├── LICENSE
├── README.md
├── dataset_unprocessed
├── Jsb16thSeparated.npz
├── js-fakes-16thSeparated.npz
├── test_chords.p
├── test_keysigs.p
├── train_chords.p
├── train_keysigs.p
├── train_majmin_chords.p
├── valid_chords.p
└── valid_keysigs.p
├── eval
├── TonicNet_epoch-55_loss-0.321_acc-90.755.pt
├── TonicNet_epoch-56_loss-0.328_acc-90.750.pt
├── TonicNet_epoch-58_loss-0.317_acc-90.928.pt
├── eval.py
├── sample.py
├── samples
│ ├── Audio
│ │ ├── sample1.mp3
│ │ ├── sample10.mp3
│ │ ├── sample2.mp3
│ │ ├── sample3.mp3
│ │ ├── sample4.mp3
│ │ ├── sample5.mp3
│ │ ├── sample6.mp3
│ │ ├── sample7.mp3
│ │ ├── sample8.mp3
│ │ ├── sample9.mp3
│ │ ├── sample_2.mp3
│ │ ├── sample_3.mp3
│ │ ├── sample_4.mp3
│ │ └── sample_5.mp3
│ ├── MIDI 16th notes
│ │ ├── sample1.mid
│ │ ├── sample10.mid
│ │ ├── sample2.mid
│ │ ├── sample3.mid
│ │ ├── sample4.mid
│ │ ├── sample5.mid
│ │ ├── sample6.mid
│ │ ├── sample7.mid
│ │ ├── sample8.mid
│ │ └── sample9.mid
│ └── MIDI smoothed
│ │ ├── sample10_smoothed.mid
│ │ ├── sample1_smoothed.mid
│ │ ├── sample2_smoothed.mid
│ │ ├── sample3_smoothed.mid
│ │ ├── sample4_smoothed.mid
│ │ ├── sample5_smoothed.mid
│ │ ├── sample6_smoothed.mid
│ │ ├── sample7_smoothed.mid
│ │ ├── sample8_smoothed.mid
│ │ └── sample9_smoothed.mid
└── utils.py
├── main.py
├── preprocessing
├── instruments.py
├── nn_dataset.py
└── utils.py
├── tokenisers
├── inverse_pitch_only.p
└── pitch_only.p
└── train
├── external.py
├── models.py
├── train_nn.py
└── transformer.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TonicNet
2 |
3 | [](https://paperswithcode.com/sota/music-modeling-on-jsb-chorales?p=js-fake-chorales-a-synthetic-dataset-of-1)
4 | [](https://paperswithcode.com/sota/music-modeling-on-jsb-chorales?p=improving-polyphonic-music-models-with)
5 |
6 |
7 | Accompanying repository for my paper: [Improving Polyphonic Music Models with Feature-Rich Encoding](https://arxiv.org/abs/1911.11775)
8 |
9 | Requirements:
10 | - Python 3 (tested with 3.6.5)
11 | - Pytorch (tested with 1.2.0)
12 | - Music21
13 |
14 | Prepare Dataset:
15 |
16 | To prepare the vanilla JSB Chorales dataset with canonical train/validation/test split:
17 | ```
18 | python main.py --gen_dataset
19 | ```
20 |
21 | To prepare dataset augmented with [JS Fake Chorales](https://github.com/omarperacha/js-fakes):
22 | ```
23 | python main.py --gen_dataset --jsf
24 | ```
25 |
26 | To prepare dataset for training on JS Fake Chorales only:
27 | ```
28 | python main.py --gen_dataset --jsf_only
29 | ```
30 |
31 | Train Model from Scratch:
32 |
33 | First run `--gen_dataset` with any optional 2nd argument, then:
34 | ```
35 | python main.py --train
36 | ```
37 |
38 | Training requires 60 epochs, taking roughly 3-6 hours on GPU
39 |
40 | Evaluate Pre-trained Model on Test Set:
41 |
42 | ```
43 | python main.py --eval_nn
44 | ```
45 |
46 | Sample with Pre-trained Model (via random sampling):
47 |
48 | ```
49 | python main.py --sample
50 | ```
51 |
--------------------------------------------------------------------------------
/dataset_unprocessed/Jsb16thSeparated.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/Jsb16thSeparated.npz
--------------------------------------------------------------------------------
/dataset_unprocessed/js-fakes-16thSeparated.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/js-fakes-16thSeparated.npz
--------------------------------------------------------------------------------
/dataset_unprocessed/test_chords.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/test_chords.p
--------------------------------------------------------------------------------
/dataset_unprocessed/test_keysigs.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/test_keysigs.p
--------------------------------------------------------------------------------
/dataset_unprocessed/train_chords.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/train_chords.p
--------------------------------------------------------------------------------
/dataset_unprocessed/train_keysigs.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/train_keysigs.p
--------------------------------------------------------------------------------
/dataset_unprocessed/train_majmin_chords.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/train_majmin_chords.p
--------------------------------------------------------------------------------
/dataset_unprocessed/valid_chords.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/valid_chords.p
--------------------------------------------------------------------------------
/dataset_unprocessed/valid_keysigs.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/dataset_unprocessed/valid_keysigs.p
--------------------------------------------------------------------------------
/eval/TonicNet_epoch-55_loss-0.321_acc-90.755.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/TonicNet_epoch-55_loss-0.321_acc-90.755.pt
--------------------------------------------------------------------------------
/eval/TonicNet_epoch-56_loss-0.328_acc-90.750.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/TonicNet_epoch-56_loss-0.328_acc-90.750.pt
--------------------------------------------------------------------------------
/eval/TonicNet_epoch-58_loss-0.317_acc-90.928.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/TonicNet_epoch-58_loss-0.317_acc-90.928.pt
--------------------------------------------------------------------------------
/eval/eval.py:
--------------------------------------------------------------------------------
1 | from torch import cuda, load, device, set_grad_enabled, max, sum, cat
2 | from preprocessing.nn_dataset import get_test_set_for_eval_classic
3 |
4 | """
5 | File containing functions to quantitatively evaluate trained models
6 | """
7 |
8 |
9 | def eval_on_test_set(load_path, model, criterion, set='test', notes_only=False):
10 | model = model
11 |
12 | try:
13 | if cuda.is_available():
14 | model.load_state_dict(load(load_path)['model_state_dict'])
15 | else:
16 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
17 | print("loded params from", load_path)
18 | except:
19 | raise ImportError(f'No file located at {load_path}, could not load parameters')
20 | print(model)
21 |
22 | if cuda.is_available():
23 | model.cuda()
24 |
25 | model.eval()
26 | criterion = criterion
27 |
28 | count = 0
29 | batch_count = 0
30 | loss_epoch = 0
31 | running_accuray = 0.0
32 | running_batch_count = 0
33 | print_loss_batch = 0 # Reset on print
34 | print_acc_batch = 0 # Reset on print
35 | pr_interval = 1
36 |
37 | for x, y, psx, i, c in get_test_set_for_eval_classic(set):
38 | model.zero_grad()
39 |
40 | train_emb = False
41 |
42 | Y = y
43 |
44 | with set_grad_enabled(False):
45 | y_hat = model(x, z=i, train_embedding=train_emb)
46 | _, preds = max(y_hat, 2)
47 |
48 | if notes_only:
49 | for j in range(y_hat.shape[1]):
50 | if j % 5 != 4:
51 | if j == 0:
52 | new_y_hat = y_hat[:, j, :].view(1, 1, 98)
53 | new_y = Y[:, j].view(1, 1)
54 | else:
55 | new_y_hat = cat((new_y_hat, y_hat[:, j, :].view(1, 1, 98)), dim=1)
56 | new_y = cat((new_y, Y[:, j].view(1, 1)), dim=1)
57 | loss = criterion(new_y_hat, new_y, )
58 | else:
59 | loss = criterion(y_hat, Y, )
60 |
61 | loss_epoch += loss.item()
62 | print_loss_batch += loss.item()
63 | if notes_only:
64 | _, new_preds = max(new_y_hat, 2)
65 | running_accuray += sum(new_preds == new_y)
66 | print_acc_batch += sum(new_preds == new_y)
67 | else:
68 | running_accuray += sum(preds == Y)
69 | print_acc_batch += sum(preds == Y)
70 |
71 | count += 1
72 | if notes_only:
73 | batch_count += int(x.shape[1] * 0.8)
74 | running_batch_count += int(x.shape[1] * 0.8)
75 | else:
76 | batch_count += x.shape[1]
77 | running_batch_count += x.shape[1]
78 |
79 | # print loss for recent set of batches
80 | if count % pr_interval == 0:
81 | ave_loss = print_loss_batch / pr_interval
82 | ave_acc = 100 * print_acc_batch.float() / running_batch_count
83 | print_acc_batch = 0
84 | running_batch_count = 0
85 | print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc))
86 | print_loss_batch = 0
87 |
88 | # calculate loss and accuracy for phase
89 | ave_loss_epoch = loss_epoch / count
90 | epoch_acc = 100 * running_accuray.float() / batch_count
91 | print('\tfinished %s phase loss: %.3f, acc: %.3f' % ('eval', ave_loss_epoch, epoch_acc))
92 |
93 |
--------------------------------------------------------------------------------
/eval/sample.py:
--------------------------------------------------------------------------------
1 | from train.models import TonicNet
2 | from torch import cat, multinomial
3 | from torch import cuda, load, device, tensor, zeros
4 | from torch.nn import LogSoftmax
5 | import pickle
6 | import random
7 | from copy import deepcopy
8 |
9 | """
10 | Functions to sample from trained models
11 | """
12 |
13 |
14 | def sample_TonicNet_random(load_path, max_tokens=2999, temperature=1.0):
15 |
16 | model = TonicNet(nb_tags=98, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.0)
17 |
18 | try:
19 | if cuda.is_available():
20 | model.load_state_dict(load(load_path)['model_state_dict'])
21 | else:
22 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
23 | print("loded params from", load_path)
24 | except:
25 | raise ImportError(f'No file located at {load_path}, could not load parameters')
26 | print(model)
27 |
28 | if cuda.is_available():
29 | model.cuda()
30 |
31 | model.eval()
32 | model.seq_len = 1
33 | model.hidden = model.init_hidden()
34 | model.zero_grad()
35 |
36 | inverse_t = pickle.load(open('tokenisers/inverse_pitch_only.p', mode='rb'))
37 |
38 | seed, pos_dict = __get_seed()
39 |
40 | x = seed
41 | x_post = x
42 |
43 | inst_conv_dict = {0: 0, 1: 1, 2: 4, 3: 2, 4: 3}
44 | current_token_dict = {0: '', 1: '', 2: '', 3: '', 4: ''}
45 |
46 | print("")
47 | print(0)
48 | print("\t", 0, ":", chord_from_token(x[0][0].item() - 48))
49 |
50 | for i in range(max_tokens):
51 |
52 | if i == 0:
53 | reset_hidden = True
54 | else:
55 | reset_hidden = False
56 |
57 | inst = inst_conv_dict[i % 5]
58 | psx = pos_dict[inst]
59 |
60 | psx_t = tensor(psx).view(1, 1)
61 |
62 | y_hat = model(x, z=psx_t, sampling=True,
63 | reset_hidden=reset_hidden).data.view(-1).div(temperature).exp()
64 | y = multinomial(y_hat, 1)[0]
65 | if y.item() == 0: # EOS token
66 | print("ending")
67 | break
68 | else:
69 | try:
70 | token = inverse_t[y.item()]
71 | except:
72 | token = chord_from_token(y.item() - 48)
73 |
74 | next_inst = inst_conv_dict[(i + 1) % 5]
75 |
76 | print("")
77 | print(i + 1)
78 |
79 | if current_token_dict[next_inst] == token:
80 | pos_dict[next_inst] += 1
81 | else:
82 | current_token_dict[next_inst] = token
83 | pos_dict[next_inst] = 0
84 |
85 | x = y.view(1, 1, 1)
86 | x_post = cat((x_post, x), dim=1)
87 |
88 | return x_post
89 |
90 |
91 | def sample_TonicNet_beam_search(load_path, max_tokens=2999, beam_width=10, alpha=1.0):
92 |
93 | """sample the model via beam search algorithm (not the most efficient implementation but functional)
94 |
95 | :param load_path: path to state_dict to load weights from
96 | :param max_tokens: maximum number of iterations to sample
97 | :param beam_width: breadth of beam search heuristic
98 | :param alpha: hyperparamter for length normalisation of beam search - higher value prefers longer sequences
99 | :return: generated list of token indices
100 | """
101 |
102 | model = TonicNet(nb_tags=98, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.0)
103 | logsoftmax = LogSoftmax(dim=0)
104 |
105 | try:
106 | if cuda.is_available():
107 | model.load_state_dict(load(load_path)['model_state_dict'])
108 | else:
109 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
110 | print("loded params from", load_path)
111 | except:
112 | raise ImportError(f'No file located at {load_path}, could not load parameters')
113 | print(model)
114 |
115 | if cuda.is_available():
116 | model.cuda()
117 |
118 | model.eval()
119 | model.seq_len = 1
120 | model.hidden = model.init_hidden()
121 | model.zero_grad()
122 |
123 | inverse_t = pickle.load(open('tokenisers/inverse_pitch_only.p', mode='rb'))
124 | inverse_t[0] = 'end'
125 |
126 | seed, pos_dict = __get_seed()
127 |
128 | x = seed
129 | x_post = x
130 |
131 | inst_conv_dict = {0: 0, 1: 1, 2: 4, 3: 2, 4: 3}
132 | current_token_dict = {0: '', 1: '', 2: '', 3: '', 4: ''}
133 |
134 | candidate_seqs = []
135 | c_ts = []
136 | pos_ds = []
137 | models = []
138 | scores = []
139 |
140 | ended = [0] * beam_width
141 |
142 | for i in range(max_tokens):
143 |
144 | print("")
145 | print(i)
146 |
147 | log_probs = zeros((98*beam_width))
148 | updated = [False] * beam_width
149 |
150 | inst = inst_conv_dict[i % 5]
151 | next_inst = inst_conv_dict[(i + 1) % 5]
152 |
153 | for b in range(beam_width):
154 |
155 | if i == 0:
156 | reset_hidden = True
157 | candidate_seqs.append(deepcopy([x]))
158 | c_ts.append(deepcopy(current_token_dict))
159 | pos_ds.append(deepcopy(pos_dict))
160 | models.append(deepcopy(model))
161 | scores.append(0)
162 | else:
163 | reset_hidden = False
164 |
165 | psx = pos_ds[b][inst]
166 |
167 | psx_t = tensor(psx).view(1, 1)
168 |
169 | if candidate_seqs[b][-1].item() == 0:
170 | log_probs[(98 * b):(98 * (b + 1))] = tensor(98).fill_(-9999)
171 | else:
172 | y_hat = models[b](candidate_seqs[b][-1], z=psx_t, sampling=True, reset_hidden=reset_hidden)
173 | log_probs[(98*b):(98*(b+1))] = logsoftmax(y_hat[0, 0, :]) + scores[b]
174 |
175 | if i == 0:
176 | top = log_probs[0:98].topk(k=beam_width)
177 | else:
178 | top = log_probs.topk(k=beam_width)
179 |
180 | temp_store = []
181 | rejection_reconciliation = {}
182 |
183 | for b1 in range(beam_width):
184 | candidate = top[1][b1].item()
185 | prob = top[0][b1].item()
186 | y = candidate % 98
187 | m = int(candidate / 98)
188 |
189 | if updated[m]:
190 | temp_store.append((m, y, prob))
191 | else:
192 | updated[m] = True
193 | rejection_reconciliation[m] = m
194 | if candidate_seqs[m][-1].item() != 0:
195 | scores[m] = prob
196 | candidate_seqs[m].append(tensor(y).view(1, 1, 1))
197 | else:
198 | ended[m] = 1
199 |
200 | for b2 in range(beam_width):
201 | if not updated[b2]:
202 | rejection_reconciliation[b2] = temp_store[0][0]
203 | if candidate_seqs[b2][-1].item() != 0:
204 |
205 | scores[b2] = temp_store[0][2]
206 |
207 | candidate_seqs[b2] = deepcopy(candidate_seqs[rejection_reconciliation[b2]])
208 | candidate_seqs[b2].pop(-1)
209 | candidate_seqs[b2].append(tensor(temp_store[0][1]).view(1, 1, 1))
210 |
211 | models[b2].hidden = models[rejection_reconciliation[b2]].hidden
212 |
213 | # copy over pos_dict and current token to replace rejected model's
214 | pos_dict[b2] = deepcopy(pos_dict[rejection_reconciliation[b2]])
215 | c_ts[b2] = deepcopy(c_ts[rejection_reconciliation[b2]])
216 |
217 | temp_store.pop(0)
218 | else:
219 | ended[b2] = 1
220 |
221 | try:
222 | token = inverse_t[candidate_seqs[b2][-1].item()]
223 | except:
224 | token = chord_from_token(candidate_seqs[b2][-1].item() - 48)
225 |
226 | if c_ts[b2][next_inst] == token:
227 | pos_ds[b2][next_inst] += 1
228 | else:
229 | c_ts[b2][next_inst] = token
230 | pos_ds[b2][next_inst] = 0
231 |
232 | if sum(ended) == beam_width:
233 | print('all ended')
234 | break
235 |
236 | normalised_scores = [scores[n] * (1/(len(candidate_seqs[n])**alpha)) for n in range(beam_width)]
237 |
238 | chosen_seq = max(zip(normalised_scores, range(len(normalised_scores))))[1]
239 |
240 | for n in candidate_seqs[chosen_seq][1:]:
241 | x_post = cat((x_post, n), dim=1)
242 |
243 | return x_post
244 |
245 |
246 | def __get_seed():
247 |
248 | seed = random.choice(range(48, 72))
249 | x = tensor(seed)
250 |
251 | p_dct = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}
252 |
253 | return x.view(1, 1, 1), p_dct
254 |
255 |
256 | def chord_from_token(token):
257 |
258 | if token < 12:
259 | qual = 'major'
260 | elif token < 24:
261 | qual = 'minor'
262 | elif token < 36:
263 | qual = 'diminished'
264 | elif token < 48:
265 | qual = 'augmented'
266 | elif token == 48:
267 | qual = 'other'
268 | else:
269 | qual = 'none'
270 |
271 | return token % 12, qual
272 |
273 |
--------------------------------------------------------------------------------
/eval/samples/Audio/sample1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample1.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample10.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample10.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample2.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample3.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample4.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample4.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample5.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample5.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample6.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample6.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample7.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample7.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample8.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample8.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample9.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample9.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_2.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample_3.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_3.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample_4.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_4.mp3
--------------------------------------------------------------------------------
/eval/samples/Audio/sample_5.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/Audio/sample_5.mp3
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample1.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample1.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample10.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample10.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample2.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample2.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample3.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample3.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample4.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample4.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample5.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample5.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample6.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample6.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample7.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample7.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample8.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample8.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI 16th notes/sample9.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI 16th notes/sample9.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample10_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample10_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample1_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample1_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample2_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample2_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample3_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample3_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample4_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample4_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample5_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample5_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample6_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample6_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample7_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample7_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample8_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample8_smoothed.mid
--------------------------------------------------------------------------------
/eval/samples/MIDI smoothed/sample9_smoothed.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/eval/samples/MIDI smoothed/sample9_smoothed.mid
--------------------------------------------------------------------------------
/eval/utils.py:
--------------------------------------------------------------------------------
1 | import music21
2 | import pickle
3 | import matplotlib.pyplot as plt
4 | from preprocessing.utils import get_parts_from_stream
5 |
6 | """
7 | Utility functions for evaluating model training and saving samples from trained models
8 | """
9 |
10 |
11 | def indices_to_stream(token_list, return_stream=False):
12 | inverse_t = pickle.load(open('tokenisers/inverse_pitch_only.p', 'rb'))
13 | tl = token_list.squeeze()
14 | tl = tl.numpy()
15 |
16 | sop_part = music21.stream.Part()
17 | sop_part.id = 'soprano'
18 |
19 | alto_part = music21.stream.Part()
20 | alto_part.id = 'alto'
21 |
22 | tenor_part = music21.stream.Part()
23 | tenor_part.id = 'tenor'
24 |
25 | bass_part = music21.stream.Part()
26 | bass_part.id = 'bass'
27 |
28 | score = music21.stream.Stream([sop_part, bass_part, alto_part, tenor_part])
29 |
30 | for j in range(len(tl)):
31 |
32 | i = tl[j]
33 | try:
34 | note = inverse_t[i]
35 | except:
36 | continue
37 |
38 | idx = (j % 5) - 1
39 |
40 | if note == 'Rest':
41 | n = music21.note.Rest()
42 | else:
43 | pitch = int(note)
44 | n = music21.note.Note(pitch)
45 |
46 | dur = 0.25
47 | n.quarterLength = dur
48 |
49 | score[idx].append(n)
50 |
51 | if return_stream:
52 | return score
53 | else:
54 | score.write('midi', fp='eval/sample.mid')
55 | print("SAVED sample to ./eval/sample.mid")
56 |
57 |
58 | def plot_loss_acc_curves(log='eval/out.log'):
59 | train_loss = []
60 | train_acc = []
61 |
62 | val_loss = []
63 | val_acc = []
64 |
65 | f = open(log, "r")
66 | txt = f.read()
67 | for line in txt.split("\n"):
68 | if 'finished' in line:
69 | components = line.split(" ")
70 | loss = components[5]
71 | loss = loss[:-1]
72 | acc = components[7]
73 | if 'train phase' in line:
74 | train_acc.append(float(acc))
75 | train_loss.append(float(loss))
76 | else:
77 | val_acc.append(float(acc))
78 | val_loss.append(float(loss))
79 |
80 | plt.figure(1)
81 | plt.subplot(121)
82 | plt.plot(train_loss)
83 | plt.plot(val_loss)
84 | plt.xlabel('epochs')
85 | plt.legend(['train loss', 'val loss'], loc='upper left')
86 | plt.ylim(0, 6)
87 |
88 | plt.subplot(122)
89 | plt.plot(train_acc)
90 | plt.plot(val_acc)
91 | plt.xlabel('epochs')
92 | plt.legend(['train acc', 'val acc'], loc='upper left')
93 | plt.ylim(0, 100)
94 | plt.show()
95 |
96 | plt.show()
97 |
98 |
99 | def smooth_rhythm():
100 | path = 'eval/sample.mid'
101 |
102 | mf = music21.midi.MidiFile()
103 | mf.open(path)
104 | mf.read()
105 | mf.close()
106 |
107 | s = music21.midi.translate.midiFileToStream(mf)
108 |
109 | score = music21.stream.Stream()
110 |
111 | parts = get_parts_from_stream(s)
112 |
113 | for part in parts:
114 | new_part = music21.stream.Part()
115 |
116 | current_pitch = -1
117 | current_offset = 0.0
118 | current_dur = 0.0
119 |
120 | for n in part.notesAndRests.flat:
121 | if isinstance(n, music21.note.Rest):
122 | if current_pitch == 129:
123 | current_dur += 0.25
124 | else:
125 | if current_pitch > -1:
126 | if current_pitch < 128:
127 | note = music21.note.Note(current_pitch)
128 | else:
129 | note = music21.note.Rest
130 | note.quarterLength = current_dur
131 | new_part.insert(current_offset, note)
132 |
133 | current_pitch = 129
134 | current_offset = n.offset
135 | current_dur = 0.25
136 |
137 | else:
138 | if n.pitch.midi == current_pitch:
139 | current_dur += 0.25
140 | else:
141 | if current_pitch > -1:
142 | if current_pitch < 128:
143 | note = music21.note.Note(current_pitch)
144 | else:
145 | note = music21.note.Rest
146 | note.quarterLength = current_dur
147 | new_part.insert(current_offset, note)
148 |
149 | current_pitch = n.pitch.midi
150 | current_offset = n.offset
151 | current_dur = 0.25
152 |
153 | if current_pitch < 128:
154 | note = music21.note.Note(current_pitch)
155 | else:
156 | note = music21.note.Rest
157 | note.quarterLength = current_dur
158 | new_part.insert(current_offset, note)
159 |
160 | score.append(new_part)
161 |
162 | score.write('midi', fp='eval/sample_smoothed.mid')
163 | print("SAVED rhythmically 'smoothed' sample to ./eval/sample_smoothed.mid")
164 |
165 |
166 |
167 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from preprocessing.nn_dataset import bach_chorales_classic
3 | from train.train_nn import train_TonicNet, TonicNet_lr_finder, TonicNet_sanity_test
4 | from train.train_nn import CrossEntropyTimeDistributedLoss
5 | from train.models import TonicNet
6 | from eval.utils import plot_loss_acc_curves, indices_to_stream, smooth_rhythm
7 | from eval.eval import eval_on_test_set
8 | from eval.sample import sample_TonicNet_random
9 |
10 | if len(sys.argv) > 1:
11 | if sys.argv[1] in ['--train', '-t']:
12 | train_TonicNet(3000, shuffle_batches=1, train_emb_freq=1, load_path='')
13 |
14 | elif sys.argv[1] in ['--plot', '-p']:
15 | plot_loss_acc_curves()
16 |
17 | elif sys.argv[1] in ['--find_lr', '-lr']:
18 | TonicNet_lr_finder(train_emb_freq=1, load_path='')
19 |
20 | elif sys.argv[1] in ['--sanity_test', '-st']:
21 | TonicNet_sanity_test(num_batches=1, train_emb_freq=1)
22 |
23 | elif sys.argv[1] in ['--sample', '-s']:
24 | x = sample_TonicNet_random(load_path='eval/TonicNet_epoch-56_loss-0.328_acc-90.750.pt', temperature=1.0)
25 | indices_to_stream(x)
26 | smooth_rhythm()
27 |
28 | elif sys.argv[1] in ['--eval_nn', '-e']:
29 | eval_on_test_set(
30 | 'eval/TonicNet_epoch-58_loss-0.317_acc-90.928.pt',
31 | TonicNet(nb_tags=98, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.0),
32 | CrossEntropyTimeDistributedLoss(), set='test', notes_only=True)
33 |
34 | elif sys.argv[1] in ['--gen_dataset', '-gd']:
35 | if len(sys.argv) > 2 and sys.argv[2] == '--jsf':
36 | for x, y, p, i, c in bach_chorales_classic('save', transpose=True, jsf_aug='all'):
37 | continue
38 | elif len(sys.argv) > 2 and sys.argv[2] == '--jsf_only':
39 | for x, y, p, i, c in bach_chorales_classic('save', transpose=True, jsf_aug='only'):
40 | continue
41 | else:
42 | for x, y, p, i, c in bach_chorales_classic('save', transpose=True):
43 | continue
44 |
45 | else:
46 | print("")
47 | print("TonicNet (Training on Ordered Notation Including Chords)")
48 | print("Omar Peracha, 2019")
49 | print("")
50 | print("--gen_dataset\t\t\t\t prepare dataset")
51 | print("--gen_dataset --jsf \t\t prepare dataset with JS Fakes data augmentation")
52 | print("--gen_dataset --jsf_only \t prepare dataset with JS Fake Chorales only")
53 | print("--train\t\t\t train model from scratch")
54 | print("--eval_nn\t\t evaluate pretrained model on test set")
55 | print("--sample\t\t sample from pretrained model")
56 | print("")
57 | else:
58 |
59 | print("")
60 | print("TonicNet (Training on Ordered Notation Including Chords)")
61 | print("Omar Peracha, 2019")
62 | print("")
63 | print("--gen_dataset\t\t\t\t prepare dataset")
64 | print("--gen_dataset --jsf \t\t prepare dataset with JS Fake Chorales data augmentation")
65 | print("--gen_dataset --jsf_only \t prepare dataset with JS Fake Chorales only")
66 | print("--train\t\t\t train model from scratch")
67 | print("--eval_nn\t\t evaluate pretrained model on test set")
68 | print("--sample\t\t sample from pretrained model")
69 | print("")
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/preprocessing/instruments.py:
--------------------------------------------------------------------------------
1 | """
2 | File containing 'structs' and methods pertaining to determining and assigning instruments to parts in corpus
3 | """
4 |
5 |
6 | # MARK:- Instrument Data Objects
7 |
8 | class SopranoVoice:
9 | instrumentId = 'soprano'
10 | highestNote = 81
11 | lowestNote = 60
12 |
13 |
14 | class AltoVoice:
15 | instrumentId = 'alto'
16 | highestNote = 77
17 | lowestNote = 53
18 |
19 |
20 | class TenorVoice:
21 | instrumentId = 'tenor'
22 | highestNote = 72
23 | lowestNote = 45
24 |
25 |
26 | class BassVoice:
27 | instrumentId = 'bass'
28 | highestNote = 64
29 | lowestNote = 36
30 |
31 |
32 | def get_instrument(inst_name_in):
33 | """
34 |
35 | :param inst_name_in: string with the name of an instrument
36 | :return: a data object corresponding to the inst_name if applicable, or UNK
37 | """
38 |
39 | if not isinstance(inst_name_in, str):
40 | inst_name = str(inst_name_in)
41 | else:
42 | inst_name = inst_name_in
43 |
44 | # Handle a few scenarios where multiple instruments could be scored
45 | if 'bass' in inst_name.lower() or 'B.' in inst_name:
46 | return BassVoice()
47 |
48 | elif 'tenor' in inst_name.lower():
49 | return TenorVoice()
50 |
51 | elif 'alto' in inst_name.lower():
52 | return AltoVoice()
53 |
54 | elif 'soprano' in inst_name.lower() or 'S.' in inst_name:
55 | return SopranoVoice()
56 |
57 | elif 'canto' in inst_name.lower():
58 | return SopranoVoice()
59 |
60 |
61 | def get_part_range(part):
62 | notes = part.pitches
63 | midi = list(map(__get_midi, notes))
64 | return [min(midi), max(midi)]
65 |
66 |
67 | def __get_midi(pitch):
68 | return pitch.midi
69 |
70 |
--------------------------------------------------------------------------------
/preprocessing/nn_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import pickle
4 | import numpy as np
5 | from preprocessing.instruments import get_instrument
6 | from random import sample
7 |
8 | """
9 | File containing functions to derive training data for neural networks
10 | """
11 |
12 | CUDA = torch.cuda.is_available()
13 |
14 | if CUDA:
15 | PATH = 'train/training_set/X_cuda'
16 | else:
17 | PATH = 'train/training_set/X'
18 | if os.path.exists(PATH):
19 | TRAIN_BATCHES = len(os.listdir(PATH))
20 | else:
21 | TRAIN_BATCHES = 0
22 | TOTAL_BATCHES = TRAIN_BATCHES + 76
23 |
24 | MAX_SEQ = 2880
25 | N_PITCH = 48
26 | N_CHORD = 50
27 | N_TOKENS = N_PITCH + N_CHORD
28 |
29 |
30 | def get_data_set(mode, shuffle_batches=True, return_I=False):
31 |
32 | if mode == 'train':
33 | parent_dir = 'train/training_set'
34 | elif mode == 'val':
35 | parent_dir = 'train/val_set'
36 | else:
37 | raise Exception("invalid mode passed to get_data_set() - options are 'train' and 'val'")
38 |
39 | if torch.cuda.is_available():
40 | lst = os.listdir(f'{parent_dir}/X_cuda')
41 | else:
42 | lst = os.listdir(f'{parent_dir}/X')
43 | try:
44 | lst.remove('.DS_Store')
45 | except:
46 | pass
47 |
48 | if shuffle_batches:
49 | lst = sample(lst, len(lst))
50 |
51 | for file_name in lst:
52 | if torch.cuda.is_available():
53 | X = torch.load(f'{parent_dir}/X_cuda/{file_name}')
54 | Y = torch.load(f'{parent_dir}/Y_cuda/{file_name}')
55 | P = torch.load(f'{parent_dir}/P_cuda/{file_name}')
56 | if return_I:
57 | I = torch.load(f'{parent_dir}/I_cuda/{file_name}')
58 | C = torch.load(f'{parent_dir}/C_cuda/{file_name}')
59 | else:
60 | X = torch.load(f'{parent_dir}/X/{file_name}')
61 | Y = torch.load(f'{parent_dir}/Y/{file_name}')
62 | P = torch.load(f'{parent_dir}/P/{file_name}')
63 | if return_I:
64 | I = torch.load(f'{parent_dir}/I/{file_name}')
65 | C = torch.load(f'{parent_dir}/C/{file_name}')
66 |
67 | if return_I:
68 | yield X, Y, P, I, C
69 | else:
70 | yield X, Y, P
71 |
72 |
73 | def bach_chorales_classic(mode, transpose=False, maj_min=False, jsf_aug=None):
74 |
75 | if maj_min and jsf_aug is not None:
76 | raise ValueError("maj_min and jsf_aug can not both be true")
77 |
78 | if jsf_aug not in ['all', 'only', 'topk-all', 'topk-only', 'topk-skilled-all', 'topk-skilled-only', None]:
79 | raise ValueError("unrecognised value for jsf_aug parameter: can only be 'all', 'only', "
80 | "'topl-all', 'topk-only', 'topk-skilled-all', 'topk-skilled-only' or None")
81 |
82 | tokeniser = pickle.load(open('tokenisers/pitch_only.p', 'rb'))
83 | tokeniser["end"] = 0
84 | count = 0
85 |
86 | for folder_name in ["training_set", "val_set"]:
87 | if torch.cuda.is_available():
88 | print("cuda:")
89 | try:
90 | os.makedirs(f'train/{folder_name}/X_cuda')
91 | os.makedirs(f'train/{folder_name}/Y_cuda')
92 | os.makedirs(f'train/{folder_name}/P_cuda')
93 | os.makedirs(f'train/{folder_name}/I_cuda')
94 | os.makedirs(f'train/{folder_name}/C_cuda')
95 | except:
96 | pass
97 | else:
98 | try:
99 | os.makedirs(f'train/{folder_name}/X')
100 | os.makedirs(f'train/{folder_name}/Y')
101 | os.makedirs(f'train/{folder_name}/P')
102 | os.makedirs(f'train/{folder_name}/I')
103 | os.makedirs(f'train/{folder_name}/C')
104 | except:
105 | pass
106 |
107 | for phase in ['train', 'valid']:
108 |
109 | d = np.load('dataset_unprocessed/Jsb16thSeparated.npz', allow_pickle=True, encoding="latin1")
110 | train = (d[phase])
111 |
112 | ks = pickle.load(open(f'dataset_unprocessed/{phase}_keysigs.p', 'rb'))
113 | crds = pickle.load(open(f'dataset_unprocessed/{phase}_chords.p', 'rb'))
114 | crds_majmin = pickle.load(open('dataset_unprocessed/train_majmin_chords.p', 'rb'))
115 | k_count = 0
116 |
117 | if jsf_aug is not None and phase == 'train':
118 | if jsf_aug in ['all', 'only']:
119 | jsf_path = 'dataset_unprocessed/js-fakes-16thSeparated.npz'
120 | jsf = np.load(jsf_path, allow_pickle=True, encoding="latin1")
121 | js_chords = jsf["chords"]
122 | jsf = jsf["pitches"]
123 |
124 | if jsf_aug == "all":
125 | train = np.concatenate((train, jsf))
126 | crds = np.concatenate((crds, js_chords))
127 | elif jsf_aug == "only":
128 | train = jsf
129 | crds = js_chords
130 |
131 | for m in train:
132 | int_m = m.astype(int)
133 |
134 | if maj_min:
135 | tonic = ks[k_count][0]
136 | scale = ks[k_count][1]
137 | crd_majmin = crds_majmin[k_count]
138 |
139 | crd = crds[k_count]
140 | k_count += 1
141 |
142 | if transpose is False or phase == 'valid':
143 | transpositions = [int_m]
144 | crds_pieces = [crd]
145 | else:
146 | parts = [int_m[:, 0], int_m[:, 1], int_m[:, 2], int_m[:, 3]]
147 | transpositions, tonics, crds_pieces = __np_perform_all_transpositions(parts, 0, crd)
148 |
149 | if maj_min:
150 |
151 | mode_switch = __np_convert_major_minor(int_m, tonic, scale)
152 | ms_parts = [mode_switch[:, 0], mode_switch[:, 1], mode_switch[:, 2], mode_switch[:, 3]]
153 | ms_trans, ms_tons, ms_crds = __np_perform_all_transpositions(ms_parts, tonic, crd_majmin)
154 |
155 | transpositions += ms_trans
156 | tonics += ms_tons
157 | crds_pieces += ms_crds
158 |
159 | kc = 0
160 |
161 | for t in transpositions:
162 |
163 | crds_piece = crds_pieces[kc]
164 |
165 | _tokens = []
166 | inst_ids = []
167 | c_class = []
168 |
169 | current_s = ''
170 | s_count = 0
171 |
172 | current_a = ''
173 | a_count = 0
174 |
175 | current_t = ''
176 | t_count = 0
177 |
178 | current_b = ''
179 | b_count = 0
180 |
181 | current_c = ''
182 | c_count = 0
183 |
184 | timestep = 0
185 |
186 | for i in t:
187 | s = 'Rest' if i[0] < 36 else str(i[0])
188 | b = 'Rest' if i[3] < 36 else str(i[3])
189 | a = 'Rest' if i[1] < 36 else str(i[1])
190 | t = 'Rest' if i[2] < 36 else str(i[2])
191 |
192 | c_val = crds_piece[timestep] + 48
193 | timestep += 1
194 |
195 | _tokens = _tokens + [c_val, s, b, a, t]
196 | c_class = c_class + [c_val]
197 |
198 | if c_val == current_c:
199 | c_count += 1
200 | else:
201 | c_count = 0
202 | current_c = c_val
203 |
204 | if s == current_s:
205 | s_count += 1
206 | else:
207 | s_count = 0
208 | current_s = s
209 |
210 | if b == current_b:
211 | b_count += 1
212 | else:
213 | b_count = 0
214 | current_b = b
215 |
216 | if a == current_a:
217 | a_count += 1
218 | else:
219 | a_count = 0
220 | current_a = a
221 |
222 | if t == current_t:
223 | t_count += 1
224 | else:
225 | t_count = 0
226 | current_t = t
227 |
228 | inst_ids = inst_ids + [c_count, s_count, b_count, a_count, t_count]
229 |
230 | pos_ids = list(range(len(_tokens)))
231 |
232 | kc += 1
233 | _tokens.append('end')
234 | tokens = []
235 | try:
236 | for x in _tokens:
237 | if isinstance(x, str):
238 | tokens.append(tokeniser[x])
239 | else:
240 | tokens.append(x)
241 | except:
242 | print("ERROR: tokenisation")
243 | continue
244 |
245 | SEQ_LEN = len(tokens) - 1
246 |
247 | count += 1
248 |
249 | data_x = []
250 | data_y = []
251 |
252 | pos_x = []
253 |
254 | for i in range(0, len(tokens) - SEQ_LEN, 1):
255 | t_seq_in = tokens[i:i + SEQ_LEN]
256 | t_seq_out = tokens[i + 1: i + 1 + SEQ_LEN]
257 | data_x.append(t_seq_in)
258 | data_y.append(t_seq_out)
259 |
260 | p_seq_in = pos_ids[i:i + SEQ_LEN]
261 | pos_x.append(p_seq_in)
262 |
263 | X = torch.tensor(data_x)
264 | X = torch.unsqueeze(X, 2)
265 |
266 | Y = torch.tensor(data_y)
267 | P = torch.tensor(pos_x)
268 | I = torch.tensor(inst_ids)
269 | C = torch.tensor(c_class)
270 |
271 | set_folder = 'training_set'
272 | if phase == 'valid':
273 | set_folder = 'val_set'
274 |
275 | if mode == 'save':
276 |
277 | if torch.cuda.is_available():
278 | print("cuda:")
279 | torch.save(X.cuda(), f'train/{set_folder}/X_cuda/{count}.pt')
280 | torch.save(Y.cuda(), f'train/{set_folder}/Y_cuda/{count}.pt')
281 | torch.save(P.cuda(), f'train/{set_folder}/P_cuda/{count}.pt')
282 | torch.save(I.cuda(), f'train/{set_folder}/I_cuda/{count}.pt')
283 | torch.save(C.cuda(), f'train/{set_folder}/C_cuda/{count}.pt')
284 | else:
285 | torch.save(X, f'train/{set_folder}/X/{count}.pt')
286 | torch.save(Y, f'train/{set_folder}/Y/{count}.pt')
287 | torch.save(P, f'train/{set_folder}/P/{count}.pt')
288 | torch.save(I, f'train/{set_folder}/I/{count}.pt')
289 | torch.save(C, f'train/{set_folder}/C/{count}.pt')
290 | print("saved", count)
291 | else:
292 | print("processed", count)
293 | yield X, Y, P, I, C
294 |
295 |
296 | def get_test_set_for_eval_classic(phase='test'):
297 |
298 | tokeniser = pickle.load(open('tokenisers/pitch_only.p', 'rb'))
299 | tokeniser["end"] = 0
300 |
301 | d = np.load('dataset_unprocessed/Jsb16thSeparated.npz', allow_pickle=True, encoding="latin1")
302 | test = (d[f'{phase}'])
303 |
304 | crds = pickle.load(open(f'dataset_unprocessed/{phase}_chords.p', 'rb'))
305 | crd_count = 0
306 |
307 | for m in test:
308 | int_m = m.astype(int)
309 |
310 | crds_piece = crds[crd_count]
311 | crd_count += 1
312 |
313 | _tokens = []
314 | inst_ids = []
315 | c_class = []
316 |
317 | current_s = ''
318 | s_count = 0
319 |
320 | current_a = ''
321 | a_count = 0
322 |
323 | current_t = ''
324 | t_count = 0
325 |
326 | current_b = ''
327 | b_count = 0
328 |
329 | current_c = ''
330 | c_count = 0
331 |
332 | timestep = 0
333 |
334 | for i in int_m:
335 | s = 'Rest' if i[0] < 36 else str(i[0])
336 | b = 'Rest' if i[3] < 36 else str(i[3])
337 | a = 'Rest' if i[1] < 36 else str(i[1])
338 | t = 'Rest' if i[2] < 36 else str(i[2])
339 |
340 | c_val = crds_piece[timestep] + 48
341 | timestep += 1
342 |
343 | _tokens = _tokens + [c_val, s, b, a, t]
344 | c_class = c_class + [c_val]
345 |
346 | if c_val == current_c:
347 | c_count += 1
348 | else:
349 | c_count = 0
350 | current_c = c_val
351 |
352 | if s == current_s:
353 | s_count += 1
354 | else:
355 | s_count = 0
356 | current_s = s
357 |
358 | if b == current_b:
359 | b_count += 1
360 | else:
361 | b_count = 0
362 | current_b = b
363 |
364 | if a == current_a:
365 | a_count += 1
366 | else:
367 | a_count = 0
368 | current_a = a
369 |
370 | if t == current_t:
371 | t_count += 1
372 | else:
373 | t_count = 0
374 | current_t = t
375 |
376 | inst_ids = inst_ids + [c_count, s_count, b_count, a_count, t_count]
377 |
378 | pos_ids = list(range(len(_tokens)))
379 |
380 | _tokens.append('end')
381 |
382 | tokens = []
383 | try:
384 | for x in _tokens:
385 | if isinstance(x, str):
386 | tokens.append(tokeniser[x])
387 | else:
388 | tokens.append(x)
389 | except:
390 | print("ERROR: tokenisation")
391 | continue
392 |
393 | SEQ_LEN = len(tokens) - 1
394 |
395 | data_x = []
396 | data_y = []
397 |
398 | pos_x = []
399 |
400 | for i in range(0, len(tokens) - SEQ_LEN, 1):
401 | t_seq_in = tokens[i:i + SEQ_LEN]
402 | t_seq_out = tokens[i + 1: i + 1 + SEQ_LEN]
403 | data_x.append(t_seq_in)
404 | data_y.append(t_seq_out)
405 |
406 | p_seq_in = pos_ids[i:i + SEQ_LEN]
407 | pos_x.append(p_seq_in)
408 |
409 | X = torch.tensor(data_x)
410 | X = torch.unsqueeze(X, 2)
411 |
412 | Y = torch.tensor(data_y)
413 | P = torch.tensor(pos_x)
414 | C = torch.tensor(c_class)
415 | I = torch.tensor(inst_ids)
416 |
417 | yield X, Y, P, I, C
418 |
419 |
420 | def position_encoding_init(n_position, emb_dim):
421 | ''' Init the sinusoid position encoding table '''
422 |
423 | # keep dim 0 for padding token position encoding zero vector
424 | position_enc = torch.tensor([
425 | [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
426 | if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)], dtype=torch.float32)
427 |
428 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim
429 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim
430 |
431 | if torch.cuda.is_available():
432 | position_enc = position_enc.cuda()
433 |
434 | return position_enc
435 |
436 |
437 | def to_categorical(y, num_classes=N_TOKENS):
438 | """ 1-hot encodes a tensor """
439 | return torch.eye(num_classes)[y]
440 |
441 |
442 | def __pos_from_beatStr(measure, beat, div=8):
443 |
444 | if len(beat.split(" ")) == 1:
445 | b_pos = float(beat.split(" ")[0])
446 | else:
447 | fraction = beat.split(" ")[1]
448 | fraction = fraction.split("/")
449 | decimal = float(fraction[0])/float(fraction[1])
450 | b_pos = float(beat.split(" ")[0]) + decimal
451 |
452 | b_pos *= div
453 | b_pos -= div
454 |
455 | pos_enc_idx = int(measure*(4*div) + b_pos)
456 |
457 | return pos_enc_idx
458 |
459 |
460 | # Mark:- transposition
461 | def __np_perform_all_transpositions(parts, tonic, chords):
462 | mylist = []
463 | tonics = []
464 | my_chords = []
465 | try:
466 | t_r = __np_transposable_range_for_piece(parts)
467 | except:
468 | print("error getting transpose range")
469 | return mylist
470 | lower = t_r[0]
471 | higher = t_r[1] + 1
472 | quals = [__np_get_quals_from_chord(x) for x in chords]
473 | for i in range(lower, higher):
474 | try:
475 | roots = [(x + 12 + i) % 12 for x in chords]
476 | transposed_piece = np.zeros((len(parts[0]), 4), dtype=int)
477 | chord_prog = [__np_chord_from_root_qual(roots[i], quals[i]) for i in range(len(chords))]
478 | for j in range(4):
479 | tp = parts[j] + i
480 | transposed_piece[:, j] = tp[:]
481 | except:
482 | print("ERROR: empty return")
483 | else:
484 | mylist.append(transposed_piece)
485 | tonics.append((tonic + i) % 12)
486 | my_chords.append(chord_prog)
487 | return mylist, tonics, my_chords
488 |
489 |
490 | def __np_transposable_range_for_part(part, inst):
491 |
492 | if not isinstance(inst, str):
493 | inst = str(inst)
494 | part_range = __np_get_part_range(part)
495 | instrument = get_instrument(inst)
496 |
497 | lower_transposable = instrument.lowestNote - part_range[0]
498 | higher_transposable = instrument.highestNote - part_range[1]
499 |
500 | # suggests there's perhaps no musical content in this score
501 | if higher_transposable - lower_transposable >= 128:
502 | lower_transposable = 0
503 | higher_transposable = 0
504 | return min(0, lower_transposable), max(0, higher_transposable)
505 |
506 |
507 | def __np_transposable_range_for_piece(parts):
508 |
509 | insts = ['soprano', 'alto', 'tenor', 'bass']
510 |
511 | lower = -127
512 | higher = 127
513 |
514 | for i in range(len(parts)):
515 | t_r = __np_transposable_range_for_part(parts[i], insts[i])
516 | if t_r[0] > lower:
517 | lower = t_r[0]
518 | if t_r[1] < higher:
519 | higher = t_r[1]
520 | # suggests there's perhaps no musical content in this score
521 | if higher - lower >= 128:
522 | lower = 0
523 | higher = 0
524 | return lower, higher
525 |
526 |
527 | def __np_get_part_range(part):
528 |
529 | mn = min(part)
530 |
531 | if mn < 36:
532 | p = sorted(part)
533 | c = 1
534 | while mn < 36:
535 | mn = p[c]
536 | c += 1
537 |
538 | return [mn, max(part)]
539 |
540 |
541 | def __np_convert_major_minor(piece, tonic, mode):
542 |
543 | _piece = piece
544 |
545 | for i in range(len(_piece)):
546 | s = _piece[i][0] if _piece[i][0] < 36 else (_piece[i][0] - tonic) % 12
547 | b = _piece[i][3] if _piece[i][3] < 36 else (_piece[i][3] - tonic) % 12
548 | a = _piece[i][1] if _piece[i][1] < 36 else (_piece[i][1] - tonic) % 12
549 | t = _piece[i][2] if _piece[i][2] < 36 else (_piece[i][2] - tonic) % 12
550 |
551 | parts = [s, a, t, b]
552 |
553 | for n in range(len(parts)):
554 | if mode == 'major':
555 | if parts[n] in [4, 9]:
556 | _piece[i][n] -= 1
557 | elif mode == 'minor':
558 | if parts[n] in [3, 8, 10]:
559 | _piece[i][n] += 1
560 | else:
561 | raise ValueError(f"mode must be minor or major, received {mode}")
562 |
563 | return _piece
564 |
565 |
566 | def __np_get_quals_from_chord(chord):
567 |
568 | if chord < 12:
569 | qual = 'major'
570 | elif chord < 24:
571 | qual = 'minor'
572 | elif chord < 36:
573 | qual = 'diminished'
574 | elif chord < 48:
575 | qual = 'augmented'
576 | elif chord == 48:
577 | qual = 'other'
578 | else:
579 | qual = 'none'
580 |
581 | return qual
582 |
583 |
584 | def __np_chord_from_root_qual(root, qual):
585 |
586 | if qual == "major":
587 | chord = root
588 | elif qual == "minor":
589 | chord = root + 12
590 | elif qual == "diminished":
591 | chord = root + 24
592 | elif qual == "augmented":
593 | chord = root + 36
594 | elif qual == "other":
595 | chord = 48
596 | elif qual == "none":
597 | chord = 49
598 |
599 | return chord
600 |
601 |
602 |
603 |
604 |
--------------------------------------------------------------------------------
/preprocessing/utils.py:
--------------------------------------------------------------------------------
1 | import music21
2 | import pickle
3 |
4 | """
5 | Contains utility functions for dataset preprocessing
6 | """
7 |
8 |
9 | def get_parts_from_stream(piece):
10 | parts = []
11 | for i in piece:
12 | if isinstance(i, music21.stream.Part):
13 | parts.append(i)
14 | return parts
15 |
16 |
17 | def pitch_tokeniser_maker():
18 | post = {"end": 0}
19 | for i in range(36, 82):
20 | k = str(i)
21 | post[k] = len(post)
22 | post['Rest'] = len(post)
23 |
24 | return post
25 |
26 |
27 | def load_tokeniser():
28 | dic = pickle.load(open("tokenisers/pitch_only.p", "rb"))
29 | return dic
30 |
31 |
32 | def chord_from_pitches(pitches):
33 | cd = []
34 | for n in pitches:
35 | if n >= 36:
36 | cd.append(int(n))
37 |
38 | crd = music21.chord.Chord(cd)
39 | try:
40 | root = music21.pitch.Pitch(crd.root()).pitchClass
41 | except:
42 | c_val = 49
43 | else:
44 | if crd.quality == 'major':
45 | c_val = root
46 | if crd.quality == 'minor':
47 | c_val = root + 12
48 | if crd.quality == 'diminished':
49 | c_val = root + 24
50 | if crd.quality == 'augmented':
51 | c_val = root + 36
52 | if crd.quality == 'other':
53 | c_val = 48
54 |
55 | return c_val
56 |
57 |
--------------------------------------------------------------------------------
/tokenisers/inverse_pitch_only.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/tokenisers/inverse_pitch_only.p
--------------------------------------------------------------------------------
/tokenisers/pitch_only.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omarperacha/TonicNet/0a00997af67eb75d352ab2380fa737ce39c7be1a/tokenisers/pitch_only.p
--------------------------------------------------------------------------------
/train/external.py:
--------------------------------------------------------------------------------
1 | import math
2 | import itertools as it
3 | import torch
4 | from torch.optim.optimizer import Optimizer
5 | import torch.nn as nn
6 | from torch.nn.utils.rnn import PackedSequence
7 | from typing import *
8 |
9 | """
10 | File containing classes taken directly from other work to aid model training
11 | """
12 |
13 |
14 | class RAdam(Optimizer):
15 | """
16 | Directly from Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han.
17 | "On the Variance of the Adaptive Learning Rate and Beyond."
18 | """
19 |
20 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
21 | weight_decay=0):
22 | defaults = dict(lr=lr, betas=betas, eps=eps,
23 | weight_decay=weight_decay)
24 |
25 | super(RAdam, self).__init__(params, defaults)
26 |
27 | def __setstate__(self, state):
28 | super(RAdam, self).__setstate__(state)
29 |
30 | def step(self, closure=None):
31 | loss = None
32 | beta2_t = None
33 | ratio = None
34 | N_sma_max = None
35 | N_sma = None
36 |
37 | if closure is not None:
38 | loss = closure()
39 |
40 | for group in self.param_groups:
41 |
42 | for p in group['params']:
43 | if p.grad is None:
44 | continue
45 | grad = p.grad.data.float()
46 | if grad.is_sparse:
47 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
48 |
49 | p_data_fp32 = p.data.float()
50 |
51 | state = self.state[p]
52 |
53 | if len(state) == 0:
54 | state['step'] = 0
55 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
56 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
57 | else:
58 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
59 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
60 |
61 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
62 | beta1, beta2 = group['betas']
63 |
64 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
65 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
66 |
67 | state['step'] += 1
68 | if beta2_t is None:
69 | beta2_t = beta2 ** state['step']
70 | N_sma_max = 2 / (1 - beta2) - 1
71 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
72 | beta1_t = 1 - beta1 ** state['step']
73 | if N_sma >= 5:
74 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t
75 |
76 | if group['weight_decay'] != 0:
77 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
78 |
79 | # more conservative since it's an approximated value
80 | if N_sma >= 5:
81 | step_size = group['lr'] * ratio
82 | denom = exp_avg_sq.sqrt().add_(group['eps'])
83 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
84 | else:
85 | step_size = group['lr'] / beta1_t
86 | p_data_fp32.add_(-step_size, exp_avg)
87 |
88 | p.data.copy_(p_data_fp32)
89 |
90 | return loss
91 |
92 |
93 | class VariationalDropout(nn.Module):
94 | """
95 | Applies the same dropout mask across the temporal dimension
96 | See https://arxiv.org/abs/1512.05287 for more details.
97 | Note that this is not applied to the recurrent activations in the LSTM like the above paper.
98 | Instead, it is applied to the inputs and outputs of the recurrent layer.
99 | """
100 | def __init__(self, dropout: float, batch_first: Optional[bool]=False):
101 | super().__init__()
102 | self.dropout = dropout
103 | self.batch_first = batch_first
104 |
105 | def forward(self, x: torch.Tensor) -> torch.Tensor:
106 | if not self.training or self.dropout <= 0.:
107 | return x
108 |
109 | is_packed = isinstance(x, PackedSequence)
110 | if is_packed:
111 | x, batch_sizes = x
112 | max_batch_size = int(batch_sizes[0])
113 | else:
114 | batch_sizes = None
115 | max_batch_size = x.size(0)
116 |
117 | # Drop same mask across entire sequence
118 | if self.batch_first:
119 | m = x.new_empty(max_batch_size, 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
120 | else:
121 | m = x.new_empty(1, max_batch_size, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
122 | x = x.masked_fill(m == 0, 0) / (1 - self.dropout)
123 |
124 | if is_packed:
125 | return PackedSequence(x, batch_sizes)
126 | else:
127 | return x
128 |
129 |
130 | class Lookahead(Optimizer):
131 | def __init__(self, base_optimizer,alpha=0.5, k=6):
132 | if not 0.0 <= alpha <= 1.0:
133 | raise ValueError(f'Invalid slow update rate: {alpha}')
134 | if not 1 <= k:
135 | raise ValueError(f'Invalid lookahead steps: {k}')
136 | self.optimizer = base_optimizer
137 | self.param_groups = self.optimizer.param_groups
138 | self.alpha = alpha
139 | self.k = k
140 | for group in self.param_groups:
141 | group["step_counter"] = 0
142 | self.slow_weights = [[p.clone().detach() for p in group['params']]
143 | for group in self.param_groups]
144 |
145 | for w in it.chain(*self.slow_weights):
146 | w.requires_grad = False
147 |
148 | def step(self, closure=None):
149 | loss = None
150 | if closure is not None:
151 | loss = closure()
152 | loss = self.optimizer.step()
153 | for group,slow_weights in zip(self.param_groups,self.slow_weights):
154 | group['step_counter'] += 1
155 | if group['step_counter'] % self.k != 0:
156 | continue
157 | for p,q in zip(group['params'],slow_weights):
158 | if p.grad is None:
159 | continue
160 | q.data.add_(self.alpha,p.data - q.data)
161 | p.data.copy_(q.data)
162 | return loss
163 |
164 |
165 | # MARK:- OneCycleLR
166 | class OneCycleLR(torch.optim.lr_scheduler._LRScheduler):
167 | r"""Sets the learning rate of each parameter group according to the
168 | 1cycle learning rate policy. The 1cycle policy anneals the learning
169 | rate from an initial learning rate to some maximum learning rate and then
170 | from that maximum learning rate to some minimum learning rate much lower
171 | than the initial learning rate.
172 | This policy was initially described in the paper `Super-Convergence:
173 | Very Fast Training of Neural Networks Using Large Learning Rates`_.
174 |
175 | The 1cycle learning rate policy changes the learning rate after every batch.
176 | `step` should be called after a batch has been used for training.
177 |
178 | This scheduler is not chainable.
179 |
180 | This class has two built-in annealing strategies:
181 | "cos":
182 | Cosine annealing
183 | "linear":
184 | Linear annealing
185 |
186 | Note also that the total number of steps in the cycle can be determined in one
187 | of two ways (listed in order of precedence):
188 | 1) A value for total_steps is explicitly provided.
189 | 2) A number of epochs (epochs) and a number of steps per epoch
190 | (steps_per_epoch) are provided.
191 | In this case, the number of total steps is inferred by
192 | total_steps = epochs * steps_per_epoch
193 | You must either provide a value for total_steps or provide a value for both
194 | epochs and steps_per_epoch.
195 |
196 | Args:
197 | optimizer (Optimizer): Wrapped optimizer.
198 | max_lr (float or list): Upper learning rate boundaries in the cycle
199 | for each parameter group.
200 | total_steps (int): The total number of steps in the cycle. Note that
201 | if a value is provided here, then it must be inferred by providing
202 | a value for epochs and steps_per_epoch.
203 | Default: None
204 | epochs (int): The number of epochs to train for. This is used along
205 | with steps_per_epoch in order to infer the total number of steps in the cycle
206 | if a value for total_steps is not provided.
207 | Default: None
208 | steps_per_epoch (int): The number of steps per epoch to train for. This is
209 | used along with epochs in order to infer the total number of steps in the
210 | cycle if a value for total_steps is not provided.
211 | Default: None
212 | pct_start (float): The percentage of the cycle (in number of steps) spent
213 | increasing the learning rate.
214 | Default: 0.3
215 | anneal_strategy (str): {'cos', 'linear'}
216 | Specifies the annealing strategy.
217 | Default: 'cos'
218 | cycle_momentum (bool): If ``True``, momentum is cycled inversely
219 | to learning rate between 'base_momentum' and 'max_momentum'.
220 | Default: True
221 | base_momentum (float or list): Lower momentum boundaries in the cycle
222 | for each parameter group. Note that momentum is cycled inversely
223 | to learning rate; at the peak of a cycle, momentum is
224 | 'base_momentum' and learning rate is 'max_lr'.
225 | Default: 0.85
226 | max_momentum (float or list): Upper momentum boundaries in the cycle
227 | for each parameter group. Functionally,
228 | it defines the cycle amplitude (max_momentum - base_momentum).
229 | Note that momentum is cycled inversely
230 | to learning rate; at the start of a cycle, momentum is 'max_momentum'
231 | and learning rate is 'base_lr'
232 | Default: 0.95
233 | div_factor (float): Determines the initial learning rate via
234 | initial_lr = max_lr/div_factor
235 | Default: 25
236 | final_div_factor (float): Determines the minimum learning rate via
237 | min_lr = initial_lr/final_div_factor
238 | Default: 1e4
239 | last_epoch (int): The index of the last batch. This parameter is used when
240 | resuming a training job. Since `step()` should be invoked after each
241 | batch instead of after each epoch, this number represents the total
242 | number of *batches* computed, not the total number of epochs computed.
243 | When last_epoch=-1, the schedule is started from the beginning.
244 | Default: -1
245 |
246 | Example:
247 | >>> data_loader = torch.utils.data.DataLoader(...)
248 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
249 | >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
250 | >>> for epoch in range(10):
251 | >>> for batch in data_loader:
252 | >>> train_batch(...)
253 | >>> scheduler.step()
254 |
255 |
256 | .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
257 | https://arxiv.org/abs/1708.07120
258 | """
259 | def __init__(self,
260 | optimizer,
261 | max_lr,
262 | total_steps=None,
263 | epochs=None,
264 | steps_per_epoch=None,
265 | pct_start=0.3,
266 | anneal_strategy='cos',
267 | cycle_momentum=True,
268 | base_momentum=0.85,
269 | max_momentum=0.95,
270 | div_factor=25.,
271 | final_div_factor=1e4,
272 | last_epoch=-1):
273 |
274 | # Validate optimizer
275 | if not isinstance(optimizer, Optimizer):
276 | raise TypeError('{} is not an Optimizer'.format(
277 | type(optimizer).__name__))
278 | self.optimizer = optimizer
279 |
280 | # Validate total_steps
281 | if total_steps is None and epochs is None and steps_per_epoch is None:
282 | raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
283 | elif total_steps is not None:
284 | if total_steps <= 0 or not isinstance(total_steps, int):
285 | raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps))
286 | self.total_steps = total_steps
287 | else:
288 | if epochs <= 0 or not isinstance(epochs, int):
289 | raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs))
290 | if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
291 | raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch))
292 | self.total_steps = epochs * steps_per_epoch
293 | self.step_size_up = float(pct_start * self.total_steps) - 1
294 | self.step_size_down = float(self.total_steps - self.step_size_up) - 1
295 |
296 | # Validate pct_start
297 | if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
298 | raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
299 |
300 | # Validate anneal_strategy
301 | if anneal_strategy not in ['cos', 'linear']:
302 | raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
303 | elif anneal_strategy == 'cos':
304 | self.anneal_func = self._annealing_cos
305 | elif anneal_strategy == 'linear':
306 | self.anneal_func = self._annealing_linear
307 |
308 | # Initialize learning rate variables
309 | max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
310 | if last_epoch == -1:
311 | for idx, group in enumerate(self.optimizer.param_groups):
312 | group['lr'] = max_lrs[idx] / div_factor
313 | group['max_lr'] = max_lrs[idx]
314 | group['min_lr'] = group['lr'] / final_div_factor
315 |
316 | # Initialize momentum variables
317 | self.cycle_momentum = cycle_momentum
318 | if self.cycle_momentum:
319 | if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
320 | raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
321 | self.use_beta1 = 'betas' in self.optimizer.defaults
322 | max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
323 | base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
324 | if last_epoch == -1:
325 | for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
326 | if self.use_beta1:
327 | _, beta2 = group['betas']
328 | group['betas'] = (m_momentum, beta2)
329 | else:
330 | group['momentum'] = m_momentum
331 | group['max_momentum'] = m_momentum
332 | group['base_momentum'] = b_momentum
333 |
334 | super(OneCycleLR, self).__init__(optimizer, last_epoch)
335 |
336 | def _format_param(self, name, optimizer, param):
337 | """Return correctly formatted lr/momentum for each param group."""
338 | if isinstance(param, (list, tuple)):
339 | if len(param) != len(optimizer.param_groups):
340 | raise ValueError("expected {} values for {}, got {}".format(
341 | len(optimizer.param_groups), name, len(param)))
342 | return param
343 | else:
344 | return [param] * len(optimizer.param_groups)
345 |
346 | def _annealing_cos(self, start, end, pct):
347 | "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
348 | cos_out = math.cos(math.pi * pct) + 1
349 | return end + (start - end) / 2.0 * cos_out
350 |
351 | def _annealing_linear(self, start, end, pct):
352 | "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
353 | return (end - start) * pct + start
354 |
355 | def get_lr(self):
356 | lrs = []
357 | step_num = self.last_epoch
358 |
359 | if step_num > self.total_steps:
360 | raise ValueError("Tried to step {} times. The specified number of total steps is {}"
361 | .format(step_num + 1, self.total_steps))
362 |
363 | for group in self.optimizer.param_groups:
364 | if step_num <= self.step_size_up:
365 | computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up)
366 | if self.cycle_momentum:
367 | computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'],
368 | step_num / self.step_size_up)
369 | else:
370 | down_step_num = step_num - self.step_size_up
371 | computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down)
372 | if self.cycle_momentum:
373 | computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'],
374 | down_step_num / self.step_size_down)
375 |
376 | lrs.append(computed_lr)
377 | if self.cycle_momentum:
378 | if self.use_beta1:
379 | _, beta2 = group['betas']
380 | group['betas'] = (computed_momentum, beta2)
381 | else:
382 | group['momentum'] = computed_momentum
383 |
384 | return lrs
385 |
386 |
--------------------------------------------------------------------------------
/train/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from preprocessing.nn_dataset import MAX_SEQ
4 | from preprocessing.nn_dataset import position_encoding_init
5 | from train.external import VariationalDropout
6 | from train.transformer import TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer
7 |
8 | """
9 | File containing classes representing various Neural architectures
10 | """
11 |
12 |
13 | # MARK:- TonicNet
14 | class TonicNet(nn.Module):
15 | def __init__(self, nb_tags, nb_layers=1, z_dim =0,
16 | nb_rnn_units=100, batch_size=1, seq_len=1, dropout=0.0):
17 | super(TonicNet, self).__init__()
18 |
19 | self.nb_layers = nb_layers
20 | self.nb_rnn_units = nb_rnn_units
21 | self.batch_size = batch_size
22 | self.seq_len = seq_len
23 | self.dropout = dropout
24 | self.z_dim = z_dim
25 | self.z_emb_size = 32
26 |
27 | self.nb_tags = nb_tags
28 |
29 | # build actual NN
30 | self.__build_model()
31 |
32 | def __build_model(self):
33 |
34 | self.embedding = nn.Embedding(self.nb_tags, self.nb_rnn_units)
35 |
36 | # Unused but key exists in state_dict
37 | self.pos_emb = nn.Embedding(64, 0)
38 |
39 | self.z_embedding = nn.Embedding(80, self.z_emb_size)
40 |
41 | self.dropout_i = VariationalDropout(max(0.0, self.dropout - 0.2), batch_first=True)
42 |
43 | # design RNN
44 |
45 | input_size = self.nb_rnn_units
46 | if self.z_dim > 0:
47 | input_size += self.z_emb_size
48 |
49 | self.rnn = nn.GRU(
50 | input_size=input_size,
51 | hidden_size=self.nb_rnn_units,
52 | num_layers=self.nb_layers,
53 | batch_first=True,
54 | dropout=self.dropout
55 | )
56 | self.dropout_o = VariationalDropout(self.dropout, batch_first=True)
57 |
58 | # output layer which projects back to tag space
59 | self.hidden_to_tag = nn.Linear(input_size, self.nb_tags, bias=False)
60 |
61 | def init_hidden(self):
62 | # the weights are of the form (nb_layers, batch_size, nb_rnn_units)
63 | hidden_a = torch.randn(self.nb_layers, self.batch_size, self.nb_rnn_units)
64 |
65 | if torch.cuda.is_available():
66 | hidden_a = hidden_a.cuda()
67 |
68 | return hidden_a
69 |
70 | def forward(self, X, z=None, train_embedding=True, sampling=False, reset_hidden=True):
71 | # reset the RNN hidden state.
72 | if not sampling:
73 | self.seq_len = X.shape[1]
74 | if reset_hidden:
75 | self.hidden = self.init_hidden()
76 |
77 | self.embedding.weight.requires_grad = train_embedding
78 |
79 | # ---------------------
80 | # Combine inputs
81 | X = self.embedding(X)
82 | X = X.view(self.batch_size, self.seq_len, self.nb_rnn_units)
83 |
84 | # repeating pitch encoding
85 | if self.z_dim > 0:
86 | Z = self.z_embedding(z % 80)
87 | Z = Z.view(self.batch_size, self.seq_len, self.z_emb_size)
88 | X = torch.cat((Z, X), 2)
89 |
90 | X = self.dropout_i(X)
91 |
92 | # Run through RNN
93 | X, self.hidden = self.rnn(X, self.hidden)
94 |
95 | if self.z_dim > 0:
96 | X = torch.cat((Z, X), 2)
97 |
98 | X = self.dropout_o(X)
99 |
100 | # run through linear layer
101 | X = self.hidden_to_tag(X)
102 |
103 | Y_hat = X
104 | return Y_hat
105 |
106 |
107 | class Transformer_Model(nn.Module):
108 | def __init__(self, nb_tags, nb_layers=1, pe_dim=0,
109 | emb_dim=100, batch_size=1, seq_len=MAX_SEQ, dropout=0.0, encoder_only=True):
110 | super(Transformer_Model, self).__init__()
111 |
112 | self.nb_layers = nb_layers
113 | self.emb_dim = emb_dim
114 | self.batch_size = batch_size
115 | self.seq_len = seq_len
116 | self.pe_dim = pe_dim
117 | self.dropout = dropout
118 |
119 | self.nb_tags = nb_tags
120 |
121 | self.encoder_only = encoder_only
122 |
123 | # build actual NN
124 | self.__build_model()
125 |
126 | def __build_model(self):
127 |
128 | self.embedding = nn.Embedding(self.nb_tags, self.emb_dim)
129 |
130 | if not self.encoder_only:
131 | self.embedding2 = nn.Embedding(self.nb_tags, self.emb_dim)
132 |
133 | self.pos_emb = position_encoding_init(MAX_SEQ, self.pe_dim)
134 | self.pos_emb.requires_grad = False
135 |
136 | self.dropout_i = nn.Dropout(self.dropout)
137 |
138 | input_size = self.pe_dim + self.emb_dim
139 |
140 | self.transformerLayerI = TransformerEncoderLayer(d_model=input_size,
141 | nhead=8,
142 | dropout=self.dropout,
143 | dim_feedforward=1024)
144 |
145 | self.transformerI = TransformerEncoder(self.transformerLayerI,
146 | num_layers=self.nb_layers,)
147 |
148 | self.dropout_m = nn.Dropout(self.dropout)
149 |
150 | if not self.encoder_only:
151 | # design decoder
152 | self.transformerLayerO = TransformerDecoderLayer(d_model=input_size,
153 | nhead=8,
154 | dropout=self.dropout,
155 | dim_feedforward=1024)
156 |
157 | self.transformerO = TransformerDecoder(self.transformerLayerO,
158 | num_layers=self.nb_layers, )
159 |
160 | self.dropout_o = nn.Dropout(self.dropout)
161 |
162 | # output layer which projects back to tag space
163 | self.hidden_to_tag = nn.Linear(self.emb_dim + self.pe_dim, self.nb_tags)
164 |
165 | def __pos_encode(self, p):
166 | return self.pos_emb[p]
167 |
168 | def forward(self, X, p, X2=None, train_embedding=True):
169 |
170 | self.embedding.weight.requires_grad = train_embedding
171 | if not self.encoder_only:
172 | self.embedding2.weight.requires_grad = train_embedding
173 |
174 | I = X
175 |
176 | self.mask = (torch.triu(torch.ones(self.seq_len, self.seq_len)) == 1).transpose(0, 1)
177 | self.mask = self.mask.float().masked_fill(self.mask == 0, float('-inf')).masked_fill(self.mask == 1, float(0.0))
178 |
179 | if torch.cuda.is_available():
180 | self.mask = self.mask.cuda()
181 |
182 | # ---------------------
183 | # Combine inputs
184 | X = self.embedding(I)
185 | X = X.view(self.seq_len, self.batch_size, -1)
186 |
187 | if self.pe_dim > 0:
188 | P = self.__pos_encode(p)
189 | P = P.view(self.seq_len, self.batch_size, -1)
190 | X = torch.cat((X, P), 2)
191 |
192 | X = self.dropout_i(X)
193 |
194 | # Run through transformer encoder
195 |
196 | M = self.transformerI(X, mask=self.mask)
197 | M = self.dropout_m(M)
198 |
199 | if not self.encoder_only:
200 | # ---------------------
201 | # Decoder stack
202 | X = self.embedding2(X2)
203 | X = X.view(self.seq_len, self.batch_size, -1)
204 |
205 | if self.pe_dim > 0:
206 | X = torch.cat((X, P), 2)
207 |
208 | X = self.dropout_i(X)
209 |
210 | X = self.transformerO(X, M, tgt_mask=self.mask, memory_mask=None)
211 | X = self.dropout_o(X)
212 |
213 | # run through linear layer
214 | X = self.hidden_to_tag(X)
215 | else:
216 | X = self.hidden_to_tag(M)
217 |
218 | Y_hat = X
219 | return Y_hat
220 |
221 |
222 | # MARK:- Custom s2s Cross Entropy loss
223 | class CrossEntropyTimeDistributedLoss(nn.Module):
224 | """loss function for multi-timsetep model output"""
225 | def __init__(self):
226 | super(CrossEntropyTimeDistributedLoss, self).__init__()
227 |
228 | self.loss_func = nn.CrossEntropyLoss()
229 |
230 | def forward(self, y_hat, y):
231 |
232 | _y_hat = y_hat.squeeze(0)
233 | _y = y.squeeze(0)
234 |
235 | # Loss from one sequence
236 | loss = self.loss_func(_y_hat, _y)
237 | loss = torch.sum(loss)
238 | return loss
239 |
240 |
241 |
--------------------------------------------------------------------------------
/train/train_nn.py:
--------------------------------------------------------------------------------
1 | import time, os
2 | import math
3 | import matplotlib.pyplot as plt
4 | from torch import save, set_grad_enabled, sum, max
5 | from torch import optim, cuda, load, device
6 | from torch.nn.utils import clip_grad_norm_
7 | from preprocessing.nn_dataset import get_data_set, TOTAL_BATCHES, TRAIN_BATCHES, N_TOKENS
8 | from train.models import CrossEntropyTimeDistributedLoss
9 | from train.models import TonicNet, Transformer_Model
10 | from train.external import RAdam, Lookahead, OneCycleLR
11 |
12 | """
13 | File containing functions which train various neural networks defined in train.models
14 | """
15 |
16 |
17 | CV_PHASES = ['train', 'val']
18 | TRAIN_ONLY_PHASES = ['train']
19 |
20 |
21 | # MARK:- TonicNet
22 | def TonicNet_lr_finder(train_emb_freq=3000, load_path=''):
23 | train_TonicNet(epochs=3,
24 | save_model=False,
25 | load_path=load_path,
26 | shuffle_batches=True,
27 | num_batches=TRAIN_BATCHES,
28 | val=False,
29 | train_emb_freq=train_emb_freq,
30 | lr_range_test=True)
31 |
32 |
33 | def TonicNet_sanity_test(num_batches=1, train_emb_freq=3000, load_path=''):
34 | train_TonicNet(epochs=200,
35 | save_model=False,
36 | load_path=load_path,
37 | shuffle_batches=False,
38 | num_batches=num_batches,
39 | val=1,
40 | train_emb_freq=train_emb_freq,
41 | lr_range_test=False,
42 | sanity_test=True)
43 |
44 |
45 | def train_TonicNet(epochs,
46 | save_model=True,
47 | load_path='',
48 | shuffle_batches=False,
49 | num_batches=TOTAL_BATCHES,
50 | val=True,
51 | train_emb_freq=1,
52 | lr_range_test=False,
53 | sanity_test=False):
54 |
55 | model = TonicNet(nb_tags=N_TOKENS, z_dim=32, nb_layers=3, nb_rnn_units=256, dropout=0.3)
56 |
57 | if load_path != '':
58 | try:
59 | if cuda.is_available():
60 | model.load_state_dict(load(load_path)['model_state_dict'])
61 | else:
62 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
63 | print("loded params from", load_path)
64 | except:
65 | raise ImportError(f'No file located at {load_path}, could not load parameters')
66 | print(model)
67 |
68 | if cuda.is_available():
69 | model.cuda()
70 |
71 | base_lr = 0.2
72 | max_lr = 0.2
73 |
74 | if lr_range_test:
75 | base_lr = 0.000003
76 | max_lr = 0.5
77 |
78 | step_size = 3 * min(TRAIN_BATCHES, num_batches)
79 |
80 | if sanity_test:
81 | base_optim = RAdam(model.parameters(), lr=base_lr)
82 | optimiser = Lookahead(base_optim, k=5, alpha=0.5)
83 | else:
84 | optimiser = optim.SGD(model.parameters(), base_lr)
85 | criterion = CrossEntropyTimeDistributedLoss()
86 |
87 | print(criterion)
88 |
89 | print(f"min lr: {base_lr}, max_lr: {max_lr}, stepsize: {step_size}")
90 |
91 | if not sanity_test and not lr_range_test:
92 | scheduler = OneCycleLR(optimiser, max_lr,
93 | epochs=60, steps_per_epoch=TRAIN_BATCHES, pct_start=0.3,
94 | anneal_strategy='cos', cycle_momentum=True, base_momentum=0.8,
95 | max_momentum=0.95, div_factor=25.0, final_div_factor=1000.0,
96 | last_epoch=-1)
97 |
98 | elif lr_range_test:
99 | lr_lambda = lambda x: math.exp(x * math.log(max_lr / base_lr) / (epochs * num_batches))
100 | scheduler = optim.lr_scheduler.LambdaLR(optimiser, lr_lambda)
101 |
102 | best_val_loss = 100.0
103 |
104 | if lr_range_test:
105 | lr_find_loss = []
106 | lr_find_lr = []
107 |
108 | itr = 0
109 | smoothing = 0.05
110 |
111 | if val:
112 | phases = CV_PHASES
113 | else:
114 | phases = TRAIN_ONLY_PHASES
115 |
116 | for epoch in range(epochs):
117 | start = time.time()
118 | pr_interval = 50
119 |
120 | print(f'Beginning EPOCH {epoch + 1}')
121 |
122 | for phase in phases:
123 |
124 | count = 0
125 | batch_count = 0
126 | loss_epoch = 0
127 | running_accuray = 0.0
128 | running_batch_count = 0
129 | print_loss_batch = 0 # Reset on print
130 | print_acc_batch = 0 # Reset on print
131 |
132 | print(f'\n\tPHASE: {phase}')
133 |
134 | if phase == 'train':
135 | model.train() # Set model to training mode
136 | else:
137 | model.eval() # Set model to evaluate mode
138 |
139 | for x, y, psx, i, c in get_data_set(phase, shuffle_batches=shuffle_batches, return_I=1):
140 | model.zero_grad()
141 |
142 | if phase == 'train' and (epoch > -1 or load_path != ''):
143 | if train_emb_freq < 1000:
144 | train_emb = ((count % train_emb_freq) == 0)
145 | else:
146 | train_emb = False
147 |
148 | else:
149 | train_emb = False
150 |
151 | with set_grad_enabled(phase == 'train'):
152 | y_hat = model(x, z=i, train_embedding=train_emb)
153 | _, preds = max(y_hat, 2)
154 | loss = criterion(y_hat, y,)
155 |
156 | if phase == 'train':
157 | loss.backward()
158 | clip_grad_norm_(model.parameters(), 5)
159 | optimiser.step()
160 | if not sanity_test:
161 | scheduler.step()
162 |
163 | loss_epoch += loss.item()
164 | print_loss_batch += loss.item()
165 | running_accuray += sum(preds == y)
166 | print_acc_batch += sum(preds == y)
167 |
168 | count += 1
169 | batch_count += x.shape[1]
170 | running_batch_count += x.shape[1]
171 |
172 | if lr_range_test:
173 | lr_step = optimiser.state_dict()["param_groups"][0]["lr"]
174 | lr_find_lr.append(lr_step)
175 |
176 | # smooth the loss
177 | if itr == 0:
178 | lr_find_loss.append(min(loss, 7))
179 | else:
180 | loss = smoothing * min(loss, 7) + (1 - smoothing) * lr_find_loss[-1]
181 | lr_find_loss.append(loss)
182 |
183 | itr += 1
184 |
185 | # print loss for recent set of batches
186 | if count % pr_interval == 0:
187 | ave_loss = print_loss_batch/pr_interval
188 | ave_acc = 100 * print_acc_batch.float()/running_batch_count
189 | print_acc_batch = 0
190 | running_batch_count = 0
191 | print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc))
192 | print_loss_batch = 0
193 |
194 | if count == num_batches:
195 | break
196 |
197 | # calculate loss and accuracy for phase
198 | ave_loss_epoch = loss_epoch/count
199 | epoch_acc = 100 * running_accuray.float() / batch_count
200 | print('\tfinished %s phase [%d] loss: %.3f, acc: %.3f' % (phase, epoch + 1, ave_loss_epoch, epoch_acc))
201 |
202 | print('\n\ttime:', __time_since(start), '\n')
203 |
204 | # save model when validation loss improves
205 | if ave_loss_epoch < best_val_loss:
206 | best_val_loss = ave_loss_epoch
207 | print("\tNEW BEST LOSS: %.3f" % ave_loss_epoch, '\n')
208 |
209 | if save_model:
210 | __save_model(epoch, ave_loss_epoch, model, "TonicNet", epoch_acc)
211 | else:
212 | print("\tLOSS DID NOT IMPROVE FROM %.3f" % best_val_loss, '\n')
213 |
214 | print("DONE")
215 | if lr_range_test:
216 | plt.plot(lr_find_lr, lr_find_loss)
217 | plt.xscale('log')
218 | plt.grid('true')
219 | plt.savefig('lr_finder.png')
220 | plt.show()
221 |
222 |
223 | # MARK:- Transformer
224 | def Transformer_lr_finder(load_path=''):
225 | train_Transformer(epochs=3,
226 | save_model=False,
227 | load_path=load_path,
228 | shuffle_batches=True,
229 | num_batches=TRAIN_BATCHES,
230 | val=False,
231 | lr_range_test=True)
232 |
233 |
234 | def Transformer_sanity_test(num_batches=1, load_path=''):
235 | train_Transformer(epochs=1000,
236 | save_model=0,
237 | load_path=load_path,
238 | shuffle_batches=False,
239 | num_batches=num_batches,
240 | val=1,
241 | lr_range_test=False,
242 | sanity_test=True)
243 |
244 |
245 | def train_Transformer(epochs,
246 | save_model=True,
247 | load_path='',
248 | shuffle_batches=False,
249 | num_batches=TOTAL_BATCHES,
250 | val=True,
251 | lr_range_test=False,
252 | sanity_test=False):
253 |
254 | model = Transformer_Model(nb_tags=N_TOKENS, nb_layers=5, emb_dim=256, dropout=0.1, pe_dim=256)
255 |
256 | if load_path != '':
257 | try:
258 | if cuda.is_available():
259 | model.load_state_dict(load(load_path)['model_state_dict'])
260 | else:
261 | model.load_state_dict(load(load_path, map_location=device('cpu'))['model_state_dict'])
262 | print("loded params from", load_path)
263 | except:
264 | raise ImportError(f'No file located at {load_path}, could not load parameters')
265 | print(model)
266 |
267 | if cuda.is_available():
268 | model.cuda()
269 |
270 | base_lr = 0.06
271 | max_lr = 0.06
272 |
273 | if lr_range_test:
274 | base_lr = 0.000003
275 | max_lr = 0.3
276 |
277 | step_size = 3 * min(TRAIN_BATCHES, num_batches)
278 |
279 | if sanity_test:
280 | base_optim = RAdam(model.parameters(), lr=base_lr/100)
281 | optimiser = Lookahead(base_optim, k=5, alpha=0.5)
282 | else:
283 | optimiser = optim.SGD(model.parameters(), base_lr)
284 | criterion = CrossEntropyTimeDistributedLoss()
285 |
286 | print(criterion)
287 |
288 | print(f"min lr: {base_lr}, max_lr: {max_lr}, stepsize: {step_size}")
289 |
290 | if not sanity_test and not lr_range_test:
291 | scheduler = OneCycleLR(optimiser, max_lr,
292 | epochs=30, steps_per_epoch=TRAIN_BATCHES, pct_start=0.3,
293 | anneal_strategy='cos', cycle_momentum=True, base_momentum=0.8,
294 | max_momentum=0.95, div_factor=100.0, final_div_factor=1000.0,
295 | last_epoch=-1)
296 |
297 | elif lr_range_test:
298 | lr_lambda = lambda x: math.exp(x * math.log(max_lr / base_lr) / (epochs * TRAIN_BATCHES))
299 | scheduler = optim.lr_scheduler.LambdaLR(optimiser, lr_lambda)
300 |
301 | best_val_loss = 100.0
302 | step = 0
303 |
304 | if lr_range_test:
305 | lr_find_loss = []
306 | lr_find_lr = []
307 |
308 | itr = 0
309 | smoothing = 0.05
310 |
311 | if val:
312 | phases = CV_PHASES
313 | else:
314 | phases = TRAIN_ONLY_PHASES
315 |
316 | for epoch in range(epochs):
317 | start = time.time()
318 | pr_interval = 50
319 |
320 | print(f'Beginning EPOCH {epoch + 1}')
321 |
322 | for phase in phases:
323 | phase_loss = None
324 | model.zero_grad()
325 | count = 0
326 | batch_count = 0
327 | loss_epoch = 0
328 | running_accuray = 0.0
329 | running_batch_count = 0
330 | print_loss_batch = 0 # Reset on print
331 | print_acc_batch = 0 # Reset on print
332 |
333 | print(f'\n\tPHASE: {phase}')
334 |
335 | if phase == 'train':
336 | model.train() # Set model to training mode
337 | else:
338 | model.eval() # Set model to evaluate mode
339 |
340 | for x, y, psx, i, c in get_data_set(phase, shuffle_batches=shuffle_batches, return_I=1):
341 |
342 | Y = y
343 | model.seq_len = x.shape[1]
344 |
345 | with set_grad_enabled(phase == 'train'):
346 | y_hat = model(x, psx)
347 | y_hat = y_hat.view(1, -1, N_TOKENS)
348 | _, preds = max(y_hat, 2)
349 | loss = criterion(y_hat, Y, )
350 | if phase_loss is None:
351 | phase_loss = loss
352 | else:
353 | phase_loss += loss
354 |
355 | loss_epoch += loss.item()
356 | print_loss_batch += loss.item()
357 |
358 | len_batch = model.seq_len
359 |
360 | running_accuray += sum(preds == Y)
361 | print_acc_batch += sum(preds == Y)
362 |
363 | count += 1
364 | batch_count += len_batch
365 | running_batch_count += len_batch
366 |
367 | if count % 1 == 0:
368 | if phase == 'train':
369 | phase_loss.backward()
370 | clip_grad_norm_(model.parameters(), 5)
371 | optimiser.step()
372 | step += 1
373 | if not sanity_test:
374 | scheduler.step()
375 | phase_loss = None
376 | model.zero_grad()
377 |
378 | if lr_range_test:
379 | lr_step = optimiser.state_dict()["param_groups"][0]["lr"]
380 | lr_find_lr.append(lr_step)
381 |
382 | # smooth the loss
383 | if itr == 0:
384 | lr_find_loss.append(min(loss, 4))
385 | else:
386 | loss = smoothing * min(loss, 4) + (1 - smoothing) * lr_find_loss[-1]
387 | lr_find_loss.append(loss)
388 |
389 | itr += 1
390 |
391 | # print loss for recent set of batches
392 | if count % pr_interval == 0:
393 | ave_loss = print_loss_batch/pr_interval
394 | ave_acc = 100 * print_acc_batch.float()/running_batch_count
395 | print_acc_batch = 0
396 | running_batch_count = 0
397 | print('\t\t[%d] loss: %.3f, acc: %.3f' % (count, ave_loss, ave_acc))
398 | print_loss_batch = 0
399 |
400 | if count == num_batches:
401 | break
402 |
403 | # calculate loss and accuracy for phase
404 | ave_loss_epoch = loss_epoch/count
405 | epoch_acc = 100 * running_accuray.float() / batch_count
406 | print('\tfinished %s phase [%d] loss: %.3f, acc: %.3f' % (phase, epoch + 1, ave_loss_epoch, epoch_acc))
407 |
408 | print('\n\ttime:', __time_since(start), '\n')
409 |
410 | # save model when validation loss improves
411 | if ave_loss_epoch < best_val_loss:
412 | best_val_loss = ave_loss_epoch
413 | print("\tNEW BEST LOSS: %.3f" % ave_loss_epoch, '\n')
414 |
415 | if save_model:
416 | __save_model(epoch, ave_loss_epoch, model, "EncoReTransformer", epoch_acc)
417 | else:
418 | print("\tLOSS DID NOT IMPROVE FROM %.3f" % best_val_loss, '\n')
419 |
420 | print("DONE")
421 | if lr_range_test:
422 | plt.plot(lr_find_lr, lr_find_loss)
423 | plt.xscale('log')
424 | plt.grid('true')
425 | plt.savefig('lr_finder.png')
426 | plt.show()
427 |
428 |
429 | def __save_model(epoch, ave_loss_epoch, model, model_name, acc):
430 |
431 | test = os.listdir('eval')
432 |
433 | for item in test:
434 | if item.endswith(".pt"):
435 | os.remove(os.path.join('eval', item))
436 |
437 | path_loss = round(ave_loss_epoch, 3)
438 | path_acc = '%.3f' % acc
439 | path = f'eval/{model_name}_epoch-{epoch}_loss-{path_loss}_acc-{path_acc}.pt'
440 | save({
441 | 'epoch': epoch,
442 | 'model_state_dict': model.state_dict(),
443 | 'loss': path_loss
444 | }, path)
445 | print("\tSAVED MODEL TO:", path)
446 |
447 |
448 | def __time_since(t):
449 | now = time.time()
450 | s = now - t
451 | return '%s' % (__as_minutes(s))
452 |
453 |
454 | def __as_minutes(s):
455 | m = math.floor(s / 60)
456 | s -= m * 60
457 | return '%dm %ds' % (m, s)
458 |
--------------------------------------------------------------------------------
/train/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import warnings
4 | from torch.nn import functional as F
5 | from torch.nn import Module, Parameter
6 | from torch.nn import ModuleList
7 | from torch.nn.init import xavier_uniform_, xavier_normal_, constant_
8 | from torch.nn import Dropout
9 | from torch.nn import Linear
10 | from torch.nn import LayerNorm
11 |
12 | """
13 | Official pytorch implementation of the Transformer model
14 | """
15 |
16 |
17 | class Transformer(Module):
18 | r"""A transformer model. User is able to modify the attributes as needed. The architechture
19 | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
20 | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
21 | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
22 | Processing Systems, pages 6000-6010.
23 | Args:
24 | d_model: the number of expected features in the encoder/decoder inputs (default=512).
25 | nhead: the number of heads in the multiheadattention models (default=8).
26 | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
27 | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
28 | dim_feedforward: the dimension of the feedforward network model (default=2048).
29 | dropout: the dropout value (default=0.1).
30 | custom_encoder: custom encoder (default=None).
31 | custom_decoder: custom decoder (default=None).
32 | Examples::
33 | >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab)
34 | >>> transformer_model = nn.Transformer(src_vocab, tgt_vocab, nhead=16, num_encoder_layers=12)
35 | """
36 |
37 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
38 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
39 | custom_encoder=None, custom_decoder=None):
40 | super(Transformer, self).__init__()
41 |
42 | if custom_encoder is not None:
43 | self.encoder = custom_encoder
44 | else:
45 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
46 | encoder_norm = LayerNorm(d_model)
47 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
48 |
49 | if custom_decoder is not None:
50 | self.decoder = custom_decoder
51 | else:
52 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
53 | decoder_norm = LayerNorm(d_model)
54 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
55 |
56 | self._reset_parameters()
57 |
58 | self.d_model = d_model
59 | self.nhead = nhead
60 |
61 | def forward(self, src, tgt, src_mask=None, tgt_mask=None,
62 | memory_mask=None, src_key_padding_mask=None,
63 | tgt_key_padding_mask=None, memory_key_padding_mask=None):
64 | r"""Take in and process masked source/target sequences.
65 | Args:
66 | src: the sequence to the encoder (required).
67 | tgt: the sequence to the decoder (required).
68 | src_mask: the additive mask for the src sequence (optional).
69 | tgt_mask: the additive mask for the tgt sequence (optional).
70 | memory_mask: the additive mask for the encoder output (optional).
71 | src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
72 | tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
73 | memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
74 | Shape:
75 | - src: :math:`(S, N, E)`.
76 | - tgt: :math:`(T, N, E)`.
77 | - src_mask: :math:`(S, S)`.
78 | - tgt_mask: :math:`(T, T)`.
79 | - memory_mask: :math:`(T, S)`.
80 | - src_key_padding_mask: :math:`(N, S)`.
81 | - tgt_key_padding_mask: :math:`(N, T)`.
82 | - memory_key_padding_mask: :math:`(N, S)`.
83 | Note: [src/tgt/memory]_mask should be filled with
84 | float('-inf') for the masked positions and float(0.0) else. These masks
85 | ensure that predictions for position i depend only on the unmasked positions
86 | j and are applied identically for each sequence in a batch.
87 | [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
88 | that should be masked with float('-inf') and False values will be unchanged.
89 | This mask ensures that no information will be taken from position i if
90 | it is masked, and has a separate mask for each sequence in a batch.
91 | - output: :math:`(T, N, E)`.
92 | Note: Due to the multi-head attention architecture in the transformer model,
93 | the output sequence length of a transformer is same as the input sequence
94 | (i.e. target) length of the decode.
95 | where S is the source sequence length, T is the target sequence length, N is the
96 | batch size, E is the feature number
97 | Examples:
98 | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
99 | """
100 |
101 | if src.size(1) != tgt.size(1):
102 | raise RuntimeError("the batch number of src and tgt must be equal")
103 |
104 | if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
105 | raise RuntimeError("the feature number of src and tgt must be equal to d_model")
106 |
107 | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
108 | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
109 | tgt_key_padding_mask=tgt_key_padding_mask,
110 | memory_key_padding_mask=memory_key_padding_mask)
111 | return output
112 |
113 |
114 | def generate_square_subsequent_mask(self, sz):
115 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
116 | Unmasked positions are filled with float(0.0).
117 | """
118 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
119 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
120 | return mask
121 |
122 |
123 | def _reset_parameters(self):
124 | r"""Initiate parameters in the transformer model."""
125 |
126 | for p in self.parameters():
127 | if p.dim() > 1:
128 | xavier_uniform_(p)
129 |
130 |
131 | class TransformerEncoder(Module):
132 | r"""TransformerEncoder is a stack of N encoder layers
133 | Args:
134 | encoder_layer: an instance of the TransformerEncoderLayer() class (required).
135 | num_layers: the number of sub-encoder-layers in the encoder (required).
136 | norm: the layer normalization component (optional).
137 | Examples::
138 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
139 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
140 | """
141 |
142 | def __init__(self, encoder_layer, num_layers, norm=None):
143 | super(TransformerEncoder, self).__init__()
144 | self.layers = _get_clones(encoder_layer, num_layers)
145 | self.num_layers = num_layers
146 | self.norm = norm
147 |
148 | def forward(self, src, mask=None, src_key_padding_mask=None):
149 | r"""Pass the input through the endocder layers in turn.
150 | Args:
151 | src: the sequnce to the encoder (required).
152 | mask: the mask for the src sequence (optional).
153 | src_key_padding_mask: the mask for the src keys per batch (optional).
154 | Shape:
155 | see the docs in Transformer class.
156 | """
157 | output = src
158 |
159 | for i in range(self.num_layers):
160 | output = self.layers[i](output, src_mask=mask,
161 | src_key_padding_mask=src_key_padding_mask)
162 |
163 | if self.norm:
164 | output = self.norm(output)
165 |
166 | return output
167 |
168 |
169 | class TransformerDecoder(Module):
170 | r"""TransformerDecoder is a stack of N decoder layers
171 | Args:
172 | decoder_layer: an instance of the TransformerDecoderLayer() class (required).
173 | num_layers: the number of sub-decoder-layers in the decoder (required).
174 | norm: the layer normalization component (optional).
175 | Examples::
176 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
177 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
178 | """
179 |
180 | def __init__(self, decoder_layer, num_layers, norm=None):
181 | super(TransformerDecoder, self).__init__()
182 | self.layers = _get_clones(decoder_layer, num_layers)
183 | self.num_layers = num_layers
184 | self.norm = norm
185 |
186 | def forward(self, tgt, memory, tgt_mask=None,
187 | memory_mask=None, tgt_key_padding_mask=None,
188 | memory_key_padding_mask=None):
189 | r"""Pass the inputs (and mask) through the decoder layer in turn.
190 | Args:
191 | tgt: the sequence to the decoder (required).
192 | memory: the sequnce from the last layer of the encoder (required).
193 | tgt_mask: the mask for the tgt sequence (optional).
194 | memory_mask: the mask for the memory sequence (optional).
195 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
196 | memory_key_padding_mask: the mask for the memory keys per batch (optional).
197 | Shape:
198 | see the docs in Transformer class.
199 | """
200 | output = tgt
201 |
202 | for i in range(self.num_layers):
203 | output = self.layers[i](output, memory, tgt_mask=tgt_mask,
204 | memory_mask=memory_mask,
205 | tgt_key_padding_mask=tgt_key_padding_mask,
206 | memory_key_padding_mask=memory_key_padding_mask)
207 |
208 | if self.norm:
209 | output = self.norm(output)
210 |
211 | return output
212 |
213 |
214 | class TransformerEncoderLayer(Module):
215 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
216 | This standard encoder layer is based on the paper "Attention Is All You Need".
217 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
218 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
219 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
220 | in a different way during application.
221 | Args:
222 | d_model: the number of expected features in the input (required).
223 | nhead: the number of heads in the multiheadattention models (required).
224 | dim_feedforward: the dimension of the feedforward network model (default=2048).
225 | dropout: the dropout value (default=0.1).
226 | Examples::
227 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
228 | """
229 |
230 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
231 | super(TransformerEncoderLayer, self).__init__()
232 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
233 | # Implementation of Feedforward model
234 | self.linear1 = Linear(d_model, dim_feedforward)
235 | self.dropout = Dropout(dropout)
236 | self.linear2 = Linear(dim_feedforward, d_model)
237 |
238 | self.norm1 = LayerNorm(d_model)
239 | self.norm2 = LayerNorm(d_model)
240 | self.dropout1 = Dropout(dropout)
241 | self.dropout2 = Dropout(dropout)
242 |
243 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
244 | r"""Pass the input through the endocder layer.
245 | Args:
246 | src: the sequnce to the encoder layer (required).
247 | src_mask: the mask for the src sequence (optional).
248 | src_key_padding_mask: the mask for the src keys per batch (optional).
249 | Shape:
250 | see the docs in Transformer class.
251 | """
252 | src2 = self.self_attn(src, src, src, attn_mask=src_mask,
253 | key_padding_mask=src_key_padding_mask)[0]
254 | src = src + self.dropout1(src2)
255 | src = self.norm1(src)
256 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
257 | src = src + self.dropout2(src2)
258 | src = self.norm2(src)
259 | return src
260 |
261 |
262 |
263 | class TransformerDecoderLayer(Module):
264 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
265 | This standard decoder layer is based on the paper "Attention Is All You Need".
266 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
267 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
268 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
269 | in a different way during application.
270 | Args:
271 | d_model: the number of expected features in the input (required).
272 | nhead: the number of heads in the multiheadattention models (required).
273 | dim_feedforward: the dimension of the feedforward network model (default=2048).
274 | dropout: the dropout value (default=0.1).
275 | Examples::
276 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
277 | """
278 |
279 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
280 | super(TransformerDecoderLayer, self).__init__()
281 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
282 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
283 | # Implementation of Feedforward model
284 | self.linear1 = Linear(d_model, dim_feedforward)
285 | self.dropout = Dropout(dropout)
286 | self.linear2 = Linear(dim_feedforward, d_model)
287 |
288 | self.norm1 = LayerNorm(d_model)
289 | self.norm2 = LayerNorm(d_model)
290 | self.norm3 = LayerNorm(d_model)
291 | self.dropout1 = Dropout(dropout)
292 | self.dropout2 = Dropout(dropout)
293 | self.dropout3 = Dropout(dropout)
294 |
295 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
296 | tgt_key_padding_mask=None, memory_key_padding_mask=None):
297 | r"""Pass the inputs (and mask) through the decoder layer.
298 | Args:
299 | tgt: the sequence to the decoder layer (required).
300 | memory: the sequnce from the last layer of the encoder (required).
301 | tgt_mask: the mask for the tgt sequence (optional).
302 | memory_mask: the mask for the memory sequence (optional).
303 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
304 | memory_key_padding_mask: the mask for the memory keys per batch (optional).
305 | Shape:
306 | see the docs in Transformer class.
307 | """
308 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
309 | key_padding_mask=tgt_key_padding_mask)[0]
310 | tgt = tgt + self.dropout1(tgt2)
311 | tgt = self.norm1(tgt)
312 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
313 | key_padding_mask=memory_key_padding_mask)[0]
314 | tgt = tgt + self.dropout2(tgt2)
315 | tgt = self.norm2(tgt)
316 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
317 | tgt = tgt + self.dropout3(tgt2)
318 | tgt = self.norm3(tgt)
319 | return tgt
320 |
321 |
322 | def _get_clones(module, N):
323 | return ModuleList([copy.deepcopy(module) for i in range(N)])
324 |
325 |
326 | class MultiheadAttention(Module):
327 | r"""Allows the model to jointly attend to information
328 | from different representation subspaces.
329 | See reference: Attention Is All You Need
330 | .. math::
331 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
332 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
333 | Args:
334 | embed_dim: total dimension of the model.
335 | num_heads: parallel attention heads.
336 | dropout: a Dropout layer on attn_output_weights. Default: 0.0.
337 | bias: add bias as module parameter. Default: True.
338 | add_bias_kv: add bias to the key and value sequences at dim=0.
339 | add_zero_attn: add a new batch of zeros to the key and
340 | value sequences at dim=1.
341 | kdim: total number of features in key. Default: None.
342 | vdim: total number of features in key. Default: None.
343 | Note: if kdim and vdim are None, they will be set to embed_dim such that
344 | query, key, and value have the same number of features.
345 | Examples::
346 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
347 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
348 | """
349 |
350 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
351 | super(MultiheadAttention, self).__init__()
352 | self.embed_dim = embed_dim
353 | self.kdim = kdim if kdim is not None else embed_dim
354 | self.vdim = vdim if vdim is not None else embed_dim
355 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
356 |
357 | self.num_heads = num_heads
358 | self.dropout = dropout
359 | self.head_dim = embed_dim // num_heads
360 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
361 |
362 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
363 |
364 | if self._qkv_same_embed_dim is False:
365 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
366 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
367 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
368 |
369 | if bias:
370 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
371 | else:
372 | self.register_parameter('in_proj_bias', None)
373 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
374 |
375 | if add_bias_kv:
376 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
377 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
378 | else:
379 | self.bias_k = self.bias_v = None
380 |
381 | self.add_zero_attn = add_zero_attn
382 |
383 | self._reset_parameters()
384 |
385 | def _reset_parameters(self):
386 | if self._qkv_same_embed_dim:
387 | xavier_uniform_(self.in_proj_weight)
388 | else:
389 | xavier_uniform_(self.q_proj_weight)
390 | xavier_uniform_(self.k_proj_weight)
391 | xavier_uniform_(self.v_proj_weight)
392 |
393 | if self.in_proj_bias is not None:
394 | constant_(self.in_proj_bias, 0.)
395 | constant_(self.out_proj.bias, 0.)
396 | if self.bias_k is not None:
397 | xavier_normal_(self.bias_k)
398 | if self.bias_v is not None:
399 | xavier_normal_(self.bias_v)
400 |
401 | def forward(self, query, key, value, key_padding_mask=None,
402 | need_weights=True, attn_mask=None):
403 | r"""
404 | Args:
405 | query, key, value: map a query and a set of key-value pairs to an output.
406 | See "Attention Is All You Need" for more details.
407 | key_padding_mask: if provided, specified padding elements in the key will
408 | be ignored by the attention. This is an binary mask. When the value is True,
409 | the corresponding value on the attention layer will be filled with -inf.
410 | need_weights: output attn_output_weights.
411 | attn_mask: mask that prevents attention to certain positions. This is an additive mask
412 | (i.e. the values will be added to the attention layer).
413 | Shape:
414 | - Inputs:
415 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
416 | the embedding dimension.
417 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
418 | the embedding dimension.
419 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
420 | the embedding dimension.
421 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
422 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
423 | - Outputs:
424 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
425 | E is the embedding dimension.
426 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
427 | L is the target sequence length, S is the source sequence length.
428 | """
429 | if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
430 | return F.multi_head_attention_forward(
431 | query, key, value, self.embed_dim, self.num_heads,
432 | self.in_proj_weight, self.in_proj_bias,
433 | self.bias_k, self.bias_v, self.add_zero_attn,
434 | self.dropout, self.out_proj.weight, self.out_proj.bias,
435 | training=self.training,
436 | key_padding_mask=key_padding_mask, need_weights=need_weights,
437 | attn_mask=attn_mask, use_separate_proj_weight=True,
438 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
439 | v_proj_weight=self.v_proj_weight)
440 | else:
441 | if not hasattr(self, '_qkv_same_embed_dim'):
442 | warnings.warn('A new version of MultiheadAttention module has been implemented. \
443 | Please re-train your model with the new module',
444 | UserWarning)
445 |
446 | return F.multi_head_attention_forward(
447 | query, key, value, self.embed_dim, self.num_heads,
448 | self.in_proj_weight, self.in_proj_bias,
449 | self.bias_k, self.bias_v, self.add_zero_attn,
450 | self.dropout, self.out_proj.weight, self.out_proj.bias,
451 | training=self.training,
452 | key_padding_mask=key_padding_mask, need_weights=need_weights,
453 | attn_mask=attn_mask)
--------------------------------------------------------------------------------