├── Seeds ├── Honesty.mid ├── Nothing Else Matters.mid ├── House Of The Rising Sun.mid ├── README.md └── Sharing The Night Together.mid ├── Models ├── Version-1 │ ├── training_acc_graph.png │ ├── training_loss_graph.png │ ├── losses_accuracies.pickle │ ├── validation_acc_graph.png │ ├── validation_loss_graph.png │ └── README.md └── README.md ├── Artwork ├── Experimental-Music-Transformer-Artwork (1).png └── README.md ├── Training-Code ├── README.md ├── experimental_music_transformer_maker.py └── Experimental_Music_Transformer_Maker.ipynb ├── Training-Data ├── README.md ├── experimental_music_transformer_training_dataset_maker.py └── Experimental_Music_Transformer_Training_Dataset_Maker.ipynb ├── README.md ├── LICENSE ├── experimental_music_transformer_version_3.py ├── experimental_music_transformer_version_2.py └── experimental_music_transformer_version_1.py /Seeds/Honesty.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Seeds/Honesty.mid -------------------------------------------------------------------------------- /Seeds/Nothing Else Matters.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Seeds/Nothing Else Matters.mid -------------------------------------------------------------------------------- /Seeds/House Of The Rising Sun.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Seeds/House Of The Rising Sun.mid -------------------------------------------------------------------------------- /Seeds/README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer Sample Seed MIDIs 2 | 3 | *** 4 | 5 | ### Project Los Angeles 6 | ### Tegridy Code 2023 7 | -------------------------------------------------------------------------------- /Seeds/Sharing The Night Together.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Seeds/Sharing The Night Together.mid -------------------------------------------------------------------------------- /Models/Version-1/training_acc_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Models/Version-1/training_acc_graph.png -------------------------------------------------------------------------------- /Models/Version-1/training_loss_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Models/Version-1/training_loss_graph.png -------------------------------------------------------------------------------- /Models/Version-1/losses_accuracies.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Models/Version-1/losses_accuracies.pickle -------------------------------------------------------------------------------- /Models/Version-1/validation_acc_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Models/Version-1/validation_acc_graph.png -------------------------------------------------------------------------------- /Models/Version-1/validation_loss_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Models/Version-1/validation_loss_graph.png -------------------------------------------------------------------------------- /Artwork/Experimental-Music-Transformer-Artwork (1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/Experimental-Music-Transformer/main/Artwork/Experimental-Music-Transformer-Artwork (1).png -------------------------------------------------------------------------------- /Artwork/README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer Concept Artwork 2 | 3 | ## Images are a courtesy of Bing Image Search 4 | 5 | *** 6 | 7 | ### Project Los Angeles 8 | ### Tegridy Code 2023 9 | -------------------------------------------------------------------------------- /Models/README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer Pre-Trained Models 2 | 3 | *** 4 | 5 | ## Models are hosted on [Hugging Face](https://huggingface.co/asigalov61/Experimental-Music-Transformer) 6 | 7 | *** 8 | 9 | ### Project Los Angeles 10 | ### Tegridy Code 2023 11 | -------------------------------------------------------------------------------- /Models/Version-1/README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer Version 1 Pre-Trained Model 2 | 3 | *** 4 | 5 | ## Model was trained on all Karaoke MIDIs (~50k) from Los Angeles MIDI Dataset for 65 hours @ 2 batches on a single A100 GPU 6 | 7 | *** 8 | 9 | ## Notes on the results 10 | 11 | ### 1) ... 12 | 13 | *** 14 | 15 | ### Project Los Angeles 16 | ### Tegridy Code 2023 17 | -------------------------------------------------------------------------------- /Training-Code/README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer Training Code 2 | 3 | *** 4 | 5 | [![Open In Colab][colab-badge]][colab-notebook1] 6 | 7 | [colab-notebook1]: 8 | [colab-badge]: 9 | 10 | *** 11 | 12 | ## Recommended DL/ML cloud provider: [Lambda Labs](https://lambdalabs.com/) 13 | 14 | *** 15 | 16 | ### Project Los Angeles 17 | ### Tegridy Code 2023 18 | -------------------------------------------------------------------------------- /Training-Data/README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer Training Dataset Maker 2 | 3 | *** 4 | 5 | [![Open In Colab][colab-badge]][colab-notebook2] 6 | 7 | [colab-notebook2]: 8 | [colab-badge]: 9 | 10 | *** 11 | 12 | ## Recommended MIDI datasets: 13 | ### 1) [Los Angeles MIDI Dataset](https://github.com/asigalov61/Los-Angeles-MIDI-Dataset) 14 | ### 2) [LAKH MIDI Dataset](https://colinraffel.com/projects/lmd/) (included with code/colab) 15 | 16 | *** 17 | 18 | ### Project Los Angeles 19 | ### Tegridy Code 2023 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Experimental Music Transformer 2 | ## Experimental music transformer to explore new techniques and ideas 3 | 4 | ![image_processing20200927-24845-1kxd3mo](https://github.com/asigalov61/Experimental-Music-Transformer/assets/56325539/d154431d-937b-4216-a4dc-24cd1bac6009) 5 | 6 | *** 7 | 8 | ## Version 1.0 9 | 10 | [![Open In Colab][colab-badge]][colab-notebook2] 11 | 12 | [colab-notebook2]: 13 | [colab-badge]: 14 | 15 | ### Main feature in this version are the emphasis tokens. This should allow to condition the model on simple lyrics. 16 | 17 | *** 18 | 19 | ## [WIP] [DEV] Version 2.0 20 | 21 | [![Open In Colab][colab-badge]][colab-notebook3] 22 | 23 | [colab-notebook3]: 24 | [colab-badge]: 25 | 26 | ### This version tries to utilize mixed continious/discrete transformer implementation to improve the quality of the generated ouput 27 | ### For more info please see [x-transformers](https://github.com/lucidrains/x-transformers) repo / [XVal implementation](https://github.com/lucidrains/x-transformers/blob/main/x_transformers/xval.py) 28 | 29 | *** 30 | 31 | ## Version 3.0 32 | 33 | [![Open In Colab][colab-badge]][colab-notebook4] 34 | 35 | [colab-notebook4]: 36 | [colab-badge]: 37 | 38 | ### Nice and robust seq2seq implementation for music seq2seq experiments. This version tries to model seq2seq accompaniment generation 39 | 40 | *** 41 | 42 | ### Project Los Angeles 43 | ### Tegridy Code 2023 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Training-Code/experimental_music_transformer_maker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Experimental_Music_Transformer_Maker.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/github/asigalov61/Experimental-Music-Transformer/blob/main/Training-Code/Experimental_Music_Transformer_Maker.ipynb 8 | 9 | # Experimental Music Transformer Maker (ver. 1.0) 10 | 11 | *** 12 | 13 | Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools 14 | 15 | *** 16 | 17 | WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/ 18 | 19 | *** 20 | 21 | #### Project Los Angeles 22 | 23 | #### Tegridy Code 2023 24 | 25 | *** 26 | 27 | # GPU check 28 | """ 29 | 30 | !nvidia-smi 31 | 32 | """# Setup environment""" 33 | 34 | !git clone https://github.com/asigalov61/tegridy-tools 35 | 36 | !pip uninstall torch 37 | 38 | !pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 39 | !pip install einops 40 | !pip install torch-summary 41 | !pip install tqdm 42 | !pip install matplotlib 43 | 44 | # Commented out IPython magic to ensure Python compatibility. 45 | # Load modules and make data dir 46 | 47 | print('Loading modules...') 48 | 49 | import os 50 | import pickle 51 | import random 52 | import secrets 53 | import tqdm 54 | import math 55 | import torch 56 | import torch.optim as optim 57 | from torch.utils.data import DataLoader, Dataset 58 | 59 | import matplotlib.pyplot as plt 60 | 61 | from torchsummary import summary 62 | from sklearn import metrics 63 | 64 | # %cd /content/tegridy-tools/tegridy-tools/ 65 | 66 | import TMIDIX 67 | 68 | # %cd /content/tegridy-tools/tegridy-tools/X-Transformer 69 | 70 | from x_transformer_1_23_2 import * 71 | 72 | torch.set_float32_matmul_precision('high') 73 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 74 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 75 | 76 | # %cd /content/ 77 | 78 | if not os.path.exists('/content/INTS'): 79 | os.makedirs('/content/INTS') 80 | 81 | import random 82 | 83 | print('Done') 84 | 85 | print('Torch version:', torch.__version__) 86 | 87 | """# Load training data""" 88 | 89 | dataset_addr = "/content/INTS" 90 | 91 | #========================================================================== 92 | 93 | filez = list() 94 | for (dirpath, dirnames, filenames) in os.walk(dataset_addr): 95 | filez += [os.path.join(dirpath, file) for file in filenames] 96 | print('=' * 70) 97 | 98 | random.shuffle(filez) 99 | 100 | print('Loaded', len(filez), 'data files') 101 | print('=' * 70) 102 | 103 | """# Setup model""" 104 | 105 | # Setup model 106 | 107 | # constants 108 | 109 | NUM_DATA_FILES_TO_LOAD_PER_ITER = 8 110 | 111 | SEQ_LEN = 8192 # Models seq len 112 | PAD_IDX = 9000 # Models pad index 113 | 114 | NUM_EPOCHS = 1 115 | 116 | BATCH_SIZE = 4 117 | GRADIENT_ACCUMULATE_EVERY = 4 118 | 119 | LEARNING_RATE = 2e-4 120 | 121 | VALIDATE_EVERY = 100 122 | SAVE_EVERY = 500 123 | GENERATE_EVERY = 250 124 | GENERATE_LENGTH = 512 125 | PRINT_STATS_EVERY = 20 126 | 127 | # helpers 128 | 129 | def cycle(loader): 130 | while True: 131 | for data in loader: 132 | yield data 133 | 134 | # instantiate the model 135 | 136 | model = TransformerWrapper( 137 | num_tokens = PAD_IDX+1, 138 | max_seq_len = SEQ_LEN, 139 | attn_layers = Decoder(dim = 1024, depth = 20, heads = 16, attn_flash = True) 140 | ) 141 | 142 | model = AutoregressiveWrapper(model, ignore_index = PAD_IDX) 143 | 144 | model = torch.nn.DataParallel(model) 145 | 146 | model.cuda() 147 | 148 | print('Done!') 149 | 150 | summary(model) 151 | 152 | # Dataloader 153 | 154 | class MusicDataset(Dataset): 155 | def __init__(self, data, seq_len): 156 | super().__init__() 157 | self.data = data 158 | self.seq_len = seq_len 159 | 160 | def __getitem__(self, index): 161 | 162 | # consequtive sampling 163 | 164 | full_seq = torch.Tensor(self.data[index][:self.seq_len+1]).long() 165 | 166 | return full_seq.cuda() 167 | 168 | def __len__(self): 169 | return (len(self.data) // BATCH_SIZE) * BATCH_SIZE 170 | 171 | # precision/optimizer/scaler 172 | 173 | dtype = torch.float16 174 | 175 | ctx = torch.amp.autocast(device_type='cuda', dtype=dtype) 176 | 177 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 178 | 179 | scaler = torch.cuda.amp.GradScaler(enabled=True) 180 | 181 | """# Train""" 182 | 183 | # Train the model 184 | 185 | CHUNKS_LENGTH = SEQ_LEN+1 186 | MIN_NUMBER_OF_CHUNK_EVENTS = 512 # min number of tokens per chunk 187 | 188 | train_losses = [] 189 | val_losses = [] 190 | 191 | train_accs = [] 192 | val_accs = [] 193 | 194 | nsteps = 0 195 | 196 | for fa in range(0, len(filez), NUM_DATA_FILES_TO_LOAD_PER_ITER): 197 | 198 | #========================================================================== 199 | print('=' * 70) 200 | print('Loading data files', fa, '---', fa+NUM_DATA_FILES_TO_LOAD_PER_ITER-1) 201 | print('Please wait...') 202 | print('=' * 70) 203 | 204 | train_data = [] 205 | 206 | chunks_counter = 0 207 | discarted_chunks_counter = 1 208 | 209 | for lfa in tqdm.tqdm(filez[fa:fa+NUM_DATA_FILES_TO_LOAD_PER_ITER]): 210 | 211 | train_d = pickle.load(open(lfa, 'rb')) 212 | random.shuffle(train_d) 213 | for t in train_d: 214 | for i in range(0, len(t), int((SEQ_LEN * 3) / 4)): 215 | 216 | #========================================================================= 217 | # collecting all possible chunks of chunks length 218 | 219 | if 0 <= max(t[i:i+CHUNKS_LENGTH]) < PAD_IDX: # final data integrity check 220 | if len(t[i:i+CHUNKS_LENGTH]) == CHUNKS_LENGTH: 221 | train_data.append(t[i:i+CHUNKS_LENGTH]) 222 | 223 | else: 224 | if len(t[i:i+CHUNKS_LENGTH]) >= MIN_NUMBER_OF_CHUNK_EVENTS: 225 | td = t[i:i+CHUNKS_LENGTH] + [PAD_IDX] * (CHUNKS_LENGTH-len(t[i:i+CHUNKS_LENGTH])) # padding with pad index 226 | train_data.append(td) 227 | else: 228 | discarted_chunks_counter += 1 229 | 230 | chunks_counter += 1 231 | 232 | else: 233 | print('Bad data!!!') 234 | break 235 | 236 | #========================================================================= 237 | # Collecting middle chunk if it larger than chunks length 238 | 239 | if 0 <= max(t) < PAD_IDX: # final data integrity check 240 | if len(t) >= SEQ_LEN+8: 241 | comp_middle = int(len(t) / 8) 242 | sidx = int((comp_middle * 4)-(SEQ_LEN / 2)) 243 | train_data.append(t[sidx:sidx+CHUNKS_LENGTH]) 244 | 245 | else: 246 | discarted_chunks_counter += 1 247 | 248 | chunks_counter += 1 249 | 250 | else: 251 | print('Bad data!!!') 252 | break 253 | 254 | #========================================================================== 255 | 256 | print('Done!') 257 | print('=' * 70) 258 | print('Total number of imput chunks:', chunks_counter) 259 | print('Total number of good chunks:', len(train_data)) 260 | print('Total number of discarted chunks:', discarted_chunks_counter, '/', round(100 * discarted_chunks_counter/chunks_counter, 3), '%') 261 | print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len))) 262 | print('=' * 70) 263 | print('Final data randomization...') 264 | random.shuffle(train_data) 265 | print('Done!') 266 | print('=' * 70) 267 | 268 | 269 | train_dataset = MusicDataset(train_data, SEQ_LEN) 270 | val_dataset = MusicDataset(train_data, SEQ_LEN) 271 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 272 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 273 | 274 | NUM_BATCHES = (len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY) * NUM_EPOCHS 275 | 276 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'): 277 | model.train() 278 | 279 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 280 | with ctx: 281 | loss, acc = model(next(train_loader)) 282 | loss = loss / GRADIENT_ACCUMULATE_EVERY 283 | scaler.scale(loss).backward(torch.ones(loss.shape).cuda()) 284 | 285 | if i % PRINT_STATS_EVERY == 0: 286 | print(f'Training loss: {loss.mean().item() * GRADIENT_ACCUMULATE_EVERY}') 287 | print(f'Training acc: {acc.mean().item()}') 288 | 289 | train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY) 290 | train_accs.append(acc.mean().item()) 291 | 292 | scaler.unscale_(optim) 293 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 294 | scaler.step(optim) 295 | scaler.update() 296 | optim.zero_grad(set_to_none=True) 297 | 298 | nsteps += 1 299 | 300 | if i % VALIDATE_EVERY == 0: 301 | model.eval() 302 | with torch.no_grad(): 303 | with ctx: 304 | val_loss, val_acc = model(next(val_loader)) 305 | 306 | print(f'Validation loss: {val_loss.mean().item()}') 307 | print(f'Validation acc: {val_acc.mean().item()}') 308 | 309 | val_losses.append(val_loss.mean().item()) 310 | val_accs.append(val_acc.mean().item()) 311 | 312 | print('Plotting training loss graph...') 313 | 314 | tr_loss_list = train_losses 315 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 316 | plt.show() 317 | plt.close() 318 | print('Done!') 319 | 320 | print('Plotting training acc graph...') 321 | 322 | tr_loss_list = train_accs 323 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 324 | plt.show() 325 | plt.close() 326 | print('Done!') 327 | 328 | print('Plotting validation loss graph...') 329 | tr_loss_list = val_losses 330 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 331 | plt.show() 332 | plt.close() 333 | print('Done!') 334 | 335 | print('Plotting validation acc graph...') 336 | tr_loss_list = val_accs 337 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 338 | plt.show() 339 | plt.close() 340 | print('Done!') 341 | 342 | if i % GENERATE_EVERY == 0: 343 | model.eval() 344 | 345 | inp = random.choice(val_dataset)[:-1] 346 | 347 | print(inp) 348 | 349 | with ctx: 350 | sample = model.module.generate(inp[None, ...], GENERATE_LENGTH) 351 | 352 | print(sample) 353 | 354 | data = sample.tolist()[0] 355 | 356 | print('Sample INTs', data[:15]) 357 | 358 | out = data[:200000] 359 | 360 | if len(out) != 0: 361 | 362 | song = out 363 | song_f = [] 364 | 365 | time = 0 366 | dur = 0 367 | vel = 90 368 | pitch = 0 369 | channel = 0 370 | 371 | for ss in song: 372 | 373 | if 0 <= ss < 512: 374 | 375 | time += ss * 8 376 | 377 | if 512 <= ss < 4608: 378 | 379 | dur = ((ss-512) // 8) * 8 380 | vel = (((ss-512) % 8)+1) * 15 381 | 382 | if 4608 <= ss < 6784: 383 | 384 | patch = (ss-4608) // 128 385 | 386 | if patch == 16: 387 | channel = 9 388 | else: 389 | if 9 <= patch <= 14: 390 | channel = patch + 1 391 | else: 392 | channel = patch 393 | 394 | if patch == 15: 395 | channel = 15 396 | 397 | pitch = (ss-4608) % 128 398 | 399 | song_f.append(['note', time, dur, channel, pitch, vel ]) 400 | 401 | detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 402 | output_signature = 'Experimental Music Transformer', 403 | output_file_name = '/content/Experimental-Music-Trnasformer-Composition', 404 | track_name='Project Los Angeles', 405 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 406 | ) 407 | 408 | print('Done!') 409 | 410 | if i % SAVE_EVERY == 0: 411 | 412 | print('Saving model progress. Please wait...') 413 | print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth') 414 | 415 | fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth' 416 | 417 | torch.save(model.state_dict(), fname) 418 | 419 | data = [train_losses, train_accs, val_losses, val_accs] 420 | 421 | TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs') 422 | 423 | print('Done!') 424 | 425 | #====================================================================================================== 426 | 427 | print('Saving model progress. Please wait...') 428 | print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth') 429 | 430 | fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth' 431 | 432 | torch.save(model.state_dict(), fname) 433 | 434 | print('Done!') 435 | 436 | data = [train_losses, train_accs, val_losses, val_accs] 437 | 438 | TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies') 439 | 440 | # Save training loss graph 441 | 442 | plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b') 443 | plt.savefig('/content/training_loss_graph.png') 444 | plt.close() 445 | print('Done!') 446 | 447 | # Save training acc graph 448 | 449 | plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b') 450 | plt.savefig('/content/training_acc_graph.png') 451 | plt.close() 452 | print('Done!') 453 | 454 | # Save validation loss graph 455 | 456 | plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b') 457 | plt.savefig('/content/validation_loss_graph.png') 458 | plt.close() 459 | print('Done!') 460 | 461 | # Save validation acc graph 462 | 463 | plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b') 464 | plt.savefig('/content/validation_acc_graph.png') 465 | plt.close() 466 | print('Done!') 467 | 468 | """# Final Save""" 469 | 470 | print('Saving model progress. Please wait...') 471 | print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth') 472 | 473 | fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth' 474 | 475 | torch.save(model.state_dict(), fname) 476 | 477 | print('Done!') 478 | 479 | data = [train_losses, train_accs, val_losses, val_accs] 480 | 481 | TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies') 482 | 483 | # Save training loss graph 484 | 485 | plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b') 486 | plt.savefig('/content/training_loss_graph.png') 487 | plt.close() 488 | print('Done!') 489 | 490 | # Save training acc graph 491 | 492 | plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b') 493 | plt.savefig('/content/training_acc_graph.png') 494 | plt.close() 495 | print('Done!') 496 | 497 | # Save validation loss graph 498 | 499 | plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b') 500 | plt.savefig('/content/validation_loss_graph.png') 501 | plt.close() 502 | print('Done!') 503 | 504 | # Save validation acc graph 505 | 506 | plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b') 507 | plt.savefig('/content/validation_acc_graph.png') 508 | plt.close() 509 | print('Done!') 510 | 511 | """# Eval""" 512 | 513 | model.eval() 514 | 515 | #x = torch.tensor((random.choice(train_data)[:1000], dtype=torch.long, device=device_type)[None, ...]) 516 | x = torch.tensor([[8998, 8851+0, 8853+0, 8870+60]] * 4, dtype=torch.long, device='cuda') 517 | 518 | # run generation 519 | 520 | with ctx: 521 | out = model.module.generate(x, 522 | 500, 523 | temperature=0.9, 524 | return_prime=True, 525 | verbose=True) 526 | 527 | y = out.tolist() 528 | 529 | print('---------------') 530 | 531 | #@title Test INTs 532 | 533 | data = y[0] 534 | 535 | print('Sample INTs', data[:15]) 536 | 537 | out = data[:200000] 538 | 539 | if len(out) != 0: 540 | 541 | song = out 542 | song_f = [] 543 | 544 | time = 0 545 | dur = 0 546 | vel = 90 547 | pitch = 0 548 | channel = 0 549 | 550 | for ss in song: 551 | 552 | if 0 <= ss < 512: 553 | 554 | time += ss * 8 555 | 556 | if 512 <= ss < 4608: 557 | 558 | dur = ((ss-512) // 8) * 8 559 | vel = (((ss-512) % 8)+1) * 15 560 | 561 | if 4608 <= ss < 6784: 562 | 563 | patch = (ss-4608) // 128 564 | 565 | if patch == 16: 566 | channel = 9 567 | else: 568 | if 9 <= patch <= 14: 569 | channel = patch + 1 570 | else: 571 | channel = patch 572 | 573 | if patch == 15: 574 | channel = 15 575 | 576 | pitch = (ss-4608) % 128 577 | 578 | if patch == 17: 579 | break 580 | 581 | song_f.append(['note', time, dur, channel, pitch, vel ]) 582 | 583 | detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 584 | output_signature = 'Experimental Music Transformer', 585 | output_file_name = '/content/Experimental-Music-Trnasformer-Composition', 586 | track_name='Project Los Angeles', 587 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 588 | ) 589 | 590 | print('Done!') 591 | 592 | patches 593 | 594 | tok_emb = model.module.net.token_emb.emb.weight.detach().cpu().tolist() 595 | 596 | cos_sim = metrics.pairwise_distances( 597 | tok_emb, metric='cosine' 598 | ) 599 | plt.figure(figsize=(7, 7)) 600 | plt.imshow(cos_sim, cmap="inferno", interpolation="nearest") 601 | im_ratio = cos_sim.shape[0] / cos_sim.shape[1] 602 | plt.colorbar(fraction=0.046 * im_ratio, pad=0.04) 603 | plt.xlabel("Position") 604 | plt.ylabel("Position") 605 | plt.tight_layout() 606 | plt.plot() 607 | plt.savefig("/content/Experimental-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight") 608 | 609 | """# Congrats! You did it! :)""" -------------------------------------------------------------------------------- /experimental_music_transformer_version_3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Experimental_Music_Transformer_Version_3.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1yKIPMV0y4adB8YYScEVHnjHqrj8Diyse 8 | 9 | # Experimental Music Transformer Version 3 (ver. 0.5) 10 | 11 | *** 12 | 13 | Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools 14 | 15 | *** 16 | 17 | WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/ 18 | 19 | *** 20 | 21 | #### Project Los Angeles 22 | 23 | #### Tegridy Code 2023 24 | 25 | *** 26 | 27 | # (SETUP ENVIRONMENT) 28 | """ 29 | 30 | #@title Install all dependencies (run only once per session) 31 | 32 | !git clone https://github.com/asigalov61/tegridy-tools 33 | !pip install einops 34 | !pip install torch-summary 35 | 36 | # Commented out IPython magic to ensure Python compatibility. 37 | #@title Import all needed modules 38 | 39 | print('Loading needed modules. Please wait...') 40 | 41 | import os 42 | import pickle 43 | import copy 44 | import statistics 45 | import secrets 46 | import tqdm 47 | import math 48 | 49 | from joblib import Parallel, delayed, parallel_config 50 | 51 | import torch 52 | import torch.optim as optim 53 | from torch.optim import Adam 54 | from torch.utils.data import DataLoader, Dataset 55 | 56 | torch.set_float32_matmul_precision('high') 57 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 58 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 59 | 60 | import matplotlib.pyplot as plt 61 | 62 | from torchsummary import summary 63 | from sklearn import metrics 64 | 65 | print('Loading TMIDIX module...') 66 | 67 | # %cd /content/tegridy-tools/tegridy-tools/ 68 | 69 | import TMIDIX 70 | 71 | print('Loading X Transformer module...') 72 | 73 | # %cd /content/tegridy-tools/tegridy-tools/X-Transformer 74 | 75 | from x_transformer_1_23_2 import * 76 | import random 77 | 78 | # %cd /content/ 79 | 80 | print('Creating I/O dirs...') 81 | 82 | if not os.path.exists('/content/Dataset'): 83 | os.makedirs('/content/Dataset') 84 | 85 | if not os.path.exists('/content/DATA'): 86 | os.makedirs('/content/DATA') 87 | 88 | print('Done!') 89 | print('PyTorch version:', torch.__version__) 90 | print('Enjoy! :)') 91 | 92 | """# (DOWNLOAD AND UNZIP MIDI DATASET)""" 93 | 94 | # Commented out IPython magic to ensure Python compatibility. 95 | # @title Download and unzip Mono Melodies Piano Violin MIDI Dataset 96 | # %cd /content/Dataset 97 | !wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Mono-Melodies/Piano-Violin/Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.001 98 | !wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Mono-Melodies/Piano-Violin/Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.002 99 | !cat Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip* > Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip 100 | !unzip Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip 101 | !rm Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip 102 | # %cd /content/ 103 | 104 | """# (LOAD MIDI PROCESSOR)""" 105 | 106 | #@title TMIDIX MIDI Processor 107 | 108 | print('=' * 70) 109 | print('Loading TMIDIX MIDI Processor...') 110 | print('=' * 70) 111 | 112 | def TMIDIX_MIDI_Processor(midi_file): 113 | 114 | melody_chords = [] 115 | 116 | try: 117 | 118 | fn = os.path.basename(midi_file) 119 | 120 | # Filtering out GIANT4 MIDIs 121 | file_size = os.path.getsize(midi_file) 122 | 123 | if file_size <= 1000000: 124 | 125 | #======================================================= 126 | # START PROCESSING 127 | 128 | # Convering MIDI to ms score with MIDI.py module 129 | score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False) 130 | 131 | # INSTRUMENTS CONVERSION CYCLE 132 | events_matrix = [] 133 | itrack = 1 134 | patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 135 | 136 | while itrack < len(score): 137 | for event in score[itrack]: 138 | if event[0] == 'note' or event[0] == 'patch_change': 139 | events_matrix.append(event) 140 | itrack += 1 141 | 142 | events_matrix.sort(key=lambda x: x[1]) 143 | 144 | events_matrix1 = [] 145 | 146 | for event in events_matrix: 147 | if event[0] == 'patch_change': 148 | patches[event[2]] = event[3] 149 | 150 | if event[0] == 'note': 151 | event.extend([patches[event[3]]]) 152 | 153 | if events_matrix1: 154 | if (event[1] == events_matrix1[-1][1]): 155 | if ([event[3], event[4]] != events_matrix1[-1][3:5]): 156 | events_matrix1.append(event) 157 | else: 158 | events_matrix1.append(event) 159 | 160 | else: 161 | events_matrix1.append(event) 162 | 163 | if len(events_matrix1) > 0: 164 | if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0: 165 | 166 | #======================================================= 167 | # PRE-PROCESSING 168 | 169 | # checking number of instruments in a composition 170 | instruments_list = list(set([y[3] for y in events_matrix1])) 171 | 172 | if len(events_matrix1) > 0: 173 | 174 | #=================================== 175 | # ORIGINAL COMPOSITION 176 | #=================================== 177 | 178 | # Adjusting timings 179 | 180 | for e in events_matrix1: 181 | e[1] = int(e[1] / 16) 182 | e[2] = int(e[2] / 16) 183 | 184 | # Sorting by patch, pitch, then by start-time 185 | 186 | events_matrix1.sort(key=lambda x: x[6]) 187 | events_matrix1.sort(key=lambda x: x[4], reverse=True) 188 | events_matrix1.sort(key=lambda x: x[1]) 189 | 190 | #======================================================= 191 | # FINAL PROCESSING 192 | 193 | #======================================================= 194 | # MAIN PROCESSING CYCLE 195 | #======================================================= 196 | 197 | pe = events_matrix1[0] 198 | 199 | notes = [] 200 | 201 | for e in events_matrix1: 202 | 203 | time = max(0, min(255, (e[1] - pe[1]))) 204 | dur = max(0, min(255, e[2])) 205 | cha = max(0, min(15, e[3])) 206 | ptc = max(1, min(127, e[4])) 207 | 208 | notes.append([time, dur, cha, ptc]) 209 | 210 | pe = e 211 | 212 | chords = [] 213 | cho = [] 214 | 215 | for n in notes: 216 | 217 | if n[2] not in [0, 3]: 218 | n[2] = 0 219 | 220 | if n[0] == 0: 221 | chans = list(set([nn[2] for nn in cho])) 222 | if (n[2] == 3) and (3 in chans): 223 | n[2] = 0 224 | 225 | cho.append(n) 226 | else: 227 | if len(cho) > 0: 228 | chords.append(cho) 229 | 230 | cho = [] 231 | cho.append(n) 232 | 233 | 234 | if len(cho) > 0: 235 | chords.append(cho) 236 | 237 | return chords 238 | 239 | except: 240 | return None 241 | 242 | print('Done!') 243 | print('=' * 70) 244 | 245 | """# (FILES LIST)""" 246 | 247 | #@title Save file list 248 | ########### 249 | 250 | print('=' * 70) 251 | print('Loading MIDI files...') 252 | print('This may take a while on a large dataset in particular.') 253 | 254 | dataset_addr = "/content/Dataset" 255 | 256 | # os.chdir(dataset_addr) 257 | filez = list() 258 | for (dirpath, dirnames, filenames) in os.walk(dataset_addr): 259 | filez += [os.path.join(dirpath, file) for file in filenames] 260 | print('=' * 70) 261 | 262 | if not filez: 263 | print('Could not find any MIDI files. Please check Dataset dir...') 264 | print('=' * 70) 265 | 266 | else: 267 | print('Randomizing file list...') 268 | random.shuffle(filez) 269 | print('Done!') 270 | print('=' * 70) 271 | print('Total files:', len(filez)) 272 | print('=' * 70) 273 | 274 | """# (PROCESS MIDIs)""" 275 | 276 | #@title Process MIDIs with TMIDIX MIDI processor 277 | 278 | print('=' * 70) 279 | print('TMIDIX MIDI Processor') 280 | print('=' * 70) 281 | print('Starting up...') 282 | print('=' * 70) 283 | 284 | ########### 285 | 286 | melody_chords_f = [] 287 | 288 | print('Processing MIDI files. Please wait...') 289 | print('=' * 70) 290 | 291 | for i in tqdm.tqdm(range(0, len(filez), 16)): 292 | 293 | with parallel_config(backend='threading', n_jobs=4, verbose = 0): 294 | 295 | output = Parallel()(delayed(TMIDIX_MIDI_Processor)(f) for f in filez[i:i+16]) 296 | 297 | for o in output: 298 | if o is not None: 299 | melody_chords_f.append(o) 300 | 301 | print('Done!') 302 | print('=' * 70) 303 | 304 | """# (SAVE/LOAD PROCESSED MIDIs)""" 305 | 306 | #@title Save processed MIDIs 307 | TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/DATA/Processed_MIDIs') 308 | 309 | # @title Load processed MIDIs 310 | melody_chords_f = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/DATA/Processed_MIDIs') 311 | print('Done!') 312 | 313 | """# (PREP INTs)""" 314 | 315 | # @title Convert processed MIDIs to INTs for training 316 | 317 | def split_data_into_chunks(data, chunk_size): 318 | # Use list comprehension to create chunks of the specified size 319 | return [data[i:i + chunk_size] for i in range(0, len(data) - len(data) % chunk_size, chunk_size)] 320 | 321 | print('=' * 70) 322 | 323 | SEQ_LEN = 256 324 | 325 | src_list = [] 326 | trg_list = [] 327 | 328 | for m in tqdm.tqdm(melody_chords_f): 329 | 330 | for j in range(-6, 6): 331 | 332 | sdat = [] 333 | tdat = [] 334 | 335 | for mmm in m: 336 | 337 | melody_tone = max(1, min(127, mmm[0][3]+j)) % 12 338 | 339 | tones_chord = sorted(list(set([(n[3]+j) % 12 for n in mmm]))) 340 | 341 | try: 342 | chord_id = TMIDIX.ALL_CHORDS.index(tones_chord) 343 | except: 344 | chord_id = -1 345 | 346 | if chord_id != -1: 347 | pchord_id = chord_id 348 | 349 | else: 350 | pchord_id = TMIDIX.ALL_CHORDS.index([melody_tone]) 351 | 352 | sdat.extend([melody_tone]) 353 | tdat.extend([pchord_id]) 354 | 355 | if len(sdat) > SEQ_LEN: 356 | sdat1 = split_data_into_chunks(sdat, SEQ_LEN) 357 | 358 | if len(tdat) > SEQ_LEN: 359 | tdat1 = split_data_into_chunks(tdat, SEQ_LEN) 360 | 361 | for i in range(len(tdat1)): 362 | fill_ratio = sum(1 for t in tdat1[i] if t > 11) / len(tdat1[i]) 363 | 364 | if fill_ratio >= 0.6: 365 | src_list.append(sdat1[i]) 366 | trg_list.append(tdat1[i]) 367 | 368 | print('Done!') 369 | print('=' * 70) 370 | if len(max(src_list, key=len)) == len(min(src_list, key=len)) and len(max(trg_list, key=len)) == len(min(trg_list, key=len)): 371 | print('All data is good!') 372 | else: 373 | print('WARNING!!! BAD DATA!!!') 374 | print('=' * 70) 375 | 376 | trg_list[0] 377 | 378 | """#=========================================================== 379 | 380 | # (SAVE/LOAD TRAINING INTs) 381 | """ 382 | 383 | # @title Save INTs 384 | TMIDIX.Tegridy_Any_Pickle_File_Writer([src_list, trg_list], '/content/DATA/Training_INTs') 385 | 386 | # @title Load INTs 387 | src_list, trg_list = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/DATA/Training_INTs') 388 | print('Done!') 389 | 390 | """# (PREP DATA)""" 391 | 392 | # @title DATA 393 | SEQ_LEN = 256 394 | batch_size = 128 395 | 396 | def train_test_split(src, trg, test_size=0.2): 397 | indices = torch.randperm(len(src)).tolist() 398 | split = int(test_size * len(src)) 399 | src_train = src[indices[split:]] 400 | trg_train = trg[indices[split:]] 401 | src_test = src[indices[:split]] 402 | trg_test = trg[indices[:split]] 403 | return src_train, src_test, trg_train, trg_test 404 | 405 | # Convert lists to PyTorch tensors 406 | src_in = torch.tensor(src_list).long() 407 | trg_in = torch.tensor(trg_list).long() 408 | 409 | # Split the data into train, validation, and test sets 410 | src_train, src_val_test, trg_train, trg_val_test = train_test_split(src_in, trg_in, test_size=0.05) 411 | src_val, src_test, trg_val, trg_test = train_test_split(src_val_test, trg_val_test, test_size=0.3) 412 | 413 | class MusicDataset(Dataset): 414 | def __init__(self, src, trg): 415 | self.src = src 416 | self.trg = trg 417 | 418 | def __len__(self): 419 | return len(self.src) 420 | 421 | def __getitem__(self, idx): 422 | src = self.src[idx].long() 423 | trg = self.trg[idx].long() 424 | return src, trg 425 | 426 | # Create datasets for each split 427 | train_dataset = MusicDataset(src_train, trg_train) 428 | val_dataset = MusicDataset(src_val, trg_val) 429 | test_dataset = MusicDataset(src_test, trg_test) 430 | 431 | # Create data loaders for each split 432 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True) 433 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True) 434 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True) 435 | 436 | train_dataset[0] 437 | 438 | """# (TRAIN)""" 439 | 440 | # @title Train model 441 | 442 | torch.cuda.empty_cache() 443 | 444 | # Initialize the model 445 | model = XTransformer( 446 | dim = 512, 447 | enc_num_tokens = 12, 448 | enc_depth = 8, 449 | enc_heads = 8, 450 | enc_max_seq_len = SEQ_LEN, 451 | enc_dropout = 0.3, 452 | enc_attn_flash = True, 453 | dec_num_tokens = 320, 454 | dec_depth = 8, 455 | dec_heads = 8, 456 | dec_max_seq_len = SEQ_LEN, 457 | dec_dropout = 0.3, 458 | dec_attn_flash = True, 459 | cross_attn_tokens_dropout = 0.3 460 | ) 461 | 462 | model.cuda() 463 | 464 | # Define the optimizer 465 | optimizer = Adam(model.parameters()) 466 | 467 | # Initialize AMP 468 | scaler = torch.cuda.amp.GradScaler() 469 | 470 | NUM_EPOCHS = 100 471 | GRADIENT_ACCUMULATE_EVERY = 1 # Set the number of steps to accumulate gradients 472 | 473 | PRINT_STATS_EVERY = 20 474 | 475 | train_losses = [] 476 | train_accs = [] 477 | 478 | num_steps = 0 479 | 480 | # Training loop with gradient accumulation 481 | 482 | for epoch in range(NUM_EPOCHS): # replace NUM_EPOCHS with the actual number of epochs 483 | model.train() # set the model to training mode 484 | total_loss = 0 485 | optimizer.zero_grad(set_to_none=True) # Initialize gradients to zero at the start of the epoch 486 | 487 | for batch_idx, batch in enumerate(tqdm.tqdm(train_loader)): # iterate over batches of data 488 | src, tgt = [item.cuda() for item in batch] # unpack the source and target tensors from the current batch 489 | 490 | src_mask = src.bool().cuda() # create a mask for the source sequence 491 | with torch.cuda.amp.autocast(): 492 | loss, acc = model(src, tgt, mask=src_mask) # forward pass 493 | 494 | # loss = loss / GRADIENT_ACCUMULATE_EVERY # Normalize the loss by the number of accumulation steps 495 | scaler.scale(loss).backward() # Backward pass with gradient scaling 496 | 497 | train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY) 498 | train_accs.append(acc.mean().item()) 499 | 500 | 501 | if (batch_idx + 1) % GRADIENT_ACCUMULATE_EVERY == 0: # Perform optimization step after accumulating gradients 502 | scaler.unscale_(optimizer) 503 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 504 | scaler.step(optimizer) 505 | scaler.update() 506 | optimizer.zero_grad(set_to_none=True) # Reset gradients after optimization step 507 | 508 | total_loss += loss.item() * GRADIENT_ACCUMULATE_EVERY # Undo the normalization for logging 509 | 510 | if num_steps % PRINT_STATS_EVERY == 0: 511 | print(f'Training Loss: {total_loss / (batch_idx + 1)}, Accuracy: {acc.item()}') 512 | 513 | num_steps += 1 514 | 515 | # Validation loop 516 | model.eval() 517 | with torch.no_grad(): 518 | for i, (src, trg) in enumerate(val_loader): 519 | src = src.cuda() 520 | trg = trg.cuda() 521 | src_mask = src.bool().cuda() # create a mask for the source sequence 522 | with torch.cuda.amp.autocast(): 523 | loss, acc = model(src, tgt, mask=src_mask) # forward pass 524 | 525 | print(f'Validation Loss: {loss.item()}, Accuracy: {acc.item()}') 526 | 527 | if i > 10: 528 | break 529 | 530 | avg_loss = total_loss / len(train_loader) # calculate average loss for the epoch 531 | print(f'Epoch {epoch}: Average Loss: {avg_loss}') 532 | 533 | 534 | print('Plotting training loss graph...') 535 | 536 | tr_loss_list = train_losses 537 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 538 | plt.show() 539 | plt.close() 540 | print('Done!') 541 | 542 | print('Plotting training acc graph...') 543 | 544 | tr_loss_list = train_accs 545 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 546 | plt.show() 547 | plt.close() 548 | print('Done!') 549 | 550 | """# EVAL""" 551 | 552 | # Validation loop 553 | model.eval() 554 | with torch.no_grad(): 555 | for i, (src, trg) in enumerate(test_loader): 556 | src = src.cuda() 557 | trg = trg.cuda() 558 | src_mask = src.bool().cuda() # create a mask for the source sequence 559 | with torch.cuda.amp.autocast(): 560 | loss, acc = model(src, tgt, mask=src_mask) # forward pass 561 | 562 | print(f'Validation Loss: {loss.item()}, Accuracy: {acc.item()}') 563 | 564 | """#===========================================================""" 565 | 566 | # Convering MIDI to ms score with MIDI.py module 567 | 568 | midi_file = '/content/tegridy-tools/tegridy-tools/seed-melody.mid' 569 | 570 | score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False) 571 | 572 | # INSTRUMENTS CONVERSION CYCLE 573 | events_matrix = [] 574 | itrack = 1 575 | patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 576 | 577 | while itrack < len(score): 578 | for event in score[itrack]: 579 | if event[0] == 'note' or event[0] == 'patch_change': 580 | events_matrix.append(event) 581 | itrack += 1 582 | 583 | events_matrix.sort(key=lambda x: x[1]) 584 | 585 | events_matrix1 = [] 586 | 587 | for event in events_matrix: 588 | if event[0] == 'patch_change': 589 | patches[event[2]] = event[3] 590 | 591 | if event[0] == 'note': 592 | event.extend([patches[event[3]]]) 593 | 594 | if events_matrix1: 595 | if (event[1] == events_matrix1[-1][1]): 596 | if ([event[3], event[4]] != events_matrix1[-1][3:5]): 597 | events_matrix1.append(event) 598 | else: 599 | events_matrix1.append(event) 600 | 601 | else: 602 | events_matrix1.append(event) 603 | 604 | if len(events_matrix1) > 0: 605 | if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0: 606 | 607 | #======================================================= 608 | # PRE-PROCESSING 609 | 610 | # checking number of instruments in a composition 611 | instruments_list = list(set([y[3] for y in events_matrix1])) 612 | 613 | if len(events_matrix1) > 0: 614 | 615 | #=================================== 616 | # ORIGINAL COMPOSITION 617 | #=================================== 618 | 619 | # Adjusting timings 620 | 621 | for e in events_matrix1: 622 | e[1] = int(e[1] / 16) 623 | e[2] = int(e[2] / 16) 624 | 625 | # Sorting by patch, pitch, then by start-time 626 | 627 | events_matrix1.sort(key=lambda x: x[6]) 628 | events_matrix1.sort(key=lambda x: x[4], reverse=True) 629 | events_matrix1.sort(key=lambda x: x[1]) 630 | 631 | #======================================================= 632 | # FINAL PROCESSING 633 | 634 | #======================================================= 635 | # MAIN PROCESSING CYCLE 636 | #======================================================= 637 | 638 | pe = events_matrix1[0] 639 | 640 | notes = [] 641 | 642 | for e in events_matrix1: 643 | 644 | time = max(0, min(255, (e[1] - pe[1]))) 645 | dur = max(0, min(255, e[2])) 646 | cha = max(0, min(15, e[3])) 647 | ptc = max(1, min(127, e[4])) 648 | 649 | notes.append([time, dur, cha, ptc]) 650 | 651 | pe = e 652 | 653 | mel_pitches = [n[3] for n in notes] 654 | mel_tones = [n[3] % 12 for n in notes] 655 | 656 | # @title Eval the model 657 | dtype = 'float16' 658 | device_type = 'cuda' 659 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 660 | ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) 661 | 662 | model.eval() 663 | 664 | x = torch.tensor(mel_tones, dtype=torch.long, device='cuda')[None, ...] 665 | 666 | prime = [] 667 | 668 | for t in mel_tones[:8]: 669 | prime.append(TMIDIX.ALL_CHORDS.index([t])) 670 | y = torch.tensor(prime, dtype=torch.long, device='cuda')[None, ...] 671 | 672 | #x = torch.tensor([[0]] * 1, dtype=torch.long, device='cuda') 673 | 674 | # run generation 675 | 676 | with ctx: 677 | out = model.generate(x, y, 678 | seq_len=x.shape[1]-9, 679 | temperature=0.85, 680 | return_prime=True, 681 | verbose=True) 682 | 683 | y = out.tolist() 684 | 685 | print('=' * 70) 686 | print(y[0]) 687 | print('=' * 70) 688 | 689 | for c in y[0]: 690 | print(TMIDIX.ALL_CHORDS[c]) 691 | 692 | mel_tones[:8] 693 | 694 | #@title Test model output 695 | 696 | train_data1 = y[0] 697 | 698 | #train_data1 = max(melody_chords_f, key = len) 699 | 700 | print('Sample INTs', train_data1[:15]) 701 | 702 | out = train_data1 703 | 704 | patches = [0] * 16 705 | patches[3] = 40 706 | 707 | if len(out) != 0: 708 | 709 | song = out 710 | song_f = [] 711 | 712 | time = 0 713 | dur = 0 714 | vel = 90 715 | pitch = 0 716 | channel = 0 717 | 718 | mel_idx = 0 719 | 720 | for ss in song: 721 | 722 | time += notes[mel_idx][0] * 16 723 | dur = notes[mel_idx][1] * 16 724 | pitch = notes[mel_idx][3] 725 | channel = 3 726 | song_f.append(['note', time, dur, channel, pitch, vel ]) 727 | 728 | chord = TMIDIX.ALL_CHORDS[ss] 729 | for c in chord: 730 | pitch = 48+c 731 | channel = 0 732 | song_f.append(['note', time, dur, channel, pitch, vel ]) 733 | 734 | mel_idx += 1 735 | 736 | detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 737 | output_signature = 'Experimental Music Transformer', 738 | output_file_name = '/content/Experimental-Music-Transformer-Composition', 739 | track_name='Project Los Angeles', 740 | list_of_MIDI_patches=patches 741 | ) 742 | 743 | torch.save(model.state_dict(), '/content/model.pth') 744 | 745 | """# (TOKENS EMBEDDINGS)""" 746 | 747 | # @title Explore model tokens embeddings 748 | tok_emb = model.decoder.net.token_emb.emb.weight.detach().cpu().tolist() 749 | 750 | cos_sim = metrics.pairwise_distances( 751 | tok_emb, metric='cosine' 752 | ) 753 | plt.figure(figsize=(7, 7)) 754 | plt.imshow(cos_sim, cmap="inferno", interpolation="nearest") 755 | im_ratio = cos_sim.shape[0] / cos_sim.shape[1] 756 | plt.colorbar(fraction=0.046 * im_ratio, pad=0.04) 757 | plt.xlabel("Position") 758 | plt.ylabel("Position") 759 | plt.tight_layout() 760 | plt.plot() 761 | plt.savefig("/content/Experimental-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight") 762 | 763 | """# Congrats! You did it! :)""" -------------------------------------------------------------------------------- /experimental_music_transformer_version_2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Experimental_Music_Transformer_Version_2.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1h6bSAa7rDfd6cE_UFPnIjLGc03M2Pgpv 8 | 9 | # Experimental Music Transformer Version 2 (ver. 0.1) 10 | 11 | *** 12 | 13 | Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools 14 | 15 | *** 16 | 17 | WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/ 18 | 19 | *** 20 | 21 | #### Project Los Angeles 22 | 23 | #### Tegridy Code 2023 24 | 25 | *** 26 | 27 | # (SETUP ENVIRONMENT) 28 | """ 29 | 30 | !nvidia-smi 31 | 32 | #@title Install all dependencies (run only once per session) 33 | 34 | !git clone https://github.com/asigalov61/tegridy-tools 35 | 36 | #@title Import all needed modules 37 | 38 | print('Loading needed modules. Please wait...') 39 | import os 40 | 41 | import math 42 | import statistics 43 | import random 44 | 45 | from collections import Counter 46 | 47 | from tqdm import tqdm 48 | 49 | if not os.path.exists('/content/Dataset'): 50 | os.makedirs('/content/Dataset') 51 | 52 | print('Loading TMIDIX module...') 53 | os.chdir('/content/tegridy-tools/tegridy-tools') 54 | 55 | import TMIDIX 56 | 57 | from joblib import Parallel, delayed, parallel_config 58 | 59 | print('Done!') 60 | 61 | os.chdir('/content/') 62 | print('Enjoy! :)') 63 | 64 | """# (DOWNLOAD AND UNZIP DATASETS)""" 65 | 66 | # Commented out IPython magic to ensure Python compatibility. 67 | # @title MIDI Dataset 68 | # %cd /content/Dataset 69 | !wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Mono-Melodies/Piano-Violin/Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.001 70 | !wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Mono-Melodies/Piano-Violin/Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.002 71 | !cat Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip* > Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip 72 | !unzip Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip 73 | !rm Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.001 74 | !rm Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.002 75 | # %cd /content/ 76 | 77 | """# (FILE LIST) 78 | 79 | # (PROCESS) 80 | """ 81 | 82 | #@title TMIDIX MIDI Processor 83 | 84 | print('=' * 70) 85 | print('Loading TMIDIX MIDI Processor...') 86 | print('=' * 70) 87 | 88 | def TMIDIX_MIDI_Processor(midi_file): 89 | 90 | melody_chords = [] 91 | 92 | try: 93 | 94 | fn = os.path.basename(midi_file) 95 | 96 | # Filtering out GIANT4 MIDIs 97 | file_size = os.path.getsize(midi_file) 98 | 99 | if file_size <= 1000000: 100 | 101 | #======================================================= 102 | # START PROCESSING 103 | 104 | # Convering MIDI to ms score with MIDI.py module 105 | score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False) 106 | 107 | # INSTRUMENTS CONVERSION CYCLE 108 | events_matrix = [] 109 | itrack = 1 110 | patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 111 | 112 | while itrack < len(score): 113 | for event in score[itrack]: 114 | if event[0] == 'note' or event[0] == 'patch_change': 115 | events_matrix.append(event) 116 | itrack += 1 117 | 118 | events_matrix.sort(key=lambda x: x[1]) 119 | 120 | events_matrix1 = [] 121 | 122 | for event in events_matrix: 123 | if event[0] == 'patch_change': 124 | patches[event[2]] = event[3] 125 | 126 | if event[0] == 'note': 127 | event.extend([patches[event[3]]]) 128 | 129 | if events_matrix1: 130 | if (event[1] == events_matrix1[-1][1]): 131 | if ([event[3], event[4]] != events_matrix1[-1][3:5]): 132 | events_matrix1.append(event) 133 | else: 134 | events_matrix1.append(event) 135 | 136 | else: 137 | events_matrix1.append(event) 138 | 139 | if len(events_matrix1) > 0: 140 | if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0: 141 | 142 | #======================================================= 143 | # PRE-PROCESSING 144 | 145 | # checking number of instruments in a composition 146 | instruments_list = list(set([y[3] for y in events_matrix1])) 147 | 148 | if len(events_matrix1) > 0: 149 | 150 | #=================================== 151 | # ORIGINAL COMPOSITION 152 | #=================================== 153 | 154 | # Adjusting timings 155 | 156 | for e in events_matrix1: 157 | e[1] = int(e[1] / 16) 158 | e[2] = int(e[2] / 16) 159 | 160 | # Sorting by patch, pitch, then by start-time 161 | 162 | events_matrix1.sort(key=lambda x: x[6]) 163 | events_matrix1.sort(key=lambda x: x[4], reverse=True) 164 | events_matrix1.sort(key=lambda x: x[1]) 165 | 166 | #======================================================= 167 | # FINAL PROCESSING 168 | 169 | #======================================================= 170 | # MAIN PROCESSING CYCLE 171 | #======================================================= 172 | 173 | pe = events_matrix1[0] 174 | 175 | notes = [] 176 | 177 | for e in events_matrix1: 178 | 179 | time = max(0, min(127, (e[1] - pe[1]))) 180 | dur = max(0, min(127, e[2])) 181 | cha = max(0, min(15, e[3])) 182 | ptc = max(1, min(127, e[4])) 183 | 184 | if cha == 3: 185 | cha = 1 186 | 187 | notes.append([time, dur, cha, ptc]) 188 | 189 | pe = e 190 | 191 | return notes 192 | 193 | except: 194 | return None 195 | 196 | print('Done!') 197 | print('=' * 70) 198 | 199 | #@title Save file list 200 | ########### 201 | 202 | print('=' * 70) 203 | print('Loading MIDI files...') 204 | print('This may take a while on a large dataset in particular.') 205 | 206 | dataset_addr = "/content/Dataset" 207 | 208 | # os.chdir(dataset_addr) 209 | filez = list() 210 | for (dirpath, dirnames, filenames) in os.walk(dataset_addr): 211 | filez += [os.path.join(dirpath, file) for file in filenames] 212 | print('=' * 70) 213 | 214 | if not filez: 215 | print('Could not find any MIDI files. Please check Dataset dir...') 216 | print('=' * 70) 217 | 218 | else: 219 | print('Randomizing file list...') 220 | random.shuffle(filez) 221 | print('Done!') 222 | print('=' * 70) 223 | print('Total files:', len(filez)) 224 | print('=' * 70) 225 | 226 | #@title Process MIDIs with TMIDIX MIDI processor 227 | 228 | print('=' * 70) 229 | print('TMIDIX MIDI Processor') 230 | print('=' * 70) 231 | print('Starting up...') 232 | print('=' * 70) 233 | 234 | ########### 235 | 236 | melody_chords_f = [] 237 | 238 | print('Processing MIDI files. Please wait...') 239 | print('=' * 70) 240 | 241 | for i in tqdm(range(0, len(filez), 16)): 242 | 243 | with parallel_config(backend='threading', n_jobs=16, verbose = 0): 244 | 245 | output = Parallel()(delayed(TMIDIX_MIDI_Processor)(f) for f in filez[i:i+16]) 246 | 247 | for o in output: 248 | 249 | if o is not None: 250 | melody_chords_f.append(o) 251 | 252 | print('Done!') 253 | print('=' * 70) 254 | 255 | melody_chords_f[1] 256 | 257 | TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/Processed_MIDIs') 258 | 259 | melody_chords_f = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/Processed_MIDIs') 260 | 261 | #@title Test INTs 262 | 263 | train_data1 = melody_chords_f[4] 264 | 265 | #train_data1 = max(melody_chords_f, key = len) 266 | 267 | print('Sample INTs', train_data1[:15]) 268 | 269 | out = train_data1 270 | 271 | patches = [0] * 16 272 | patches[1] = 40 273 | 274 | if len(out) != 0: 275 | 276 | song = out 277 | song_f = [] 278 | 279 | time = 0 280 | dur = 0 281 | vel = 90 282 | pitch = 0 283 | channel = 0 284 | 285 | 286 | for s in song: 287 | 288 | 289 | time += s[0] * 16 290 | dur = s[1] * 16 291 | channel = s[2] 292 | pitch = s[3] 293 | 294 | 295 | song_f.append(['note', time, dur, channel, pitch, vel ]) 296 | 297 | 298 | 299 | detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 300 | output_signature = 'Experimental Music Transformer', 301 | output_file_name = '/content/Experimental-Music-Transformer-Composition', 302 | track_name='Project Los Angeles', 303 | list_of_MIDI_patches=patches 304 | ) 305 | 306 | print('Done!') 307 | 308 | len(melody_chords_f[0]) 309 | 310 | """# (TRAIN DATA)""" 311 | 312 | train_data = [] 313 | 314 | for m in tqdm.tqdm(melody_chords_f): 315 | 316 | cha = m[0][2] 317 | 318 | dat = [1025, ((cha * 128) + m[0][3])+640, 0] 319 | 320 | for mm in m: 321 | 322 | cha = mm[2] 323 | 324 | if mm[0] != 0: 325 | dat.extend([mm[0], mm[1]+128, ((cha * 128) + mm[3])+256]) 326 | else: 327 | dat.extend([mm[1]+128, ((cha * 128) + mm[3])+256]) 328 | 329 | dat = dat[:1025] 330 | 331 | ids = [] # 0 - 256 and 640 - 1024 332 | nums = [] # 256 - 640 333 | masks = [] # 1024 334 | 335 | for d in dat: 336 | if 0 <= d < 256: 337 | ids.append(d) 338 | nums.append(-1) 339 | masks.append(False) 340 | 341 | if 256 <= d < 640: 342 | ids.append(1024) 343 | nums.append(d) 344 | masks.append(True) 345 | 346 | ids += [1026] * (1025 - len(ids)) 347 | nums += [-1] * (1025 - len(nums)) 348 | masks += [False] * (1025 - len(masks)) 349 | 350 | train_data.append([ids, nums, masks]) 351 | 352 | # Total dict size 1027 353 | 354 | len(train_data), max(train_data, key=len) == min(train_data, key=len) 355 | 356 | train_data[555][:8] 357 | 358 | random.shuffle(train_data) 359 | 360 | TMIDIX.Tegridy_Any_Pickle_File_Writer(train_data, '/content/INTs') 361 | 362 | train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/INTs) 363 | 364 | len(max(train_data[0], key=len)), len(min(train_data[0], key=len)) 365 | 366 | """# (TRAIN MODEL)""" 367 | 368 | !pip install x-transformers 369 | !pip install einops 370 | !pip install torch-summary 371 | 372 | import torch 373 | torch.__version__ 374 | 375 | # Commented out IPython magic to ensure Python compatibility. 376 | # Load modules and make data dir 377 | 378 | print('Loading modules...') 379 | 380 | import os 381 | import pickle 382 | import random 383 | import secrets 384 | import tqdm 385 | import math 386 | import torch 387 | import torch.optim as optim 388 | from torch.utils.data import DataLoader, Dataset 389 | 390 | import matplotlib.pyplot as plt 391 | 392 | from torchsummary import summary 393 | from sklearn import metrics 394 | 395 | # %cd /content/tegridy-tools/tegridy-tools/ 396 | 397 | import TMIDIX 398 | 399 | # %cd /content/tegridy-tools/tegridy-tools/X-Transformer 400 | 401 | from x_transformers import ( 402 | Decoder, 403 | XValTransformerWrapper, 404 | XValAutoregressiveWrapper 405 | ) 406 | 407 | torch.set_float32_matmul_precision('high') 408 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 409 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 410 | 411 | # %cd /content/ 412 | 413 | if not os.path.exists('/content/INTS'): 414 | os.makedirs('/content/INTS') 415 | 416 | import random 417 | 418 | print('Done') 419 | 420 | train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/INTs') 421 | 422 | len(train_data) // 8 423 | 424 | # @title Setup and init the model 425 | 426 | # constants 427 | 428 | SEQ_LEN = 8192 # Models seq len 429 | PAD_IDX = 1026 # Models pad index 430 | 431 | BATCH_SIZE = 4 432 | NUM_EPOCHS = 100 433 | GRADIENT_ACCUMULATE_EVERY = 4 434 | 435 | 436 | LEARNING_RATE = 2e-4 437 | 438 | VALIDATE_EVERY = 100 439 | SAVE_EVERY = 500 440 | GENERATE_EVERY = 100 441 | PRINT_STATS_EVERY = 20 442 | 443 | GENERATE_LENGTH = 32 444 | 445 | # helpers 446 | 447 | def cycle(loader): 448 | while True: 449 | for data in loader: 450 | yield data 451 | 452 | # instantiate the model 453 | 454 | model = XValTransformerWrapper( 455 | num_tokens = 1027, 456 | numerical_token_id = 1024, 457 | max_seq_len = 1024, 458 | attn_layers = Decoder( 459 | dim = 1024, 460 | depth = 8, 461 | heads = 8, 462 | 463 | ) 464 | ) 465 | 466 | # wrap it with the xval autoregressive wrapper 467 | 468 | model = XValAutoregressiveWrapper(model, ignore_index=PAD_IDX) 469 | 470 | model.cuda() 471 | 472 | print('Done!') 473 | 474 | summary(model) 475 | 476 | # Dataloader 477 | 478 | class MusicDataset(Dataset): 479 | def __init__(self, data, seq_len): 480 | super().__init__() 481 | self.data = data 482 | self.seq_len = seq_len 483 | 484 | def __getitem__(self, index): 485 | 486 | ids = torch.Tensor(self.data[index][0][:self.seq_len+1]).long() 487 | nums = torch.Tensor(self.data[index][1][:self.seq_len+1]).long() 488 | masks = torch.Tensor(self.data[index][2][:self.seq_len+1]).bool() 489 | 490 | return ids.cuda(), nums.cuda(), masks.cuda() 491 | 492 | def __len__(self): 493 | return (len(self.data) // BATCH_SIZE) * BATCH_SIZE 494 | 495 | # precision/optimizer/scaler 496 | 497 | dtype = torch.float16 498 | 499 | ctx = torch.amp.autocast(device_type='cuda', dtype=dtype) 500 | 501 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 502 | 503 | scaler = torch.cuda.amp.GradScaler(enabled=True) 504 | 505 | random.shuffle(train_data) 506 | 507 | train_dataset = MusicDataset(train_data, SEQ_LEN) 508 | val_dataset = MusicDataset(train_data, SEQ_LEN) 509 | train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE) 510 | val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE) 511 | 512 | # @title Train the model 513 | torch.cuda.empty_cache() 514 | train_losses = [] 515 | val_losses = [] 516 | 517 | train_accs = [] 518 | val_accs = [] 519 | 520 | nsteps = 0 521 | 522 | PRINT_STATS_EVERY = 200 523 | 524 | for epoch in range(NUM_EPOCHS): # replace NUM_EPOCHS with the actual number of epochs 525 | 526 | print('=' * 70) 527 | print('Epoch #', epoch) 528 | print('=' * 70) 529 | model.train() # set the model to training mode 530 | total_loss = 0 531 | optimizer.zero_grad(set_to_none=True) # Initialize gradients to zero at the start of the epoch 532 | 533 | for batch_idx, batch in enumerate(tqdm.tqdm(train_loader)): # iterate over batches of data 534 | ids, nums, masks = batch # unpack the source and target tensors from the current batch 535 | 536 | with torch.cuda.amp.autocast(): 537 | loss = model(ids, nums, mask=masks) # forward pass 538 | 539 | # loss = loss / GRADIENT_ACCUMULATE_EVERY # Normalize the loss by the number of accumulation steps 540 | # scaler.scale(loss).backward() # Backward pass with gradient scaling 541 | 542 | train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY) 543 | # train_accs.append(acc.mean().item()) 544 | 545 | 546 | if (batch_idx + 1) % GRADIENT_ACCUMULATE_EVERY == 0: # Perform optimization step after accumulating gradients 547 | # scaler.unscale_(optimizer) 548 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 549 | #scaler.step(optimizer) 550 | #scaler.update() 551 | optimizer.step() 552 | optimizer.zero_grad(set_to_none=True) # Reset gradients after optimization step 553 | 554 | total_loss += loss.item() * GRADIENT_ACCUMULATE_EVERY # Undo the normalization for logging 555 | 556 | if nsteps % PRINT_STATS_EVERY == 0: 557 | # print(f'Training Loss: {total_loss / (batch_idx + 1)}, Accuracy: {acc.item()}') 558 | print(f'Training Loss: {total_loss / (batch_idx + 1)}') 559 | 560 | 561 | 562 | nsteps += 1 563 | 564 | '''if i % VALIDATE_EVERY == 0: 565 | model.eval() 566 | with torch.no_grad(): 567 | with ctx: 568 | val_loss, val_acc = model(next(val_loader)) 569 | 570 | print(f'Validation loss: {val_loss.mean().item()}') 571 | print(f'Validation acc: {val_acc.mean().item()}') 572 | 573 | val_losses.append(val_loss.mean().item()) 574 | val_accs.append(val_acc.mean().item()) 575 | 576 | print('Plotting training loss graph...') 577 | 578 | tr_loss_list = train_losses 579 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 580 | plt.show() 581 | plt.close() 582 | print('Done!') 583 | 584 | print('Plotting training acc graph...') 585 | 586 | tr_loss_list = train_accs 587 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 588 | plt.show() 589 | plt.close() 590 | print('Done!') 591 | 592 | print('Plotting validation loss graph...') 593 | tr_loss_list = val_losses 594 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 595 | plt.show() 596 | plt.close() 597 | print('Done!') 598 | 599 | print('Plotting validation acc graph...') 600 | tr_loss_list = val_accs 601 | plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b') 602 | plt.show() 603 | plt.close() 604 | print('Done!')''' 605 | 606 | '''if i % GENERATE_EVERY == 0: 607 | model.eval() 608 | 609 | inp = random.choice(val_dataset)[:-1] 610 | 611 | print(inp) 612 | 613 | with ctx: 614 | 615 | sample = model.generate(inp[None, ...], GENERATE_LENGTH) 616 | 617 | print(sample)''' 618 | 619 | if i % SAVE_EVERY == 0: 620 | 621 | print('Saving model progress. Please wait...') 622 | print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth') 623 | 624 | fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth' 625 | 626 | torch.save(model.state_dict(), fname) 627 | 628 | data = [train_losses, train_accs, val_losses, val_accs] 629 | 630 | TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs') 631 | 632 | print('Done!') 633 | 634 | #====================================================================================================== 635 | 636 | print('Saving model progress. Please wait...') 637 | print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth') 638 | 639 | fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth' 640 | 641 | torch.save(model.state_dict(), fname) 642 | 643 | print('Done!') 644 | 645 | data = [train_losses, train_accs, val_losses, val_accs] 646 | 647 | TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies') 648 | 649 | # Save training loss graph 650 | 651 | plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b') 652 | plt.savefig('/content/training_loss_graph.png') 653 | plt.close() 654 | print('Done!') 655 | 656 | # Save training acc graph 657 | 658 | plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b') 659 | plt.savefig('/content/training_acc_graph.png') 660 | plt.close() 661 | print('Done!') 662 | 663 | # Save validation loss graph 664 | 665 | plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b') 666 | plt.savefig('/content/validation_loss_graph.png') 667 | plt.close() 668 | print('Done!') 669 | 670 | # Save validation acc graph 671 | 672 | plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b') 673 | plt.savefig('/content/validation_acc_graph.png') 674 | plt.close() 675 | print('Done!') 676 | 677 | print('Saving model progress. Please wait...') 678 | print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth') 679 | 680 | fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth' 681 | 682 | torch.save(model.state_dict(), fname) 683 | 684 | print('Done!') 685 | 686 | data = [train_losses, train_accs, val_losses, val_accs] 687 | 688 | TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies') 689 | 690 | # Save training loss graph 691 | 692 | plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b') 693 | plt.savefig('/content/training_loss_graph.png') 694 | plt.close() 695 | print('Done!') 696 | 697 | # Save training acc graph 698 | 699 | plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b') 700 | plt.savefig('/content/training_acc_graph.png') 701 | plt.close() 702 | print('Done!') 703 | 704 | # Save validation loss graph 705 | 706 | plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b') 707 | plt.savefig('/content/validation_loss_graph.png') 708 | plt.close() 709 | print('Done!') 710 | 711 | # Save validation acc graph 712 | 713 | plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b') 714 | plt.savefig('/content/validation_acc_graph.png') 715 | plt.close() 716 | print('Done!') 717 | 718 | """# EVAL""" 719 | 720 | dtype = 'float16' 721 | device_type = 'cuda' 722 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 723 | ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) 724 | 725 | model.eval() 726 | 727 | x = torch.tensor(train_data[2][:900], dtype=torch.long, device='cuda')[None, ...] 728 | #x = torch.tensor([[1024]] * 1, dtype=torch.long, device='cuda') 729 | 730 | # run generation 731 | 732 | with ctx: 733 | out = model.generate(x, 734 | 1023, 735 | temperature=1, 736 | return_prime=False, 737 | verbose=True) 738 | 739 | y = out.tolist() 740 | 741 | print('---------------') 742 | 743 | print(y) 744 | 745 | #@title Test INTs 746 | 747 | train_data1 = out3 # y[0] 748 | 749 | #train_data1 = max(melody_chords_f, key = len) 750 | 751 | print('Sample INTs', train_data1[:15]) 752 | 753 | out = train_data1 754 | 755 | patches = [0] * 16 756 | patches[3] = 40 757 | 758 | if len(out) != 0: 759 | 760 | song = out 761 | song_f = [] 762 | 763 | time = 0 764 | dur = 0 765 | vel = 90 766 | pitch = 0 767 | channel = 0 768 | 769 | for ss in tqdm.tqdm(song): 770 | 771 | if 0 <= ss < 256: 772 | 773 | time += (ss * 16) 774 | 775 | if 256 <= ss < 512: 776 | 777 | dur = (ss-256) * 16 778 | 779 | if 512 <= ss < 640: 780 | 781 | pitch = ss-512 782 | 783 | if 640 <= ss < 642: 784 | 785 | channel = ss-640 786 | 787 | if channel == 1: 788 | channel = 3 789 | 790 | if 642 <= ss < 770: 791 | vel = ss-642 792 | 793 | song_f.append(['note', time, dur, channel, pitch, vel ]) 794 | 795 | detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 796 | output_signature = 'Experimental Music Transformer', 797 | output_file_name = '/content/Experimental-Music-Transformer-Composition', 798 | track_name='Project Los Angeles', 799 | list_of_MIDI_patches=patches 800 | ) 801 | 802 | print('Done!') 803 | 804 | tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist() 805 | 806 | cos_sim = metrics.pairwise_distances( 807 | tok_emb, metric='cosine' 808 | ) 809 | plt.figure(figsize=(7, 7)) 810 | plt.imshow(cos_sim, cmap="inferno", interpolation="nearest") 811 | im_ratio = cos_sim.shape[0] / cos_sim.shape[1] 812 | plt.colorbar(fraction=0.046 * im_ratio, pad=0.04) 813 | plt.xlabel("Position") 814 | plt.ylabel("Position") 815 | plt.tight_layout() 816 | plt.plot() 817 | plt.savefig("/content/Experimental-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight") 818 | 819 | """# Congrats! You did it! :)""" -------------------------------------------------------------------------------- /Training-Data/experimental_music_transformer_training_dataset_maker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Experimental_Music_Transformer_Training_Dataset_Maker.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/github/asigalov61/Experimental-Music-Transformer/blob/main/Training-Data/Experimental_Music_Transformer_Training_Dataset_Maker.ipynb 8 | 9 | # Experimental Music Transformer Training Dataset Maker (ver. 1.0) 10 | 11 | *** 12 | 13 | Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools 14 | 15 | *** 16 | 17 | #### Project Los Angeles 18 | 19 | #### Tegridy Code 2023 20 | 21 | *** 22 | 23 | # (SETUP ENVIRONMENT) 24 | """ 25 | 26 | #@title Install all dependencies (run only once per session) 27 | 28 | !git clone https://github.com/asigalov61/tegridy-tools 29 | !pip install tqdm 30 | 31 | #@title Import all needed modules 32 | 33 | print('Loading needed modules. Please wait...') 34 | import os 35 | import copy 36 | import math 37 | import statistics 38 | import random 39 | 40 | from tqdm import tqdm 41 | 42 | if not os.path.exists('/content/Dataset'): 43 | os.makedirs('/content/Dataset') 44 | 45 | print('Loading TMIDIX module...') 46 | os.chdir('/content/tegridy-tools/tegridy-tools') 47 | 48 | import TMIDIX 49 | 50 | from joblib import Parallel, delayed 51 | 52 | print('Done!') 53 | 54 | os.chdir('/content/') 55 | print('Enjoy! :)') 56 | 57 | """# (DOWNLOAD SOURCE MIDI DATASET)""" 58 | 59 | # Commented out IPython magic to ensure Python compatibility. 60 | #@title Download original LAKH MIDI Dataset 61 | 62 | # %cd /content/Dataset/ 63 | 64 | !wget 'http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz' 65 | !tar -xvf 'lmd_full.tar.gz' 66 | !rm 'lmd_full.tar.gz' 67 | 68 | # %cd /content/ 69 | 70 | #@title Mount Google Drive 71 | from google.colab import drive 72 | drive.mount('/content/drive') 73 | 74 | """# (FILE LIST)""" 75 | 76 | #@title Save file list 77 | ########### 78 | 79 | print('Loading MIDI files...') 80 | print('This may take a while on a large dataset in particular.') 81 | 82 | dataset_addr = "/content/Dataset" 83 | # os.chdir(dataset_addr) 84 | filez = list() 85 | for (dirpath, dirnames, filenames) in os.walk(dataset_addr): 86 | filez += [os.path.join(dirpath, file) for file in filenames] 87 | print('=' * 70) 88 | 89 | if filez == []: 90 | print('Could not find any MIDI files. Please check Dataset dir...') 91 | print('=' * 70) 92 | 93 | print('Randomizing file list...') 94 | random.shuffle(filez) 95 | 96 | TMIDIX.Tegridy_Any_Pickle_File_Writer(filez, '/content/drive/MyDrive/filez') 97 | 98 | #@title Load file list 99 | filez = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/drive/MyDrive/filez') 100 | 101 | """# (PROCESS)""" 102 | 103 | #@title Process MIDIs with TMIDIX MIDI processor 104 | 105 | #=============================================================================== 106 | 107 | def TMIDIX_MIDI_Processor(midi_file): 108 | 109 | melody_chords = [] 110 | melody_chords_aug = [] 111 | 112 | try: 113 | 114 | fn = os.path.basename(midi_file) 115 | 116 | # Filtering out EXP MIDIs 117 | file_size = os.path.getsize(midi_file) 118 | 119 | if file_size <= 1000000: 120 | 121 | #======================================================= 122 | # START PROCESSING 123 | 124 | score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False, pass_old_timings_events=True) 125 | 126 | # INSTRUMENTS CONVERSION CYCLE 127 | events_matrix = [] 128 | itrack = 1 129 | patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 130 | 131 | emph_once = False 132 | emphasis_time = 0 133 | 134 | tpq = 0 135 | tempo = 0 136 | time_sig = 0 137 | key_sig = 0 138 | 139 | while itrack < len(score): 140 | for event in score[itrack]: 141 | 142 | if event[0] != 'note': 143 | event.extend([0, 128]) 144 | else: 145 | event.extend([0, 0]) 146 | 147 | if event[0] == 'text_event' or event[0] == 'lyric' or event[0] == 'patch_change' or event[0] == 'time_signature': 148 | event[4] = 128 149 | 150 | events_matrix.append(event) 151 | 152 | itrack += 1 153 | 154 | events_matrix.sort(key=lambda x: x[4], reverse = True) 155 | events_matrix.sort(key=lambda x: x[1]) 156 | 157 | events_matrix1 = [] 158 | 159 | pt = events_matrix[0][1] 160 | 161 | for event in events_matrix: 162 | if event[0] == 'patch_change': 163 | patches[event[2]] = event[3] 164 | 165 | #======================================================================== 166 | # Emphasis 167 | 168 | if event[0] == 'text_event' or event[0] == 'lyric': 169 | emphasis_time = event[1] 170 | emph_once = True 171 | 172 | if event[0] == 'note' and int(event[1] / 8) > int(emphasis_time / 8) and event[1] > pt: 173 | event[7] = 2 174 | emph_once = False 175 | 176 | if event[0] == 'note' and int(event[1] / 8) == int(emphasis_time / 8) and emph_once: 177 | event[7] = 1 178 | emph_once = False 179 | 180 | pt = event[1] 181 | 182 | #======================================================================== 183 | # Tempo 184 | 185 | if event[0] == 'old_tpq': 186 | tpq = event[2] 187 | 188 | if event[0] == 'old_set_tempo': 189 | tempo = event[2] 190 | 191 | #======================================================================== 192 | # Time and key sigs 193 | 194 | if event[0] == 'time_signature': 195 | time_sig = round((event[2] / max(1, event[3])) * 10) 196 | 197 | if event[0] == 'key_signature': 198 | key_sig = (event[3] * 16) + event[2]+8 199 | 200 | #======================================================================== 201 | # Notes 202 | 203 | if event[0] == 'note': 204 | event[6] = patches[event[3]] 205 | event.extend([round(tempo / tpq / 100)]) 206 | event.extend([time_sig]) 207 | event.extend([key_sig]) 208 | 209 | if events_matrix1: 210 | if (event[1] == events_matrix1[-1][1]): 211 | if ([event[3], event[4]] != events_matrix1[-1][3:5]): 212 | events_matrix1.append(event) 213 | else: 214 | events_matrix1.append(event) 215 | 216 | else: 217 | events_matrix1.append(event) 218 | 219 | if len(events_matrix1) > 0: 220 | if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0: 221 | 222 | #======================================================= 223 | # PRE-PROCESSING 224 | 225 | # checking number of instruments in a composition 226 | instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9])) 227 | instruments_list = list(set([y[3] for y in events_matrix1])) 228 | 229 | if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0: 230 | 231 | num_karaoke_events = len([y for y in events_matrix if y[0] == 'text_event' or y[0] == 'lyric']) 232 | 233 | # checking number of karaoke events in a composition 234 | if num_karaoke_events >= 100: 235 | 236 | #======================================================= 237 | # MAIN PROCESSING 238 | #======================================================= 239 | 240 | #======================================================= 241 | # Timings 242 | #======================================================= 243 | 244 | events_matrix2 = [] 245 | 246 | # Recalculating timings 247 | for e in events_matrix1: 248 | 249 | ev = copy.deepcopy(e) 250 | 251 | # Original timings 252 | e[1] = int(e[1] / 8) 253 | e[2] = int(e[2] / 8) 254 | 255 | # Augmented timings (+ 5%) 256 | ev[1] = int((ev[1] * 1.05) / 8) 257 | ev[2] = int((ev[2] * 1.05) / 8) 258 | 259 | events_matrix2.append(ev) 260 | 261 | #=================================== 262 | # ORIGINAL COMPOSITION 263 | #=================================== 264 | 265 | # Sorting by patch, pitch, then by start-time 266 | 267 | events_matrix1.sort(key=lambda x: x[6]) 268 | events_matrix1.sort(key=lambda x: x[4], reverse=True) 269 | events_matrix1.sort(key=lambda x: x[1]) 270 | 271 | #======================================================= 272 | # FINAL PROCESSING 273 | 274 | melody_chords = [] 275 | 276 | # Break between compositions / Intro seq 277 | 278 | if 9 in instruments_list: 279 | drums_present = 8852 # Yes 280 | else: 281 | drums_present = 8851 # No 282 | 283 | if events_matrix1[0][3] != 9: 284 | pat = max(0, min(127, events_matrix1[0][6])) // 8 285 | else: 286 | pat = 16 287 | 288 | ptc = events_matrix1[0][4] 289 | 290 | melody_chords.extend([8998, drums_present, 8853+pat, 8870+ptc]) # Intro seq 291 | 292 | #======================================================= 293 | # PROCESSING CYCLE 294 | #======================================================= 295 | 296 | abs_time = 0 297 | 298 | pbar_time = 0 299 | 300 | pe = events_matrix1[0] 301 | 302 | chords_counter = 1 303 | 304 | time_key_seq = [0, 0, 0] 305 | old_time_key_seq = [0, 0, 0] 306 | 307 | tempo = 0 308 | time_sig = 0 309 | key_sig = 0 310 | 311 | comp_chords_len = len(list(set([y[1] for y in events_matrix1]))) 312 | 313 | for e in events_matrix1: 314 | 315 | #======================================================= 316 | # Timings... 317 | 318 | # Cliping all values... 319 | delta_time = max(0, min(511, e[1]-pe[1])) 320 | abs_time += delta_time 321 | 322 | bar_time = abs_time // 512 323 | bar_time_local = abs_time % 512 324 | 325 | if bar_time >= 1022: 326 | break 327 | 328 | # Durations and channels 329 | 330 | dur = max(0, min(511, e[2])) 331 | cha = max(0, min(15, e[3])) 332 | 333 | # Patches 334 | if cha == 9: # Drums patch will be == 16 335 | pat = 16 336 | 337 | else: 338 | pat = max(0, min(127, e[6])) // 8 339 | 340 | # Pitches 341 | ptc = max(1, min(127, e[4])) 342 | 343 | 344 | # Emphasis 345 | emph = e[7] 346 | 347 | # Velocities 348 | # Calculating octo-velocity 349 | vel = max(8, min(127, e[5])) 350 | velocity = round(vel / 15)-1 351 | 352 | #======================================================= 353 | # Outro seq 354 | 355 | if ((comp_chords_len - chords_counter) == 50) and (delta_time != 0): 356 | out_t = 7810+delta_time 357 | out_p = 8322+ptc 358 | melody_chords.extend([8850, 8850, out_t, out_p]) # outro seq 359 | 360 | #======================================================= 361 | 362 | if time_key_seq[0] != e[8]: # Tempo 363 | time_key_seq[0] = e[8] 364 | 365 | if time_key_seq[1] != e[9]: # Time sig 366 | time_key_seq[1] = e[9] 367 | 368 | if time_key_seq[2] != e[10]: # Key sig 369 | time_key_seq[2] = e[10] 370 | 371 | if time_key_seq != old_time_key_seq: 372 | 373 | old_time_key_seq = time_key_seq 374 | 375 | time_key_seq[0] = max(0, min(254, time_key_seq[0])) + 8451 376 | time_key_seq[1] = max(0, min(128, time_key_seq[1])) + 8706 377 | time_key_seq[2] = max(0, min(16, time_key_seq[2])) + 8834 378 | 379 | melody_chords.extend([8450] + time_key_seq) 380 | 381 | #======================================================= 382 | # Bar counter seq 383 | 384 | if (bar_time > pbar_time) and (delta_time != 0): 385 | bar = 6787+min(1022, (bar_time)) # bar counter seq 386 | bar_t = 7810+bar_time_local 387 | bar_p = 8322+ptc 388 | melody_chords.extend([6787, bar, bar_t, bar_p]) 389 | chords_counter += 1 390 | pbar_time = bar_time 391 | 392 | else: 393 | if delta_time != 0: 394 | chords_counter += 1 395 | 396 | #======================================================= 397 | # FINAL NOTE SEQ 398 | 399 | # Writing final note asynchronously 400 | 401 | dur_vel = (8 * dur) + velocity 402 | pat_ptc = (128 * pat) + ptc 403 | 404 | melody_chords.extend([emph+6784, delta_time, dur_vel+512, pat_ptc+4608]) 405 | 406 | pe = e 407 | 408 | #======================================================= 409 | 410 | melody_chords.extend([8999, 8999, 8999, 8999]) # EOS 411 | 412 | #=================================== 413 | # AUGMENTED COMPOSITION 414 | #=================================== 415 | 416 | # Sorting by patch, pitch, then by start-time 417 | 418 | events_matrix2.sort(key=lambda x: x[6]) 419 | events_matrix2.sort(key=lambda x: x[4], reverse=True) 420 | events_matrix2.sort(key=lambda x: x[1]) 421 | 422 | # Simple pitches augmentation 423 | 424 | ptc_shift = 1 # Shifting up by 1 semi-tone 425 | 426 | for e in events_matrix2: 427 | if e[3] != 9: 428 | e[4] = e[4] + ptc_shift 429 | 430 | #======================================================= 431 | # FINAL PROCESSING 432 | 433 | melody_chords_aug = [] 434 | 435 | # Break between compositions / Intro seq 436 | 437 | if 9 in instruments_list: 438 | drums_present = 8852 # Yes 439 | else: 440 | drums_present = 8851 # No 441 | 442 | if events_matrix2[0][3] != 9: 443 | pat = max(0, min(127, events_matrix2[0][6])) // 8 444 | else: 445 | pat = 16 446 | 447 | ptc = events_matrix2[0][4] 448 | 449 | melody_chords_aug.extend([8998, drums_present, 8853+pat, 8870+ptc]) # Intro seq 450 | 451 | #======================================================= 452 | # PROCESSING CYCLE 453 | #======================================================= 454 | 455 | abs_time = 0 456 | 457 | pbar_time = 0 458 | 459 | pe = events_matrix2[0] 460 | 461 | chords_counter = 1 462 | 463 | time_key_seq = [0, 0, 0] 464 | old_time_key_seq = [0, 0, 0] 465 | 466 | tempo = 0 467 | time_sig = 0 468 | key_sig = 0 469 | 470 | comp_chords_len = len(list(set([y[1] for y in events_matrix2]))) 471 | 472 | for e in events_matrix2: 473 | 474 | #======================================================= 475 | # Timings... 476 | 477 | # Cliping all values... 478 | delta_time = max(0, min(511, e[1]-pe[1])) 479 | abs_time += delta_time 480 | 481 | bar_time = abs_time // 512 482 | bar_time_local = abs_time % 512 483 | 484 | if bar_time >= 1022: 485 | break 486 | 487 | # Durations and channels 488 | 489 | dur = max(0, min(511, e[2])) 490 | cha = max(0, min(15, e[3])) 491 | 492 | # Patches 493 | if cha == 9: # Drums patch will be == 128 494 | pat = 16 495 | 496 | else: 497 | pat = max(0, min(127, e[6])) // 8 498 | 499 | # Pitches 500 | ptc = max(1, min(127, e[4])) 501 | 502 | # Emphasis 503 | emph = e[7] 504 | 505 | # Velocities 506 | # Calculating octo-velocity 507 | vel = max(8, min(127, e[5]-4)) 508 | velocity = round(vel / 15)-1 509 | 510 | #======================================================= 511 | # Outro seq 512 | 513 | if ((comp_chords_len - chords_counter) == 50) and (delta_time != 0): 514 | out_t = 7810+delta_time 515 | out_p = 8322+ptc 516 | melody_chords_aug.extend([8850, 8850, out_t, out_p]) # outro seq 517 | 518 | #======================================================= 519 | 520 | if time_key_seq[0] != e[8]: # Tempo 521 | time_key_seq[0] = e[8] 522 | 523 | if time_key_seq[1] != e[9]: # Time sig 524 | time_key_seq[1] = e[9] 525 | 526 | if time_key_seq[2] != e[10]: # Key sig 527 | time_key_seq[2] = e[10] 528 | 529 | if time_key_seq != old_time_key_seq: 530 | old_time_key_seq = time_key_seq 531 | 532 | time_key_seq[0] = max(0, min(254, time_key_seq[0])) + 8451 533 | time_key_seq[1] = max(0, min(128, time_key_seq[1])) + 8706 534 | time_key_seq[2] = max(0, min(16, time_key_seq[2])) + 8834 535 | 536 | melody_chords_aug.extend([8450] + time_key_seq) 537 | 538 | #======================================================= 539 | # Bar counter seq 540 | 541 | if (bar_time > pbar_time) and (delta_time != 0): 542 | bar = 6787+min(1022, (bar_time)) # bar counter seq 543 | bar_t = 7810+bar_time_local 544 | bar_p = 8322+ptc 545 | melody_chords_aug.extend([6787, bar, bar_t, bar_p]) 546 | chords_counter += 1 547 | pbar_time = bar_time 548 | 549 | else: 550 | if delta_time != 0: 551 | chords_counter += 1 552 | 553 | #======================================================= 554 | # FINAL NOTE SEQ 555 | 556 | # Writing final note asynchronously 557 | 558 | dur_vel = (8 * dur) + velocity 559 | pat_ptc = (128 * pat) + ptc 560 | 561 | melody_chords_aug.extend([emph+6784, delta_time, dur_vel+512, pat_ptc+4608]) 562 | 563 | pe = e 564 | 565 | #======================================================= 566 | 567 | melody_chords_aug.extend([8999, 8999, 8999, 8999]) # EOS 568 | 569 | #======================================================= 570 | 571 | # TOTAL DICTIONARY SIZE 8999+1=9000 572 | 573 | #======================================================= 574 | 575 | return melody_chords, melody_chords_aug 576 | 577 | except Exception as ex: 578 | print('WARNING !!!') 579 | print('=' * 70) 580 | print('Bad MIDI:', f) 581 | print('Error detected:', ex) 582 | print('=' * 70) 583 | return None 584 | 585 | #=============================================================================== 586 | 587 | print('=' * 70) 588 | print('TMIDIX MIDI Processor') 589 | print('=' * 70) 590 | print('Starting up...') 591 | print('=' * 70) 592 | 593 | ########### 594 | 595 | melody_chords_f = [] 596 | melody_chords_f_aug = [] 597 | 598 | files_count = 0 599 | 600 | print('Processing MIDI files. Please wait...') 601 | print('=' * 70) 602 | 603 | for i in tqdm(range(0, len(filez), 16)): 604 | 605 | output = Parallel(n_jobs=4, verbose=0)(delayed(TMIDIX_MIDI_Processor)(fa) for fa in filez[i:i+16]) 606 | 607 | for o in output: 608 | 609 | if o is not None: 610 | melody_chords_f.append(o[0]) 611 | melody_chords_f_aug.append(o[1]) 612 | files_count += 1 613 | 614 | # Saving every 2560 processed files 615 | if files_count % 2560 == 0 and files_count != 0: 616 | print('SAVING !!!') 617 | print('=' * 70) 618 | print('Saving processed files...') 619 | print('=' * 70) 620 | print('Data check:', min(melody_chords_f[0]), '===', max(melody_chords_f[0]), '===', len(list(set(melody_chords_f[0]))), '===', len(melody_chords_f[0])) 621 | print('=' * 70) 622 | print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio') 623 | print('=' * 70) 624 | count = str(files_count) 625 | TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/drive/MyDrive/LAKH_INTs_'+count) 626 | TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f_aug, '/content/drive/MyDrive/LAKH_AUG_INTs_'+count) 627 | 628 | melody_chords_f = [] 629 | melody_chords_f_aug = [] 630 | 631 | print('=' * 70) 632 | 633 | print('FINAL SAVING !!!') 634 | print('=' * 70) 635 | print('Saving processed files...') 636 | print('=' * 70) 637 | print('Data check:', min(melody_chords_f[0]), '===', max(melody_chords_f[0]), '===', len(list(set(melody_chords_f[0]))), '===', len(melody_chords_f[0])) 638 | print('=' * 70) 639 | print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio') 640 | print('=' * 70) 641 | count = str(files_count) 642 | TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/drive/MyDrive/LAKH_INTs_'+count) 643 | TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f_aug, '/content/drive/MyDrive/LAKH_AUG_INTs_'+count) 644 | print('=' * 70) 645 | 646 | """# (TEST INTS)""" 647 | 648 | #@title Test INTs 649 | 650 | train_data1 = random.choice(melody_chords_f + melody_chords_f_aug) 651 | 652 | print('Sample INTs', train_data1[:15]) 653 | 654 | out = train_data1 655 | 656 | if len(out) != 0: 657 | 658 | song = out 659 | song_f = [] 660 | 661 | time = 0 662 | dur = 0 663 | vel = 90 664 | pitch = 0 665 | channel = 0 666 | 667 | for ss in song: 668 | 669 | if 0 <= ss < 512: 670 | 671 | time += ss * 8 672 | 673 | if 512 <= ss < 4608: 674 | 675 | dur = ((ss-512) // 8) * 8 676 | vel = (((ss-512) % 8)+1) * 15 677 | 678 | if 4608 <= ss < 6784: 679 | 680 | patch = (ss-4608) // 128 681 | 682 | if patch == 16: 683 | channel = 9 684 | else: 685 | if 9 <= patch <= 14: 686 | channel = patch + 1 687 | else: 688 | channel = patch 689 | 690 | if patch == 15: 691 | channel = 15 692 | 693 | pitch = (ss-4608) % 128 694 | 695 | if patch == 17: 696 | break 697 | 698 | song_f.append(['note', time, dur, channel, pitch, vel ]) 699 | 700 | detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 701 | output_signature = 'Experimental Music Transformer', 702 | output_file_name = '/content/Experimental-Music-Trnasformer-Composition', 703 | track_name='Project Los Angeles', 704 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 705 | ) 706 | 707 | print('Done!') 708 | 709 | """# Congrats! You did it! :)""" -------------------------------------------------------------------------------- /Training-Code/Experimental_Music_Transformer_Maker.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "VGrGd6__l5ch" 7 | }, 8 | "source": [ 9 | "# Experimental Music Transformer Maker (ver. 1.0)\n", 10 | "\n", 11 | "***\n", 12 | "\n", 13 | "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n", 14 | "\n", 15 | "***\n", 16 | "\n", 17 | "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n", 18 | "\n", 19 | "***\n", 20 | "\n", 21 | "#### Project Los Angeles\n", 22 | "\n", 23 | "#### Tegridy Code 2023\n", 24 | "\n", 25 | "***" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "shLrgoXdl5cj" 32 | }, 33 | "source": [ 34 | "# GPU check" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "id": "X3rABEpKCO02" 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "!nvidia-smi" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": { 51 | "id": "0RcVC4btl5ck" 52 | }, 53 | "source": [ 54 | "# Setup environment" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": { 61 | "id": "viHgEaNACPTs" 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "!git clone https://github.com/asigalov61/tegridy-tools" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "source": [ 71 | "!pip uninstall torch" 72 | ], 73 | "metadata": { 74 | "id": "_As9UafQ2AuS" 75 | }, 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "id": "vK40g6V_BTNj" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "!pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121\n", 88 | "!pip install einops\n", 89 | "!pip install torch-summary\n", 90 | "!pip install tqdm\n", 91 | "!pip install matplotlib" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "id": "DzCOZU_gBiQV" 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "# Load modules and make data dir\n", 103 | "\n", 104 | "print('Loading modules...')\n", 105 | "\n", 106 | "import os\n", 107 | "import pickle\n", 108 | "import random\n", 109 | "import secrets\n", 110 | "import tqdm\n", 111 | "import math\n", 112 | "import torch\n", 113 | "import torch.optim as optim\n", 114 | "from torch.utils.data import DataLoader, Dataset\n", 115 | "\n", 116 | "import matplotlib.pyplot as plt\n", 117 | "\n", 118 | "from torchsummary import summary\n", 119 | "from sklearn import metrics\n", 120 | "\n", 121 | "%cd /content/tegridy-tools/tegridy-tools/\n", 122 | "\n", 123 | "import TMIDIX\n", 124 | "\n", 125 | "%cd /content/tegridy-tools/tegridy-tools/X-Transformer\n", 126 | "\n", 127 | "from x_transformer_1_23_2 import *\n", 128 | "\n", 129 | "torch.set_float32_matmul_precision('high')\n", 130 | "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n", 131 | "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n", 132 | "\n", 133 | "%cd /content/\n", 134 | "\n", 135 | "if not os.path.exists('/content/INTS'):\n", 136 | " os.makedirs('/content/INTS')\n", 137 | "\n", 138 | "import random\n", 139 | "\n", 140 | "print('Done')\n", 141 | "\n", 142 | "print('Torch version:', torch.__version__)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "id": "Sbhzy8FGl5cm" 149 | }, 150 | "source": [ 151 | "# Load training data" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": { 158 | "id": "IdBpL-HUHLBW" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "dataset_addr = \"/content/INTS\"\n", 163 | "\n", 164 | "#==========================================================================\n", 165 | "\n", 166 | "filez = list()\n", 167 | "for (dirpath, dirnames, filenames) in os.walk(dataset_addr):\n", 168 | " filez += [os.path.join(dirpath, file) for file in filenames]\n", 169 | "print('=' * 70)\n", 170 | "\n", 171 | "random.shuffle(filez)\n", 172 | "\n", 173 | "print('Loaded', len(filez), 'data files')\n", 174 | "print('=' * 70)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "VhZqBvqVl5cn" 181 | }, 182 | "source": [ 183 | "# Setup model" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "source": [ 189 | "# Setup model\n", 190 | "\n", 191 | "# constants\n", 192 | "\n", 193 | "NUM_DATA_FILES_TO_LOAD_PER_ITER = 8\n", 194 | "\n", 195 | "SEQ_LEN = 8192 # Models seq len\n", 196 | "PAD_IDX = 9000 # Models pad index\n", 197 | "\n", 198 | "NUM_EPOCHS = 1\n", 199 | "\n", 200 | "BATCH_SIZE = 4\n", 201 | "GRADIENT_ACCUMULATE_EVERY = 4\n", 202 | "\n", 203 | "LEARNING_RATE = 2e-4\n", 204 | "\n", 205 | "VALIDATE_EVERY = 100\n", 206 | "SAVE_EVERY = 500\n", 207 | "GENERATE_EVERY = 250\n", 208 | "GENERATE_LENGTH = 512\n", 209 | "PRINT_STATS_EVERY = 20\n", 210 | "\n", 211 | "# helpers\n", 212 | "\n", 213 | "def cycle(loader):\n", 214 | " while True:\n", 215 | " for data in loader:\n", 216 | " yield data\n", 217 | "\n", 218 | "# instantiate the model\n", 219 | "\n", 220 | "model = TransformerWrapper(\n", 221 | " num_tokens = PAD_IDX+1,\n", 222 | " max_seq_len = SEQ_LEN,\n", 223 | " attn_layers = Decoder(dim = 1024, depth = 20, heads = 16, attn_flash = True)\n", 224 | " )\n", 225 | "\n", 226 | "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)\n", 227 | "\n", 228 | "model = torch.nn.DataParallel(model)\n", 229 | "\n", 230 | "model.cuda()\n", 231 | "\n", 232 | "print('Done!')\n", 233 | "\n", 234 | "summary(model)\n", 235 | "\n", 236 | "# Dataloader\n", 237 | "\n", 238 | "class MusicDataset(Dataset):\n", 239 | " def __init__(self, data, seq_len):\n", 240 | " super().__init__()\n", 241 | " self.data = data\n", 242 | " self.seq_len = seq_len\n", 243 | "\n", 244 | " def __getitem__(self, index):\n", 245 | "\n", 246 | " # consequtive sampling\n", 247 | "\n", 248 | " full_seq = torch.Tensor(self.data[index][:self.seq_len+1]).long()\n", 249 | "\n", 250 | " return full_seq.cuda()\n", 251 | "\n", 252 | " def __len__(self):\n", 253 | " return (len(self.data) // BATCH_SIZE) * BATCH_SIZE\n", 254 | "\n", 255 | "# precision/optimizer/scaler\n", 256 | "\n", 257 | "dtype = torch.float16\n", 258 | "\n", 259 | "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)\n", 260 | "\n", 261 | "optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n", 262 | "\n", 263 | "scaler = torch.cuda.amp.GradScaler(enabled=True)" 264 | ], 265 | "metadata": { 266 | "id": "mfwp06xzzPZ5" 267 | }, 268 | "execution_count": null, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "xJPxxFiwl5cn" 275 | }, 276 | "source": [ 277 | "# Train" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": { 284 | "id": "HETGqz_6K1ml" 285 | }, 286 | "outputs": [], 287 | "source": [ 288 | "# Train the model\n", 289 | "\n", 290 | "CHUNKS_LENGTH = SEQ_LEN+1\n", 291 | "MIN_NUMBER_OF_CHUNK_EVENTS = 512 # min number of tokens per chunk\n", 292 | "\n", 293 | "train_losses = []\n", 294 | "val_losses = []\n", 295 | "\n", 296 | "train_accs = []\n", 297 | "val_accs = []\n", 298 | "\n", 299 | "nsteps = 0\n", 300 | "\n", 301 | "for fa in range(0, len(filez), NUM_DATA_FILES_TO_LOAD_PER_ITER):\n", 302 | "\n", 303 | " #==========================================================================\n", 304 | " print('=' * 70)\n", 305 | " print('Loading data files', fa, '---', fa+NUM_DATA_FILES_TO_LOAD_PER_ITER-1)\n", 306 | " print('Please wait...')\n", 307 | " print('=' * 70)\n", 308 | "\n", 309 | " train_data = []\n", 310 | "\n", 311 | " chunks_counter = 0\n", 312 | " discarted_chunks_counter = 1\n", 313 | "\n", 314 | " for lfa in tqdm.tqdm(filez[fa:fa+NUM_DATA_FILES_TO_LOAD_PER_ITER]):\n", 315 | "\n", 316 | " train_d = pickle.load(open(lfa, 'rb'))\n", 317 | " random.shuffle(train_d)\n", 318 | " for t in train_d:\n", 319 | " for i in range(0, len(t), int((SEQ_LEN * 3) / 4)):\n", 320 | "\n", 321 | " #=========================================================================\n", 322 | " # collecting all possible chunks of chunks length\n", 323 | "\n", 324 | " if 0 <= max(t[i:i+CHUNKS_LENGTH]) < PAD_IDX: # final data integrity check\n", 325 | " if len(t[i:i+CHUNKS_LENGTH]) == CHUNKS_LENGTH:\n", 326 | " train_data.append(t[i:i+CHUNKS_LENGTH])\n", 327 | "\n", 328 | " else:\n", 329 | " if len(t[i:i+CHUNKS_LENGTH]) >= MIN_NUMBER_OF_CHUNK_EVENTS:\n", 330 | " td = t[i:i+CHUNKS_LENGTH] + [PAD_IDX] * (CHUNKS_LENGTH-len(t[i:i+CHUNKS_LENGTH])) # padding with pad index\n", 331 | " train_data.append(td)\n", 332 | " else:\n", 333 | " discarted_chunks_counter += 1\n", 334 | "\n", 335 | " chunks_counter += 1\n", 336 | "\n", 337 | " else:\n", 338 | " print('Bad data!!!')\n", 339 | " break\n", 340 | "\n", 341 | " #=========================================================================\n", 342 | " # Collecting middle chunk if it larger than chunks length\n", 343 | "\n", 344 | " if 0 <= max(t) < PAD_IDX: # final data integrity check\n", 345 | " if len(t) >= SEQ_LEN+8:\n", 346 | " comp_middle = int(len(t) / 8)\n", 347 | " sidx = int((comp_middle * 4)-(SEQ_LEN / 2))\n", 348 | " train_data.append(t[sidx:sidx+CHUNKS_LENGTH])\n", 349 | "\n", 350 | " else:\n", 351 | " discarted_chunks_counter += 1\n", 352 | "\n", 353 | " chunks_counter += 1\n", 354 | "\n", 355 | " else:\n", 356 | " print('Bad data!!!')\n", 357 | " break\n", 358 | "\n", 359 | " #==========================================================================\n", 360 | "\n", 361 | " print('Done!')\n", 362 | " print('=' * 70)\n", 363 | " print('Total number of imput chunks:', chunks_counter)\n", 364 | " print('Total number of good chunks:', len(train_data))\n", 365 | " print('Total number of discarted chunks:', discarted_chunks_counter, '/', round(100 * discarted_chunks_counter/chunks_counter, 3), '%')\n", 366 | " print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len)))\n", 367 | " print('=' * 70)\n", 368 | " print('Final data randomization...')\n", 369 | " random.shuffle(train_data)\n", 370 | " print('Done!')\n", 371 | " print('=' * 70)\n", 372 | "\n", 373 | "\n", 374 | " train_dataset = MusicDataset(train_data, SEQ_LEN)\n", 375 | " val_dataset = MusicDataset(train_data, SEQ_LEN)\n", 376 | " train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))\n", 377 | " val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))\n", 378 | "\n", 379 | " NUM_BATCHES = (len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY) * NUM_EPOCHS\n", 380 | "\n", 381 | " for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):\n", 382 | " model.train()\n", 383 | "\n", 384 | " for __ in range(GRADIENT_ACCUMULATE_EVERY):\n", 385 | " with ctx:\n", 386 | " loss, acc = model(next(train_loader))\n", 387 | " loss = loss / GRADIENT_ACCUMULATE_EVERY\n", 388 | " scaler.scale(loss).backward(torch.ones(loss.shape).cuda())\n", 389 | "\n", 390 | " if i % PRINT_STATS_EVERY == 0:\n", 391 | " print(f'Training loss: {loss.mean().item() * GRADIENT_ACCUMULATE_EVERY}')\n", 392 | " print(f'Training acc: {acc.mean().item()}')\n", 393 | "\n", 394 | " train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY)\n", 395 | " train_accs.append(acc.mean().item())\n", 396 | "\n", 397 | " scaler.unscale_(optim)\n", 398 | " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n", 399 | " scaler.step(optim)\n", 400 | " scaler.update()\n", 401 | " optim.zero_grad(set_to_none=True)\n", 402 | "\n", 403 | " nsteps += 1\n", 404 | "\n", 405 | " if i % VALIDATE_EVERY == 0:\n", 406 | " model.eval()\n", 407 | " with torch.no_grad():\n", 408 | " with ctx:\n", 409 | " val_loss, val_acc = model(next(val_loader))\n", 410 | "\n", 411 | " print(f'Validation loss: {val_loss.mean().item()}')\n", 412 | " print(f'Validation acc: {val_acc.mean().item()}')\n", 413 | "\n", 414 | " val_losses.append(val_loss.mean().item())\n", 415 | " val_accs.append(val_acc.mean().item())\n", 416 | "\n", 417 | " print('Plotting training loss graph...')\n", 418 | "\n", 419 | " tr_loss_list = train_losses\n", 420 | " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", 421 | " plt.show()\n", 422 | " plt.close()\n", 423 | " print('Done!')\n", 424 | "\n", 425 | " print('Plotting training acc graph...')\n", 426 | "\n", 427 | " tr_loss_list = train_accs\n", 428 | " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", 429 | " plt.show()\n", 430 | " plt.close()\n", 431 | " print('Done!')\n", 432 | "\n", 433 | " print('Plotting validation loss graph...')\n", 434 | " tr_loss_list = val_losses\n", 435 | " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", 436 | " plt.show()\n", 437 | " plt.close()\n", 438 | " print('Done!')\n", 439 | "\n", 440 | " print('Plotting validation acc graph...')\n", 441 | " tr_loss_list = val_accs\n", 442 | " plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')\n", 443 | " plt.show()\n", 444 | " plt.close()\n", 445 | " print('Done!')\n", 446 | "\n", 447 | " if i % GENERATE_EVERY == 0:\n", 448 | " model.eval()\n", 449 | "\n", 450 | " inp = random.choice(val_dataset)[:-1]\n", 451 | "\n", 452 | " print(inp)\n", 453 | "\n", 454 | " with ctx:\n", 455 | " sample = model.module.generate(inp[None, ...], GENERATE_LENGTH)\n", 456 | "\n", 457 | " print(sample)\n", 458 | "\n", 459 | " data = sample.tolist()[0]\n", 460 | "\n", 461 | " print('Sample INTs', data[:15])\n", 462 | "\n", 463 | " out = data[:200000]\n", 464 | "\n", 465 | " if len(out) != 0:\n", 466 | "\n", 467 | " song = out\n", 468 | " song_f = []\n", 469 | "\n", 470 | " time = 0\n", 471 | " dur = 0\n", 472 | " vel = 90\n", 473 | " pitch = 0\n", 474 | " channel = 0\n", 475 | "\n", 476 | " for ss in song:\n", 477 | "\n", 478 | " if 0 <= ss < 512:\n", 479 | "\n", 480 | " time += ss * 8\n", 481 | "\n", 482 | " if 512 <= ss < 4608:\n", 483 | "\n", 484 | " dur = ((ss-512) // 8) * 8\n", 485 | " vel = (((ss-512) % 8)+1) * 15\n", 486 | "\n", 487 | " if 4608 <= ss < 6784:\n", 488 | "\n", 489 | " patch = (ss-4608) // 128\n", 490 | "\n", 491 | " if patch == 16:\n", 492 | " channel = 9\n", 493 | " else:\n", 494 | " if 9 <= patch <= 14:\n", 495 | " channel = patch + 1\n", 496 | " else:\n", 497 | " channel = patch\n", 498 | "\n", 499 | " if patch == 15:\n", 500 | " channel = 15\n", 501 | "\n", 502 | " pitch = (ss-4608) % 128\n", 503 | "\n", 504 | " song_f.append(['note', time, dur, channel, pitch, vel ])\n", 505 | "\n", 506 | " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", 507 | " output_signature = 'Experimental Music Transformer',\n", 508 | " output_file_name = '/content/Experimental-Music-Trnasformer-Composition',\n", 509 | " track_name='Project Los Angeles',\n", 510 | " list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117]\n", 511 | " )\n", 512 | "\n", 513 | " print('Done!')\n", 514 | "\n", 515 | " if i % SAVE_EVERY == 0:\n", 516 | "\n", 517 | " print('Saving model progress. Please wait...')\n", 518 | " print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n", 519 | "\n", 520 | " fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n", 521 | "\n", 522 | " torch.save(model.state_dict(), fname)\n", 523 | "\n", 524 | " data = [train_losses, train_accs, val_losses, val_accs]\n", 525 | "\n", 526 | " TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs')\n", 527 | "\n", 528 | " print('Done!')\n", 529 | "\n", 530 | "#======================================================================================================\n", 531 | "\n", 532 | "print('Saving model progress. Please wait...')\n", 533 | "print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n", 534 | "\n", 535 | "fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n", 536 | "\n", 537 | "torch.save(model.state_dict(), fname)\n", 538 | "\n", 539 | "print('Done!')\n", 540 | "\n", 541 | "data = [train_losses, train_accs, val_losses, val_accs]\n", 542 | "\n", 543 | "TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')\n", 544 | "\n", 545 | "# Save training loss graph\n", 546 | "\n", 547 | "plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')\n", 548 | "plt.savefig('/content/training_loss_graph.png')\n", 549 | "plt.close()\n", 550 | "print('Done!')\n", 551 | "\n", 552 | "# Save training acc graph\n", 553 | "\n", 554 | "plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')\n", 555 | "plt.savefig('/content/training_acc_graph.png')\n", 556 | "plt.close()\n", 557 | "print('Done!')\n", 558 | "\n", 559 | "# Save validation loss graph\n", 560 | "\n", 561 | "plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')\n", 562 | "plt.savefig('/content/validation_loss_graph.png')\n", 563 | "plt.close()\n", 564 | "print('Done!')\n", 565 | "\n", 566 | "# Save validation acc graph\n", 567 | "\n", 568 | "plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')\n", 569 | "plt.savefig('/content/validation_acc_graph.png')\n", 570 | "plt.close()\n", 571 | "print('Done!')" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": { 577 | "id": "wBkMH2gWl5co" 578 | }, 579 | "source": [ 580 | "# Final Save" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": null, 586 | "metadata": { 587 | "id": "fCmj4MBmAOjF" 588 | }, 589 | "outputs": [], 590 | "source": [ 591 | "print('Saving model progress. Please wait...')\n", 592 | "print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')\n", 593 | "\n", 594 | "fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'\n", 595 | "\n", 596 | "torch.save(model.state_dict(), fname)\n", 597 | "\n", 598 | "print('Done!')" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "metadata": { 605 | "id": "WwgV2ZA9ndQr" 606 | }, 607 | "outputs": [], 608 | "source": [ 609 | "data = [train_losses, train_accs, val_losses, val_accs]\n", 610 | "\n", 611 | "TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "metadata": { 618 | "id": "4vE5Z15fCz1M" 619 | }, 620 | "outputs": [], 621 | "source": [ 622 | "# Save training loss graph\n", 623 | "\n", 624 | "plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')\n", 625 | "plt.savefig('/content/training_loss_graph.png')\n", 626 | "plt.close()\n", 627 | "print('Done!')\n", 628 | "\n", 629 | "# Save training acc graph\n", 630 | "\n", 631 | "plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')\n", 632 | "plt.savefig('/content/training_acc_graph.png')\n", 633 | "plt.close()\n", 634 | "print('Done!')\n", 635 | "\n", 636 | "# Save validation loss graph\n", 637 | "\n", 638 | "plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')\n", 639 | "plt.savefig('/content/validation_loss_graph.png')\n", 640 | "plt.close()\n", 641 | "print('Done!')\n", 642 | "\n", 643 | "# Save validation acc graph\n", 644 | "\n", 645 | "plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')\n", 646 | "plt.savefig('/content/validation_acc_graph.png')\n", 647 | "plt.close()\n", 648 | "print('Done!')" 649 | ] 650 | }, 651 | { 652 | "cell_type": "markdown", 653 | "metadata": { 654 | "id": "feXay_Ed7mG5" 655 | }, 656 | "source": [ 657 | "# Eval" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": { 664 | "id": "naf65RxUXwDg" 665 | }, 666 | "outputs": [], 667 | "source": [ 668 | "model.eval()\n", 669 | "\n", 670 | "#x = torch.tensor((random.choice(train_data)[:1000], dtype=torch.long, device=device_type)[None, ...])\n", 671 | "x = torch.tensor([[8998, 8851+0, 8853+0, 8870+60]] * 4, dtype=torch.long, device='cuda')\n", 672 | "\n", 673 | "# run generation\n", 674 | "\n", 675 | "with ctx:\n", 676 | " out = model.module.generate(x,\n", 677 | " 500,\n", 678 | " temperature=0.9,\n", 679 | " return_prime=True,\n", 680 | " verbose=True)\n", 681 | "\n", 682 | "y = out.tolist()\n", 683 | "\n", 684 | "print('---------------')" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": null, 690 | "metadata": { 691 | "id": "tlBzqWpAnZna" 692 | }, 693 | "outputs": [], 694 | "source": [ 695 | "#@title Test INTs\n", 696 | "\n", 697 | "data = y[0]\n", 698 | "\n", 699 | "print('Sample INTs', data[:15])\n", 700 | "\n", 701 | "out = data[:200000]\n", 702 | "\n", 703 | "if len(out) != 0:\n", 704 | "\n", 705 | " song = out\n", 706 | " song_f = []\n", 707 | "\n", 708 | " time = 0\n", 709 | " dur = 0\n", 710 | " vel = 90\n", 711 | " pitch = 0\n", 712 | " channel = 0\n", 713 | "\n", 714 | " for ss in song:\n", 715 | "\n", 716 | " if 0 <= ss < 512:\n", 717 | "\n", 718 | " time += ss * 8\n", 719 | "\n", 720 | " if 512 <= ss < 4608:\n", 721 | "\n", 722 | " dur = ((ss-512) // 8) * 8\n", 723 | " vel = (((ss-512) % 8)+1) * 15\n", 724 | "\n", 725 | " if 4608 <= ss < 6784:\n", 726 | "\n", 727 | " patch = (ss-4608) // 128\n", 728 | "\n", 729 | " if patch == 16:\n", 730 | " channel = 9\n", 731 | " else:\n", 732 | " if 9 <= patch <= 14:\n", 733 | " channel = patch + 1\n", 734 | " else:\n", 735 | " channel = patch\n", 736 | "\n", 737 | " if patch == 15:\n", 738 | " channel = 15\n", 739 | "\n", 740 | " pitch = (ss-4608) % 128\n", 741 | "\n", 742 | " if patch == 17:\n", 743 | " break\n", 744 | "\n", 745 | " song_f.append(['note', time, dur, channel, pitch, vel ])\n", 746 | "\n", 747 | "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", 748 | " output_signature = 'Experimental Music Transformer',\n", 749 | " output_file_name = '/content/Experimental-Music-Trnasformer-Composition',\n", 750 | " track_name='Project Los Angeles',\n", 751 | " list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117]\n", 752 | " )\n", 753 | "\n", 754 | "print('Done!')" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": null, 760 | "metadata": { 761 | "id": "LZGayVt6okU9" 762 | }, 763 | "outputs": [], 764 | "source": [ 765 | "patches" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": null, 771 | "metadata": { 772 | "id": "al3TDlH7T8m7" 773 | }, 774 | "outputs": [], 775 | "source": [ 776 | "tok_emb = model.module.net.token_emb.emb.weight.detach().cpu().tolist()\n", 777 | "\n", 778 | "cos_sim = metrics.pairwise_distances(\n", 779 | " tok_emb, metric='cosine'\n", 780 | ")\n", 781 | "plt.figure(figsize=(7, 7))\n", 782 | "plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n", 783 | "im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n", 784 | "plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n", 785 | "plt.xlabel(\"Position\")\n", 786 | "plt.ylabel(\"Position\")\n", 787 | "plt.tight_layout()\n", 788 | "plt.plot()\n", 789 | "plt.savefig(\"/content/Experimental-Music-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")" 790 | ] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "metadata": { 795 | "id": "z87TlDTVl5cp" 796 | }, 797 | "source": [ 798 | "# Congrats! You did it! :)" 799 | ] 800 | } 801 | ], 802 | "metadata": { 803 | "accelerator": "GPU", 804 | "colab": { 805 | "gpuClass": "premium", 806 | "gpuType": "A100", 807 | "machine_shape": "hm", 808 | "private_outputs": true, 809 | "provenance": [] 810 | }, 811 | "kernelspec": { 812 | "display_name": "Python 3", 813 | "language": "python", 814 | "name": "python3" 815 | }, 816 | "language_info": { 817 | "codemirror_mode": { 818 | "name": "ipython", 819 | "version": 3 820 | }, 821 | "file_extension": ".py", 822 | "mimetype": "text/x-python", 823 | "name": "python", 824 | "nbconvert_exporter": "python", 825 | "pygments_lexer": "ipython3", 826 | "version": "3.8.10" 827 | } 828 | }, 829 | "nbformat": 4, 830 | "nbformat_minor": 0 831 | } -------------------------------------------------------------------------------- /experimental_music_transformer_version_1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Experimental_Music_Transformer_Version_1.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1k0wDK63L6iJCBKwaUGLqbvCYGsP8cpOo 8 | 9 | # Experimental Music Transformer Version 1 (ver. 0.5) 10 | 11 | *** 12 | 13 | Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools 14 | 15 | *** 16 | 17 | WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/ 18 | 19 | *** 20 | 21 | #### Project Los Angeles 22 | 23 | #### Tegridy Code 2023 24 | 25 | *** 26 | 27 | # (GPU CHECK) 28 | """ 29 | 30 | #@title NVIDIA GPU check 31 | !nvidia-smi 32 | 33 | """# (SETUP ENVIRONMENT)""" 34 | 35 | #@title Install dependencies 36 | !git clone --depth 1 https://github.com/asigalov61/Experimental-Music-Transformer 37 | !pip install huggingface_hub 38 | !pip install torch 39 | !pip install einops 40 | !pip install torch-summary 41 | !pip install tqdm 42 | !pip install matplotlib 43 | !apt install fluidsynth #Pip does not work for some reason. Only apt works 44 | 45 | # Commented out IPython magic to ensure Python compatibility. 46 | #@title Import modules 47 | 48 | print('=' * 70) 49 | print('Loading core Experimental Music Transformer modules...') 50 | 51 | import os 52 | import copy 53 | import pickle 54 | import secrets 55 | import statistics 56 | from time import time 57 | import tqdm 58 | 59 | print('=' * 70) 60 | print('Loading main Experimental Music Transformer modules...') 61 | import torch 62 | 63 | # %cd /content/Experimental-Music-Transformer 64 | 65 | import TMIDIX 66 | 67 | from midi_to_colab_audio import midi_to_colab_audio 68 | 69 | from x_transformer_1_23_2 import * 70 | 71 | import random 72 | 73 | # %cd /content/ 74 | print('=' * 70) 75 | print('Loading aux Experimental Music Transformer modules...') 76 | 77 | import matplotlib.pyplot as plt 78 | 79 | from torchsummary import summary 80 | from sklearn import metrics 81 | 82 | from IPython.display import Audio, display 83 | 84 | from huggingface_hub import hf_hub_download 85 | 86 | from google.colab import files 87 | 88 | print('=' * 70) 89 | print('Done!') 90 | print('Enjoy! :)') 91 | print('=' * 70) 92 | 93 | """# (LOAD MODEL)""" 94 | 95 | #@title Load Experimental Music Transformer Large Model 96 | 97 | #@markdown Very fast model, 32 layers, 245k MIDIs training corpus 98 | 99 | full_path_to_model_checkpoint = "/content/Experimental-Music-Transformer/Models/Version-1/Experimental_Music_Transformer_Version_1_Large_Trained_Model_18581_steps_0.636_loss_0.823_acc.pth" #@param {type:"string"} 100 | 101 | #@markdown Model precision option 102 | 103 | model_precision = "bfloat16" # @param ["bfloat16", "float16"] 104 | 105 | #@markdown bfloat16 == Half precision/faster speed (if supported, otherwise the model will default to float16) 106 | 107 | #@markdown float16 == Full precision/fast speed 108 | 109 | plot_tokens_embeddings = False # @param {type:"boolean"} 110 | 111 | print('=' * 70) 112 | print('Loading Experimental Music Transformer Large Pre-Trained Model...') 113 | print('Please wait...') 114 | print('=' * 70) 115 | 116 | if os.path.isfile(full_path_to_model_checkpoint): 117 | print('Model already exists...') 118 | 119 | else: 120 | hf_hub_download(repo_id='asigalov61/Experimental-Music-Transformer', 121 | filename='Experimental_Music_Transformer_Version_1_Large_Trained_Model_18581_steps_0.636_loss_0.823_acc.pth', 122 | local_dir='/content/Experimental-Music-Transformer/Models/Version-1', 123 | local_dir_use_symlinks=False) 124 | 125 | print('=' * 70) 126 | print('Instantiating model...') 127 | 128 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 129 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 130 | device_type = 'cuda' 131 | 132 | if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported(): 133 | dtype = 'bfloat16' 134 | else: 135 | dtype = 'float16' 136 | 137 | if model_precision == 'float16': 138 | dtype = 'float16' 139 | 140 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 141 | ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) 142 | 143 | SEQ_LEN = 8192 144 | 145 | # instantiate the model 146 | 147 | model = TransformerWrapper( 148 | num_tokens = 7578, 149 | max_seq_len = SEQ_LEN, 150 | attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True) 151 | ) 152 | 153 | model = AutoregressiveWrapper(model, ignore_index=7577) 154 | 155 | model.cuda() 156 | print('=' * 70) 157 | 158 | print('Loading model checkpoint...') 159 | 160 | model.load_state_dict(torch.load(full_path_to_model_checkpoint)) 161 | print('=' * 70) 162 | 163 | model.eval() 164 | 165 | print('Done!') 166 | print('=' * 70) 167 | 168 | print('Model will use', dtype, 'precision...') 169 | print('=' * 70) 170 | 171 | # Model stats 172 | print('Model summary...') 173 | summary(model) 174 | 175 | # Plot Token Embeddings 176 | if plot_tokens_embeddings: 177 | tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist() 178 | 179 | cos_sim = metrics.pairwise_distances( 180 | tok_emb, metric='cosine' 181 | ) 182 | plt.figure(figsize=(7, 7)) 183 | plt.imshow(cos_sim, cmap="inferno", interpolation="nearest") 184 | im_ratio = cos_sim.shape[0] / cos_sim.shape[1] 185 | plt.colorbar(fraction=0.046 * im_ratio, pad=0.04) 186 | plt.xlabel("Position") 187 | plt.ylabel("Position") 188 | plt.tight_layout() 189 | plt.plot() 190 | plt.savefig("/content/Experimental-Music-Transformer-Large-Tokens-Embeddings-Plot.png", bbox_inches="tight") 191 | 192 | """# (GENERATE) 193 | 194 | # (IMPROV) 195 | """ 196 | 197 | #@title Standard Improv Generator 198 | 199 | #@markdown Improv type 200 | 201 | improv_type = "Random Freestyle" # @param ["Random Freestyle", "Freestyle without Drums", "Freestyle with Drums", "Custom"] 202 | 203 | #@markdown Custom Improv settings 204 | 205 | first_note_MIDI_patch_number = 0 # @param {type:"slider", min:0, max:128, step:1} 206 | first_note_MIDI_pitch_number = 60 # @param {type:"slider", min:1, max:127, step:1} 207 | add_drums = False #@param {type:"boolean"} 208 | 209 | #@markdown Generation settings 210 | 211 | number_of_tokens_tp_generate = 1010 # @param {type:"slider", min:30, max:8190, step:4} 212 | number_of_batches_to_generate = 4 #@param {type:"slider", min:1, max:16, step:1} 213 | temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05} 214 | 215 | #@markdown Other settings 216 | 217 | render_MIDI_to_audio = True # @param {type:"boolean"} 218 | 219 | print('=' * 70) 220 | print('Experimental Music Transformer Standard Improv Model Generator') 221 | print('=' * 70) 222 | 223 | if improv_type == 'Random Freestyle': 224 | 225 | outy = [7575] 226 | 227 | if improv_type == 'Freestyle without Drums': 228 | 229 | outy = [7575, 7428] 230 | 231 | if improv_type == 'Freestyle with Drums': 232 | 233 | outy = [7575, 7429] 234 | 235 | if improv_type == 'Custom': 236 | 237 | if add_drums: 238 | drumsp = 7429 # Yes 239 | else: 240 | drumsp = 7428 # No 241 | 242 | outy = [7575, 243 | drumsp, 244 | 7430+first_note_MIDI_patch_number, 245 | 7447+first_note_MIDI_pitch_number] 246 | 247 | print('Selected Improv sequence:') 248 | print(outy) 249 | print('=' * 70) 250 | 251 | inp = [outy] * number_of_batches_to_generate 252 | 253 | inp = torch.LongTensor(inp).cuda() 254 | 255 | with ctx: 256 | out = model.generate(inp, 257 | number_of_tokens_tp_generate, 258 | temperature=temperature, 259 | return_prime=True, 260 | verbose=True) 261 | 262 | out0 = out.tolist() 263 | 264 | print('=' * 70) 265 | print('Done!') 266 | print('=' * 70) 267 | 268 | #====================================================================== 269 | 270 | print('Rendering results...') 271 | 272 | for i in range(number_of_batches_to_generate): 273 | 274 | print('=' * 70) 275 | print('Batch #', i) 276 | print('=' * 70) 277 | 278 | out1 = out0[i] 279 | 280 | print('Sample INTs', out1[:12]) 281 | print('=' * 70) 282 | 283 | if len(out1) != 0: 284 | 285 | song = out1 286 | song_f = [] 287 | 288 | time = 0 289 | dur = 0 290 | vel = 90 291 | pitch = 0 292 | channel = 0 293 | 294 | for ss in song: 295 | 296 | if 0 <= ss < 512: 297 | 298 | time += ss * 8 299 | 300 | if 512 <= ss < 4608: 301 | 302 | dur = ((ss-512) // 8) * 8 303 | vel = (((ss-512) % 8)+1) * 15 304 | 305 | if 4608 <= ss < 6784: 306 | 307 | patch = (ss-4608) // 128 308 | 309 | if patch == 16: 310 | channel = 9 311 | else: 312 | if 9 <= patch <= 14: 313 | channel = patch + 1 314 | else: 315 | channel = patch 316 | 317 | if patch == 15: 318 | channel = 15 319 | 320 | pitch = (ss-4608) % 128 321 | 322 | if emph == 1: 323 | song_f.append(['text_event', time, 'Emph']) 324 | 325 | song_f.append(['note', time, dur, channel, pitch, vel ]) 326 | 327 | if 6784 < ss < 6787: 328 | emph = ss - 6784 329 | 330 | if emph == 1: 331 | song_f.append(['text_event', time, 'Emph']) 332 | 333 | 334 | data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 335 | output_signature = 'Experimental Music Transformer', 336 | output_file_name = '/content/Experimental-Music-Transformer-Composition_'+str(i), 337 | track_name='Project Los Angeles', 338 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 339 | ) 340 | 341 | 342 | print('=' * 70) 343 | print('Displaying resulting composition...') 344 | print('=' * 70) 345 | 346 | fname = '/content/Experimental-Music-Transformer-Composition_'+str(i) 347 | 348 | x = [] 349 | y =[] 350 | c = [] 351 | 352 | colors = ['red', 'yellow', 'green', 'cyan', 353 | 'blue', 'pink', 'orange', 'purple', 354 | 'gray', 'white', 'gold', 'silver', 355 | 'lightgreen', 'indigo', 'maroon', 'turquoise'] 356 | 357 | for s in song_f: 358 | if s[0] == 'note': 359 | x.append(s[1] / 1000) 360 | y.append(s[4]) 361 | c.append(colors[s[3]]) 362 | 363 | if render_MIDI_to_audio: 364 | midi_audio = midi_to_colab_audio(fname + '.mid') 365 | display(Audio(midi_audio, rate=16000, normalize=False)) 366 | 367 | plt.figure(figsize=(14,5)) 368 | ax=plt.axes(title=fname) 369 | ax.set_facecolor('black') 370 | 371 | plt.scatter(x,y, c=c) 372 | plt.xlabel("Time") 373 | plt.ylabel("Pitch") 374 | plt.show() 375 | 376 | """# (CUSTOM MIDI)""" 377 | 378 | #@title Load Seed MIDI 379 | 380 | #@markdown Press play button to to upload your own seed MIDI or to load one of the provided sample seed MIDIs from the dropdown list below 381 | 382 | select_seed_MIDI = "Nothing Else Matters" # @param ["Upload your own custom MIDI", "Nothing Else Matters", "Sharing The Night Together", "Honesty", "House Of The Rising Sun"] 383 | render_MIDI_to_audio = False # @param {type:"boolean"} 384 | 385 | print('=' * 70) 386 | print('Experimental Music Transformer Seed MIDI Loader') 387 | print('=' * 70) 388 | 389 | f = '' 390 | 391 | if select_seed_MIDI != "Upload your own custom MIDI": 392 | print('Loading seed MIDI...') 393 | f = '/content/Experimental-Music-Transformer/Seeds/'+select_seed_MIDI+'.mid' 394 | score = TMIDIX.midi2single_track_ms_score(open(f, 'rb').read(), recalculate_channels=False, pass_old_timings_events=True) 395 | 396 | else: 397 | print('Upload your own custom MIDI...') 398 | print('=' * 70) 399 | uploaded_MIDI = files.upload() 400 | if list(uploaded_MIDI.keys()): 401 | f = list(uploaded_MIDI.keys())[0] 402 | score = TMIDIX.midi2single_track_ms_score(open(f, 'rb').read(), recalculate_channels=False, pass_old_timings_events=True) 403 | 404 | if f != '': 405 | 406 | print('=' * 70) 407 | print('File:', f) 408 | print('=' * 70) 409 | 410 | #======================================================= 411 | # START PROCESSING 412 | 413 | score = TMIDIX.midi2single_track_ms_score(open(f, 'rb').read(), recalculate_channels=False, pass_old_timings_events=True) 414 | 415 | # INSTRUMENTS CONVERSION CYCLE 416 | events_matrix = [] 417 | itrack = 1 418 | patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 419 | 420 | emph_once = False 421 | emphasis_time = 0 422 | 423 | tpq = 0 424 | tempo = 0 425 | time_sig = 0 426 | key_sig = 0 427 | 428 | while itrack < len(score): 429 | for event in score[itrack]: 430 | 431 | if event[0] != 'note': 432 | event.extend([0, 128]) 433 | else: 434 | event.extend([0, 0]) 435 | 436 | if event[0] == 'text_event' or event[0] == 'lyric' or event[0] == 'patch_change' or event[0] == 'time_signature': 437 | event[4] = 128 438 | 439 | events_matrix.append(event) 440 | 441 | itrack += 1 442 | 443 | events_matrix.sort(key=lambda x: x[4], reverse = True) 444 | events_matrix.sort(key=lambda x: x[1]) 445 | 446 | events_matrix1 = [] 447 | 448 | pt = events_matrix[0][1] 449 | 450 | for event in events_matrix: 451 | if event[0] == 'patch_change': 452 | patches[event[2]] = event[3] 453 | 454 | #======================================================================== 455 | # Emphasis 456 | 457 | if event[0] == 'text_event' or event[0] == 'lyric': 458 | emphasis_time = event[1] 459 | emph_once = True 460 | 461 | if event[0] == 'note' and int(event[1] / 8) > int(emphasis_time / 8) and event[1] > pt: 462 | event[7] = 2 463 | emph_once = False 464 | 465 | if event[0] == 'note' and int(event[1] / 8) == int(emphasis_time / 8) and emph_once: 466 | event[7] = 1 467 | emph_once = False 468 | 469 | pt = event[1] 470 | 471 | #======================================================================== 472 | # Notes 473 | 474 | if event[0] == 'note': 475 | event[6] = patches[event[3]] 476 | 477 | if events_matrix1: 478 | if (event[1] == events_matrix1[-1][1]): 479 | if ([event[3], event[4]] != events_matrix1[-1][3:5]): 480 | events_matrix1.append(event) 481 | else: 482 | events_matrix1.append(event) 483 | 484 | else: 485 | events_matrix1.append(event) 486 | 487 | if len(events_matrix1) > 0: 488 | if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0: 489 | 490 | #======================================================= 491 | # PRE-PROCESSING 492 | 493 | # checking number of instruments in a composition 494 | instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9])) 495 | instruments_list = list(set([y[3] for y in events_matrix1])) 496 | 497 | if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0: 498 | 499 | #======================================================= 500 | # MAIN PROCESSING 501 | #======================================================= 502 | 503 | #======================================================= 504 | # Timings 505 | #======================================================= 506 | 507 | events_matrix2 = [] 508 | 509 | # Recalculating timings 510 | for e in events_matrix1: 511 | 512 | # Original timings 513 | e[1] = int(e[1] / 8) 514 | e[2] = int(e[2] / 8) 515 | 516 | #=================================== 517 | # ORIGINAL COMPOSITION 518 | #=================================== 519 | 520 | # Sorting by patch, pitch, then by start-time 521 | 522 | events_matrix1.sort(key=lambda x: x[6]) 523 | events_matrix1.sort(key=lambda x: x[4], reverse=True) 524 | events_matrix1.sort(key=lambda x: x[1]) 525 | 526 | #======================================================= 527 | # FINAL PROCESSING 528 | 529 | melody_chords = [] 530 | melody_chords1 = [] 531 | 532 | # Break between compositions / Intro seq 533 | 534 | if 9 in instruments_list: 535 | drums_present = 7429 # Yes 536 | else: 537 | drums_present = 7428 # No 538 | 539 | if events_matrix1[0][3] != 9: 540 | pat = events_matrix1[0][6] // 8 541 | else: 542 | pat = 16 543 | 544 | ptc = events_matrix1[0][4] 545 | 546 | melody_chords.extend([7575, drums_present, 7430+pat, 7447+ptc]) # Intro seq 547 | melody_chords1.append([7575, drums_present, 7430+pat, 7447+ptc]) 548 | #======================================================= 549 | # PROCESSING CYCLE 550 | #======================================================= 551 | 552 | pe = events_matrix1[0] 553 | 554 | for e in events_matrix1: 555 | 556 | #======================================================= 557 | # Timings... 558 | 559 | # Cliping all values... 560 | delta_time = max(0, min(511, e[1]-pe[1])) 561 | 562 | # Durations and channels 563 | 564 | dur = max(0, min(511, e[2])) 565 | cha = max(0, min(15, e[3])) 566 | 567 | # Patches 568 | if cha == 9: # Drums patch will be == 16 569 | pat = 16 570 | 571 | else: 572 | pat = e[6] // 8 573 | 574 | # Pitches 575 | ptc = max(1, min(127, e[4])) 576 | 577 | # Emphasis 578 | emph = e[7] 579 | 580 | # Velocities 581 | # Calculating octo-velocity 582 | vel = max(8, min(127, e[5])) 583 | velocity = round(vel / 15)-1 584 | 585 | #======================================================= 586 | # FINAL NOTE SEQ 587 | 588 | # Writing final note asynchronously 589 | 590 | dur_vel = (8 * dur) + velocity 591 | pat_ptc = (128 * pat) + ptc 592 | 593 | melody_chords.extend([emph+6784, delta_time, dur_vel+512, pat_ptc+4608]) 594 | melody_chords1.append([emph+6784, delta_time, dur_vel+512, pat_ptc+4608]) 595 | 596 | pe = e 597 | 598 | #======================================================= 599 | 600 | emphasis = [m for m in melody_chords if 6784 <= m <= 6785] 601 | 602 | #======================================================= 603 | 604 | song = melody_chords 605 | 606 | song_f = [] 607 | 608 | time = 0 609 | dur = 0 610 | vel = 90 611 | pitch = 0 612 | channel = 0 613 | 614 | for ss in song: 615 | 616 | if 0 <= ss < 512: 617 | 618 | time += ss * 8 619 | 620 | if 512 <= ss < 4608: 621 | 622 | dur = ((ss-512) // 8) * 8 623 | vel = (((ss-512) % 8)+1) * 15 624 | 625 | if 4608 <= ss < 6784: 626 | 627 | patch = (ss-4608) // 128 628 | 629 | if patch == 16: 630 | channel = 9 631 | else: 632 | if 9 <= patch <= 14: 633 | channel = patch + 1 634 | else: 635 | channel = patch 636 | 637 | if patch == 15: 638 | channel = 15 639 | 640 | pitch = (ss-4608) % 128 641 | 642 | if emph == 1: 643 | song_f.append(['text_event', time, 'Emph']) 644 | 645 | song_f.append(['note', time, dur, channel, pitch, vel ]) 646 | 647 | if 6784 < ss < 6787: 648 | emph = ss - 6784 649 | 650 | data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 651 | output_signature = 'Experimental Music Transformer', 652 | output_file_name = '/content/Experimental-Music-Transformer-Seed-Composition', 653 | track_name='Project Los Angeles', 654 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 655 | ) 656 | 657 | #======================================================= 658 | 659 | print('=' * 70) 660 | print('Composition stats:') 661 | print('Composition has', len(melody_chords1), 'notes') 662 | print('Composition has', len(melody_chords), 'tokens') 663 | print('=' * 70) 664 | 665 | print('Displaying resulting composition...') 666 | print('=' * 70) 667 | 668 | fname = '/content/Experimental-Music-Transformer-Seed-Composition' 669 | 670 | x = [] 671 | y =[] 672 | c = [] 673 | 674 | colors = ['red', 'yellow', 'green', 'cyan', 675 | 'blue', 'pink', 'orange', 'purple', 676 | 'gray', 'white', 'gold', 'silver', 677 | 'lightgreen', 'indigo', 'maroon', 'turquoise'] 678 | 679 | for s in song_f: 680 | if s[0] == 'note': 681 | x.append(s[1] / 1000) 682 | y.append(s[4]) 683 | c.append(colors[s[3]]) 684 | 685 | if render_MIDI_to_audio: 686 | midi_audio = midi_to_colab_audio(fname + '.mid') 687 | display(Audio(midi_audio, rate=16000, normalize=False)) 688 | 689 | plt.figure(figsize=(14,5)) 690 | ax=plt.axes(title=fname) 691 | ax.set_facecolor('black') 692 | 693 | plt.scatter(x,y, c=c) 694 | plt.xlabel("Time") 695 | plt.ylabel("Pitch") 696 | plt.show() 697 | 698 | else: 699 | print('=' * 70) 700 | 701 | """# (CONTINUATION)""" 702 | 703 | #@title Standard Continuation 704 | 705 | #@markdown Generation settings 706 | 707 | try_to_generate_outro = False #@param {type:"boolean"} 708 | number_of_prime_tokens = 1008 # @param {type:"slider", min:4, max:8190, step:4} 709 | number_of_tokens_to_generate = 1026 # @param {type:"slider", min:30, max:8190, step:4} 710 | number_of_batches_to_generate = 4 #@param {type:"slider", min:1, max:16, step:1} 711 | temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05} 712 | 713 | #@markdown Other settings 714 | include_prime_tokens_in_generated_output = True #@param {type:"boolean"} 715 | allow_model_to_stop_generation_if_needed = False #@param {type:"boolean"} 716 | render_MIDI_to_audio = True # @param {type:"boolean"} 717 | 718 | print('=' * 70) 719 | print('Experimental Music Transformer Standard Continuation Model Generator') 720 | print('=' * 70) 721 | 722 | if allow_model_to_stop_generation_if_needed: 723 | min_stop_token = 7576 724 | else: 725 | min_stop_token = None 726 | 727 | outy = melody_chords[:number_of_prime_tokens] 728 | 729 | if try_to_generate_outro: 730 | outy.extend([6787, 6787]) 731 | 732 | inp = [outy] * number_of_batches_to_generate 733 | 734 | inp = torch.LongTensor(inp).cuda() 735 | 736 | with ctx: 737 | out = model.generate(inp, 738 | number_of_tokens_to_generate, 739 | temperature=temperature, 740 | return_prime=include_prime_tokens_in_generated_output, 741 | eos_token=min_stop_token, 742 | verbose=True) 743 | 744 | out0 = out.tolist() 745 | 746 | print('=' * 70) 747 | print('Done!') 748 | print('=' * 70) 749 | 750 | #====================================================================== 751 | print('Rendering results...') 752 | 753 | for i in range(number_of_batches_to_generate): 754 | 755 | print('=' * 70) 756 | print('Batch #', i) 757 | print('=' * 70) 758 | 759 | out1 = out0[i] 760 | 761 | print('Sample INTs', out1[:12]) 762 | print('=' * 70) 763 | 764 | if len(out) != 0: 765 | 766 | song = out1 767 | song_f = [] 768 | 769 | time = 0 770 | dur = 0 771 | vel = 90 772 | pitch = 0 773 | channel = 0 774 | 775 | for ss in song: 776 | 777 | if 0 <= ss < 512: 778 | 779 | time += ss * 8 780 | 781 | if 512 <= ss < 4608: 782 | 783 | dur = ((ss-512) // 8) * 8 784 | vel = (((ss-512) % 8)+1) * 15 785 | 786 | if 4608 <= ss < 6784: 787 | 788 | patch = (ss-4608) // 128 789 | 790 | if patch == 16: 791 | channel = 9 792 | else: 793 | if 9 <= patch <= 14: 794 | channel = patch + 1 795 | else: 796 | channel = patch 797 | 798 | if patch == 15: 799 | channel = 15 800 | 801 | pitch = (ss-4608) % 128 802 | 803 | if emph == 1: 804 | song_f.append(['text_event', time, 'Emph']) 805 | 806 | song_f.append(['note', time, dur, channel, pitch, vel ]) 807 | 808 | if 6784 < ss < 6787: 809 | emph = ss - 6784 810 | 811 | data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 812 | output_signature = 'Experimental Music Transformer', 813 | output_file_name = '/content/Experimental-Music-Transformer-Composition_'+str(i), 814 | track_name='Project Los Angeles', 815 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 816 | ) 817 | 818 | 819 | print('=' * 70) 820 | print('Displaying resulting composition...') 821 | print('=' * 70) 822 | 823 | fname = '/content/Experimental-Music-Transformer-Composition_'+str(i) 824 | 825 | x = [] 826 | y =[] 827 | c = [] 828 | 829 | colors = ['red', 'yellow', 'green', 'cyan', 830 | 'blue', 'pink', 'orange', 'purple', 831 | 'gray', 'white', 'gold', 'silver', 832 | 'lightgreen', 'indigo', 'maroon', 'turquoise'] 833 | 834 | for s in song_f: 835 | if s[0] == 'note': 836 | x.append(s[1] / 1000) 837 | y.append(s[4]) 838 | c.append(colors[s[3]]) 839 | 840 | if render_MIDI_to_audio: 841 | midi_audio = midi_to_colab_audio(fname + '.mid') 842 | display(Audio(midi_audio, rate=16000, normalize=False)) 843 | 844 | plt.figure(figsize=(14,5)) 845 | ax=plt.axes(title=fname) 846 | ax.set_facecolor('black') 847 | 848 | plt.scatter(x,y, c=c) 849 | plt.xlabel("Time") 850 | plt.ylabel("Pitch") 851 | plt.show() 852 | 853 | """# (INPAINTING)""" 854 | 855 | #@title Emphasis-based Notes Inpainting 856 | 857 | #@markdown You can stop the inpainting at any time to render partial results 858 | 859 | #@markdown Inpainting settings 860 | 861 | #@markdown Select MIDI patch present in the composition to inpaint 862 | 863 | inpainting_type = "Times-Durations-Velocities-Pitches" # @param ["Pitches", "Durations-Velocities-Pitches", "Times-Durations-Velocities-Pitches"] 864 | 865 | #@markdown Generation settings 866 | 867 | number_of_prime_notes = 1 # @param {type:"slider", min:1, max:2047, step:1} 868 | number_of_memory_tokens = 8188 # @param {type:"slider", min:4, max:8188, step:4} 869 | number_of_samples_per_inpainted_note = 4 #@param {type:"slider", min:1, max:16, step:1} 870 | temperature = 0.9 # @param {type:"slider", min:0.1, max:1, step:0.05} 871 | 872 | #@markdown Other settings 873 | 874 | render_MIDI_to_audio = False # @param {type:"boolean"} 875 | 876 | print('=' * 70) 877 | print('Experimental Music Transformer Inpainting Model Generator') 878 | print('=' * 70) 879 | 880 | if inpainting_type == 'Pitches': 881 | t1 = 3 882 | t2 = 1 883 | 884 | if inpainting_type == 'Durations-Velocities-Pitches': 885 | t1 = t2 = 2 886 | 887 | if inpainting_type == 'Times-Durations-Velocities-Pitches': 888 | t1 = 1 889 | t2 = 3 890 | 891 | out2 = [] 892 | 893 | number_of_prime_tokens = number_of_prime_notes * 4 894 | 895 | for m in melody_chords[:number_of_prime_tokens]: 896 | out2.append(m) 897 | 898 | for i in tqdm.tqdm(range(number_of_prime_tokens, len(melody_chords), 4)): 899 | 900 | try: 901 | out2.extend(melody_chords[i:i+t1]) 902 | 903 | if melody_chords[i] < 6787: 904 | 905 | samples = [] 906 | 907 | for j in range(number_of_samples_per_inpainted_note): 908 | 909 | inp = torch.LongTensor(out2[-number_of_memory_tokens:]).cuda() 910 | 911 | with ctx: 912 | out1 = model.generate(inp, 913 | t2, 914 | temperature=temperature, 915 | return_prime=False, 916 | verbose=False) 917 | 918 | with torch.no_grad(): 919 | test_loss, test_acc = model(out1) 920 | 921 | samples.append([out1.tolist()[0], test_acc.tolist()]) 922 | 923 | accs = [y[1] for y in samples] 924 | max_acc = max(accs) 925 | max_acc_sample = samples[accs.index(max_acc)][0] 926 | 927 | out2.extend(max_acc_sample) 928 | 929 | else: 930 | out2.extend(melody_chords[i+t1:i+4]) 931 | 932 | except KeyboardInterrupt: 933 | print('Stopping inpainting...') 934 | break 935 | 936 | except Exception as e: 937 | print('Error', e) 938 | break 939 | 940 | print('Done!') 941 | print('=' * 70) 942 | 943 | #================================================== 944 | 945 | print('Rendering results...') 946 | print('=' * 70) 947 | 948 | if len(out2) != 0: 949 | 950 | song = out2 951 | song_f = [] 952 | 953 | time = 0 954 | dur = 0 955 | vel = 90 956 | pitch = 0 957 | channel = 0 958 | 959 | for ss in song: 960 | 961 | if 0 <= ss < 512: 962 | 963 | time += ss * 8 964 | 965 | if 512 <= ss < 4608: 966 | 967 | dur = ((ss-512) // 8) * 8 968 | vel = (((ss-512) % 8)+1) * 15 969 | 970 | if 4608 <= ss < 6784: 971 | 972 | patch = (ss-4608) // 128 973 | 974 | if patch == 16: 975 | channel = 9 976 | else: 977 | if 9 <= patch <= 14: 978 | channel = patch + 1 979 | else: 980 | channel = patch 981 | 982 | if patch == 15: 983 | channel = 15 984 | 985 | pitch = (ss-4608) % 128 986 | 987 | if emph == 1: 988 | song_f.append(['text_event', time, 'Emph']) 989 | 990 | song_f.append(['note', time, dur, channel, pitch, vel ]) 991 | 992 | if 6784 < ss < 6787: 993 | emph = ss - 6784 994 | 995 | data = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, 996 | output_signature = 'Experimental Music Transformer', 997 | output_file_name = '/content/Experimental-Music-Transformer-Composition', 998 | track_name='Project Los Angeles', 999 | list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117] 1000 | ) 1001 | 1002 | 1003 | print('=' * 70) 1004 | print('Displaying resulting composition...') 1005 | print('=' * 70) 1006 | 1007 | fname = '/content/Experimental-Music-Transformer-Composition' 1008 | 1009 | x = [] 1010 | y =[] 1011 | c = [] 1012 | 1013 | colors = ['red', 'yellow', 'green', 'cyan', 1014 | 'blue', 'pink', 'orange', 'purple', 1015 | 'gray', 'white', 'gold', 'silver', 1016 | 'lightgreen', 'indigo', 'maroon', 'turquoise'] 1017 | 1018 | for s in song_f: 1019 | if s[0] == 'note': 1020 | x.append(s[1] / 1000) 1021 | y.append(s[4]) 1022 | c.append(colors[s[3]]) 1023 | 1024 | if render_MIDI_to_audio: 1025 | midi_audio = midi_to_colab_audio(fname + '.mid') 1026 | display(Audio(midi_audio, rate=16000, normalize=False)) 1027 | 1028 | plt.figure(figsize=(14,5)) 1029 | ax=plt.axes(title=fname) 1030 | ax.set_facecolor('black') 1031 | 1032 | plt.scatter(x,y, c=c) 1033 | plt.xlabel("Time") 1034 | plt.ylabel("Pitch") 1035 | plt.show() 1036 | 1037 | """# Congrats! You did it! :)""" -------------------------------------------------------------------------------- /Training-Data/Experimental_Music_Transformer_Training_Dataset_Maker.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "gradient": { 7 | "editing": false, 8 | "id": "ac5a4cf0-d9d2-47b5-9633-b53f8d99a4d2", 9 | "kernelId": "" 10 | }, 11 | "id": "SiTIpPjArIyr" 12 | }, 13 | "source": [ 14 | "# Experimental Music Transformer Training Dataset Maker (ver. 1.0)\n", 15 | "\n", 16 | "***\n", 17 | "\n", 18 | "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n", 19 | "\n", 20 | "***\n", 21 | "\n", 22 | "#### Project Los Angeles\n", 23 | "\n", 24 | "#### Tegridy Code 2023\n", 25 | "\n", 26 | "***" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "gradient": { 33 | "editing": false, 34 | "id": "fa0a611c-1803-42ae-bdf6-a49b5a4e781b", 35 | "kernelId": "" 36 | }, 37 | "id": "gOd93yV0sGd2" 38 | }, 39 | "source": [ 40 | "# (SETUP ENVIRONMENT)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "cellView": "form", 48 | "gradient": { 49 | "editing": false, 50 | "id": "a1a45a91-d909-4fd4-b67a-5e16b971d179", 51 | "kernelId": "" 52 | }, 53 | "id": "fX12Yquyuihc" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "#@title Install all dependencies (run only once per session)\n", 58 | "\n", 59 | "!git clone https://github.com/asigalov61/tegridy-tools\n", 60 | "!pip install tqdm" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "cellView": "form", 68 | "gradient": { 69 | "editing": false, 70 | "id": "b8207b76-9514-4c07-95db-95a4742e52c5", 71 | "kernelId": "" 72 | }, 73 | "id": "z7n9vnKmug1J" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "#@title Import all needed modules\n", 78 | "\n", 79 | "print('Loading needed modules. Please wait...')\n", 80 | "import os\n", 81 | "import copy\n", 82 | "import math\n", 83 | "import statistics\n", 84 | "import random\n", 85 | "\n", 86 | "from tqdm import tqdm\n", 87 | "\n", 88 | "if not os.path.exists('/content/Dataset'):\n", 89 | " os.makedirs('/content/Dataset')\n", 90 | "\n", 91 | "print('Loading TMIDIX module...')\n", 92 | "os.chdir('/content/tegridy-tools/tegridy-tools')\n", 93 | "\n", 94 | "import TMIDIX\n", 95 | "\n", 96 | "from joblib import Parallel, delayed\n", 97 | "\n", 98 | "print('Done!')\n", 99 | "\n", 100 | "os.chdir('/content/')\n", 101 | "print('Enjoy! :)')" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "source": [], 107 | "metadata": { 108 | "id": "1t3SBFxq_UOR" 109 | } 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": { 114 | "gradient": { 115 | "editing": false, 116 | "id": "20b8698a-0b4e-4fdb-ae49-24d063782e77", 117 | "kernelId": "" 118 | }, 119 | "id": "ObPxlEutsQBj" 120 | }, 121 | "source": [ 122 | "# (DOWNLOAD SOURCE MIDI DATASET)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "source": [ 128 | "#@title Download original LAKH MIDI Dataset\n", 129 | "\n", 130 | "%cd /content/Dataset/\n", 131 | "\n", 132 | "!wget 'http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz'\n", 133 | "!tar -xvf 'lmd_full.tar.gz'\n", 134 | "!rm 'lmd_full.tar.gz'\n", 135 | "\n", 136 | "%cd /content/" 137 | ], 138 | "metadata": { 139 | "cellView": "form", 140 | "id": "7aItlhq9cRxZ" 141 | }, 142 | "execution_count": null, 143 | "outputs": [] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": { 149 | "cellView": "form", 150 | "id": "S69mWHAcn5Bg" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "#@title Mount Google Drive\n", 155 | "from google.colab import drive\n", 156 | "drive.mount('/content/drive')" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": { 162 | "id": "JwrqQeie08t0" 163 | }, 164 | "source": [ 165 | "# (FILE LIST)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "cellView": "form", 173 | "id": "DuVWtdDNcqKh" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "#@title Save file list\n", 178 | "###########\n", 179 | "\n", 180 | "print('Loading MIDI files...')\n", 181 | "print('This may take a while on a large dataset in particular.')\n", 182 | "\n", 183 | "dataset_addr = \"/content/Dataset\"\n", 184 | "# os.chdir(dataset_addr)\n", 185 | "filez = list()\n", 186 | "for (dirpath, dirnames, filenames) in os.walk(dataset_addr):\n", 187 | " filez += [os.path.join(dirpath, file) for file in filenames]\n", 188 | "print('=' * 70)\n", 189 | "\n", 190 | "if filez == []:\n", 191 | " print('Could not find any MIDI files. Please check Dataset dir...')\n", 192 | " print('=' * 70)\n", 193 | "\n", 194 | "print('Randomizing file list...')\n", 195 | "random.shuffle(filez)\n", 196 | "\n", 197 | "TMIDIX.Tegridy_Any_Pickle_File_Writer(filez, '/content/drive/MyDrive/filez')" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": { 204 | "cellView": "form", 205 | "id": "qI_adhjojrJ9" 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "#@title Load file list\n", 210 | "filez = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/drive/MyDrive/filez')" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": { 216 | "id": "FLxHvO-wlwfU" 217 | }, 218 | "source": [ 219 | "# (PROCESS)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "source": [ 225 | "#@title Process MIDIs with TMIDIX MIDI processor\n", 226 | "\n", 227 | "#===============================================================================\n", 228 | "\n", 229 | "def TMIDIX_MIDI_Processor(midi_file):\n", 230 | "\n", 231 | " melody_chords = []\n", 232 | " melody_chords_aug = []\n", 233 | "\n", 234 | " try:\n", 235 | "\n", 236 | " fn = os.path.basename(midi_file)\n", 237 | "\n", 238 | " # Filtering out EXP MIDIs\n", 239 | " file_size = os.path.getsize(midi_file)\n", 240 | "\n", 241 | " if file_size <= 1000000:\n", 242 | "\n", 243 | " #=======================================================\n", 244 | " # START PROCESSING\n", 245 | "\n", 246 | " score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False, pass_old_timings_events=True)\n", 247 | "\n", 248 | " # INSTRUMENTS CONVERSION CYCLE\n", 249 | " events_matrix = []\n", 250 | " itrack = 1\n", 251 | " patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n", 252 | "\n", 253 | " emph_once = False\n", 254 | " emphasis_time = 0\n", 255 | "\n", 256 | " tpq = 0\n", 257 | " tempo = 0\n", 258 | " time_sig = 0\n", 259 | " key_sig = 0\n", 260 | "\n", 261 | " while itrack < len(score):\n", 262 | " for event in score[itrack]:\n", 263 | "\n", 264 | " if event[0] != 'note':\n", 265 | " event.extend([0, 128])\n", 266 | " else:\n", 267 | " event.extend([0, 0])\n", 268 | "\n", 269 | " if event[0] == 'text_event' or event[0] == 'lyric' or event[0] == 'patch_change' or event[0] == 'time_signature':\n", 270 | " event[4] = 128\n", 271 | "\n", 272 | " events_matrix.append(event)\n", 273 | "\n", 274 | " itrack += 1\n", 275 | "\n", 276 | " events_matrix.sort(key=lambda x: x[4], reverse = True)\n", 277 | " events_matrix.sort(key=lambda x: x[1])\n", 278 | "\n", 279 | " events_matrix1 = []\n", 280 | "\n", 281 | " pt = events_matrix[0][1]\n", 282 | "\n", 283 | " for event in events_matrix:\n", 284 | " if event[0] == 'patch_change':\n", 285 | " patches[event[2]] = event[3]\n", 286 | "\n", 287 | " #========================================================================\n", 288 | " # Emphasis\n", 289 | "\n", 290 | " if event[0] == 'text_event' or event[0] == 'lyric':\n", 291 | " emphasis_time = event[1]\n", 292 | " emph_once = True\n", 293 | "\n", 294 | " if event[0] == 'note' and int(event[1] / 8) > int(emphasis_time / 8) and event[1] > pt:\n", 295 | " event[7] = 2\n", 296 | " emph_once = False\n", 297 | "\n", 298 | " if event[0] == 'note' and int(event[1] / 8) == int(emphasis_time / 8) and emph_once:\n", 299 | " event[7] = 1\n", 300 | " emph_once = False\n", 301 | "\n", 302 | " pt = event[1]\n", 303 | "\n", 304 | " #========================================================================\n", 305 | " # Tempo\n", 306 | "\n", 307 | " if event[0] == 'old_tpq':\n", 308 | " tpq = event[2]\n", 309 | "\n", 310 | " if event[0] == 'old_set_tempo':\n", 311 | " tempo = event[2]\n", 312 | "\n", 313 | " #========================================================================\n", 314 | " # Time and key sigs\n", 315 | "\n", 316 | " if event[0] == 'time_signature':\n", 317 | " time_sig = round((event[2] / max(1, event[3])) * 10)\n", 318 | "\n", 319 | " if event[0] == 'key_signature':\n", 320 | " key_sig = (event[3] * 16) + event[2]+8\n", 321 | "\n", 322 | " #========================================================================\n", 323 | " # Notes\n", 324 | "\n", 325 | " if event[0] == 'note':\n", 326 | " event[6] = patches[event[3]]\n", 327 | " event.extend([round(tempo / tpq / 100)])\n", 328 | " event.extend([time_sig])\n", 329 | " event.extend([key_sig])\n", 330 | "\n", 331 | " if events_matrix1:\n", 332 | " if (event[1] == events_matrix1[-1][1]):\n", 333 | " if ([event[3], event[4]] != events_matrix1[-1][3:5]):\n", 334 | " events_matrix1.append(event)\n", 335 | " else:\n", 336 | " events_matrix1.append(event)\n", 337 | "\n", 338 | " else:\n", 339 | " events_matrix1.append(event)\n", 340 | "\n", 341 | " if len(events_matrix1) > 0:\n", 342 | " if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0:\n", 343 | "\n", 344 | " #=======================================================\n", 345 | " # PRE-PROCESSING\n", 346 | "\n", 347 | " # checking number of instruments in a composition\n", 348 | " instruments_list_without_drums = list(set([y[3] for y in events_matrix1 if y[3] != 9]))\n", 349 | " instruments_list = list(set([y[3] for y in events_matrix1]))\n", 350 | "\n", 351 | " if len(events_matrix1) > 0 and len(instruments_list_without_drums) > 0:\n", 352 | "\n", 353 | " num_karaoke_events = len([y for y in events_matrix if y[0] == 'text_event' or y[0] == 'lyric'])\n", 354 | "\n", 355 | " # checking number of karaoke events in a composition\n", 356 | " if num_karaoke_events >= 100:\n", 357 | "\n", 358 | " #=======================================================\n", 359 | " # MAIN PROCESSING\n", 360 | " #=======================================================\n", 361 | "\n", 362 | " #=======================================================\n", 363 | " # Timings\n", 364 | " #=======================================================\n", 365 | "\n", 366 | " events_matrix2 = []\n", 367 | "\n", 368 | " # Recalculating timings\n", 369 | " for e in events_matrix1:\n", 370 | "\n", 371 | " ev = copy.deepcopy(e)\n", 372 | "\n", 373 | " # Original timings\n", 374 | " e[1] = int(e[1] / 8)\n", 375 | " e[2] = int(e[2] / 8)\n", 376 | "\n", 377 | " # Augmented timings (+ 5%)\n", 378 | " ev[1] = int((ev[1] * 1.05) / 8)\n", 379 | " ev[2] = int((ev[2] * 1.05) / 8)\n", 380 | "\n", 381 | " events_matrix2.append(ev)\n", 382 | "\n", 383 | " #===================================\n", 384 | " # ORIGINAL COMPOSITION\n", 385 | " #===================================\n", 386 | "\n", 387 | " # Sorting by patch, pitch, then by start-time\n", 388 | "\n", 389 | " events_matrix1.sort(key=lambda x: x[6])\n", 390 | " events_matrix1.sort(key=lambda x: x[4], reverse=True)\n", 391 | " events_matrix1.sort(key=lambda x: x[1])\n", 392 | "\n", 393 | " #=======================================================\n", 394 | " # FINAL PROCESSING\n", 395 | "\n", 396 | " melody_chords = []\n", 397 | "\n", 398 | " # Break between compositions / Intro seq\n", 399 | "\n", 400 | " if 9 in instruments_list:\n", 401 | " drums_present = 8852 # Yes\n", 402 | " else:\n", 403 | " drums_present = 8851 # No\n", 404 | "\n", 405 | " if events_matrix1[0][3] != 9:\n", 406 | " pat = max(0, min(127, events_matrix1[0][6])) // 8\n", 407 | " else:\n", 408 | " pat = 16\n", 409 | "\n", 410 | " ptc = events_matrix1[0][4]\n", 411 | "\n", 412 | " melody_chords.extend([8998, drums_present, 8853+pat, 8870+ptc]) # Intro seq\n", 413 | "\n", 414 | " #=======================================================\n", 415 | " # PROCESSING CYCLE\n", 416 | " #=======================================================\n", 417 | "\n", 418 | " abs_time = 0\n", 419 | "\n", 420 | " pbar_time = 0\n", 421 | "\n", 422 | " pe = events_matrix1[0]\n", 423 | "\n", 424 | " chords_counter = 1\n", 425 | "\n", 426 | " time_key_seq = [0, 0, 0]\n", 427 | " old_time_key_seq = [0, 0, 0]\n", 428 | "\n", 429 | " tempo = 0\n", 430 | " time_sig = 0\n", 431 | " key_sig = 0\n", 432 | "\n", 433 | " comp_chords_len = len(list(set([y[1] for y in events_matrix1])))\n", 434 | "\n", 435 | " for e in events_matrix1:\n", 436 | "\n", 437 | " #=======================================================\n", 438 | " # Timings...\n", 439 | "\n", 440 | " # Cliping all values...\n", 441 | " delta_time = max(0, min(511, e[1]-pe[1]))\n", 442 | " abs_time += delta_time\n", 443 | "\n", 444 | " bar_time = abs_time // 512\n", 445 | " bar_time_local = abs_time % 512\n", 446 | "\n", 447 | " if bar_time >= 1022:\n", 448 | " break\n", 449 | "\n", 450 | " # Durations and channels\n", 451 | "\n", 452 | " dur = max(0, min(511, e[2]))\n", 453 | " cha = max(0, min(15, e[3]))\n", 454 | "\n", 455 | " # Patches\n", 456 | " if cha == 9: # Drums patch will be == 16\n", 457 | " pat = 16\n", 458 | "\n", 459 | " else:\n", 460 | " pat = max(0, min(127, e[6])) // 8\n", 461 | "\n", 462 | " # Pitches\n", 463 | " ptc = max(1, min(127, e[4]))\n", 464 | "\n", 465 | "\n", 466 | " # Emphasis\n", 467 | " emph = e[7]\n", 468 | "\n", 469 | " # Velocities\n", 470 | " # Calculating octo-velocity\n", 471 | " vel = max(8, min(127, e[5]))\n", 472 | " velocity = round(vel / 15)-1\n", 473 | "\n", 474 | " #=======================================================\n", 475 | " # Outro seq\n", 476 | "\n", 477 | " if ((comp_chords_len - chords_counter) == 50) and (delta_time != 0):\n", 478 | " out_t = 7810+delta_time\n", 479 | " out_p = 8322+ptc\n", 480 | " melody_chords.extend([8850, 8850, out_t, out_p]) # outro seq\n", 481 | "\n", 482 | " #=======================================================\n", 483 | "\n", 484 | " if time_key_seq[0] != e[8]: # Tempo\n", 485 | " time_key_seq[0] = e[8]\n", 486 | "\n", 487 | " if time_key_seq[1] != e[9]: # Time sig\n", 488 | " time_key_seq[1] = e[9]\n", 489 | "\n", 490 | " if time_key_seq[2] != e[10]: # Key sig\n", 491 | " time_key_seq[2] = e[10]\n", 492 | "\n", 493 | " if time_key_seq != old_time_key_seq:\n", 494 | "\n", 495 | " old_time_key_seq = time_key_seq\n", 496 | "\n", 497 | " time_key_seq[0] = max(0, min(254, time_key_seq[0])) + 8451\n", 498 | " time_key_seq[1] = max(0, min(128, time_key_seq[1])) + 8706\n", 499 | " time_key_seq[2] = max(0, min(16, time_key_seq[2])) + 8834\n", 500 | "\n", 501 | " melody_chords.extend([8450] + time_key_seq)\n", 502 | "\n", 503 | " #=======================================================\n", 504 | " # Bar counter seq\n", 505 | "\n", 506 | " if (bar_time > pbar_time) and (delta_time != 0):\n", 507 | " bar = 6787+min(1022, (bar_time)) # bar counter seq\n", 508 | " bar_t = 7810+bar_time_local\n", 509 | " bar_p = 8322+ptc\n", 510 | " melody_chords.extend([6787, bar, bar_t, bar_p])\n", 511 | " chords_counter += 1\n", 512 | " pbar_time = bar_time\n", 513 | "\n", 514 | " else:\n", 515 | " if delta_time != 0:\n", 516 | " chords_counter += 1\n", 517 | "\n", 518 | " #=======================================================\n", 519 | " # FINAL NOTE SEQ\n", 520 | "\n", 521 | " # Writing final note asynchronously\n", 522 | "\n", 523 | " dur_vel = (8 * dur) + velocity\n", 524 | " pat_ptc = (128 * pat) + ptc\n", 525 | "\n", 526 | " melody_chords.extend([emph+6784, delta_time, dur_vel+512, pat_ptc+4608])\n", 527 | "\n", 528 | " pe = e\n", 529 | "\n", 530 | " #=======================================================\n", 531 | "\n", 532 | " melody_chords.extend([8999, 8999, 8999, 8999]) # EOS\n", 533 | "\n", 534 | " #===================================\n", 535 | " # AUGMENTED COMPOSITION\n", 536 | " #===================================\n", 537 | "\n", 538 | " # Sorting by patch, pitch, then by start-time\n", 539 | "\n", 540 | " events_matrix2.sort(key=lambda x: x[6])\n", 541 | " events_matrix2.sort(key=lambda x: x[4], reverse=True)\n", 542 | " events_matrix2.sort(key=lambda x: x[1])\n", 543 | "\n", 544 | " # Simple pitches augmentation\n", 545 | "\n", 546 | " ptc_shift = 1 # Shifting up by 1 semi-tone\n", 547 | "\n", 548 | " for e in events_matrix2:\n", 549 | " if e[3] != 9:\n", 550 | " e[4] = e[4] + ptc_shift\n", 551 | "\n", 552 | " #=======================================================\n", 553 | " # FINAL PROCESSING\n", 554 | "\n", 555 | " melody_chords_aug = []\n", 556 | "\n", 557 | " # Break between compositions / Intro seq\n", 558 | "\n", 559 | " if 9 in instruments_list:\n", 560 | " drums_present = 8852 # Yes\n", 561 | " else:\n", 562 | " drums_present = 8851 # No\n", 563 | "\n", 564 | " if events_matrix2[0][3] != 9:\n", 565 | " pat = max(0, min(127, events_matrix2[0][6])) // 8\n", 566 | " else:\n", 567 | " pat = 16\n", 568 | "\n", 569 | " ptc = events_matrix2[0][4]\n", 570 | "\n", 571 | " melody_chords_aug.extend([8998, drums_present, 8853+pat, 8870+ptc]) # Intro seq\n", 572 | "\n", 573 | " #=======================================================\n", 574 | " # PROCESSING CYCLE\n", 575 | " #=======================================================\n", 576 | "\n", 577 | " abs_time = 0\n", 578 | "\n", 579 | " pbar_time = 0\n", 580 | "\n", 581 | " pe = events_matrix2[0]\n", 582 | "\n", 583 | " chords_counter = 1\n", 584 | "\n", 585 | " time_key_seq = [0, 0, 0]\n", 586 | " old_time_key_seq = [0, 0, 0]\n", 587 | "\n", 588 | " tempo = 0\n", 589 | " time_sig = 0\n", 590 | " key_sig = 0\n", 591 | "\n", 592 | " comp_chords_len = len(list(set([y[1] for y in events_matrix2])))\n", 593 | "\n", 594 | " for e in events_matrix2:\n", 595 | "\n", 596 | " #=======================================================\n", 597 | " # Timings...\n", 598 | "\n", 599 | " # Cliping all values...\n", 600 | " delta_time = max(0, min(511, e[1]-pe[1]))\n", 601 | " abs_time += delta_time\n", 602 | "\n", 603 | " bar_time = abs_time // 512\n", 604 | " bar_time_local = abs_time % 512\n", 605 | "\n", 606 | " if bar_time >= 1022:\n", 607 | " break\n", 608 | "\n", 609 | " # Durations and channels\n", 610 | "\n", 611 | " dur = max(0, min(511, e[2]))\n", 612 | " cha = max(0, min(15, e[3]))\n", 613 | "\n", 614 | " # Patches\n", 615 | " if cha == 9: # Drums patch will be == 128\n", 616 | " pat = 16\n", 617 | "\n", 618 | " else:\n", 619 | " pat = max(0, min(127, e[6])) // 8\n", 620 | "\n", 621 | " # Pitches\n", 622 | " ptc = max(1, min(127, e[4]))\n", 623 | "\n", 624 | " # Emphasis\n", 625 | " emph = e[7]\n", 626 | "\n", 627 | " # Velocities\n", 628 | " # Calculating octo-velocity\n", 629 | " vel = max(8, min(127, e[5]-4))\n", 630 | " velocity = round(vel / 15)-1\n", 631 | "\n", 632 | " #=======================================================\n", 633 | " # Outro seq\n", 634 | "\n", 635 | " if ((comp_chords_len - chords_counter) == 50) and (delta_time != 0):\n", 636 | " out_t = 7810+delta_time\n", 637 | " out_p = 8322+ptc\n", 638 | " melody_chords_aug.extend([8850, 8850, out_t, out_p]) # outro seq\n", 639 | "\n", 640 | " #=======================================================\n", 641 | "\n", 642 | " if time_key_seq[0] != e[8]: # Tempo\n", 643 | " time_key_seq[0] = e[8]\n", 644 | "\n", 645 | " if time_key_seq[1] != e[9]: # Time sig\n", 646 | " time_key_seq[1] = e[9]\n", 647 | "\n", 648 | " if time_key_seq[2] != e[10]: # Key sig\n", 649 | " time_key_seq[2] = e[10]\n", 650 | "\n", 651 | " if time_key_seq != old_time_key_seq:\n", 652 | " old_time_key_seq = time_key_seq\n", 653 | "\n", 654 | " time_key_seq[0] = max(0, min(254, time_key_seq[0])) + 8451\n", 655 | " time_key_seq[1] = max(0, min(128, time_key_seq[1])) + 8706\n", 656 | " time_key_seq[2] = max(0, min(16, time_key_seq[2])) + 8834\n", 657 | "\n", 658 | " melody_chords_aug.extend([8450] + time_key_seq)\n", 659 | "\n", 660 | " #=======================================================\n", 661 | " # Bar counter seq\n", 662 | "\n", 663 | " if (bar_time > pbar_time) and (delta_time != 0):\n", 664 | " bar = 6787+min(1022, (bar_time)) # bar counter seq\n", 665 | " bar_t = 7810+bar_time_local\n", 666 | " bar_p = 8322+ptc\n", 667 | " melody_chords_aug.extend([6787, bar, bar_t, bar_p])\n", 668 | " chords_counter += 1\n", 669 | " pbar_time = bar_time\n", 670 | "\n", 671 | " else:\n", 672 | " if delta_time != 0:\n", 673 | " chords_counter += 1\n", 674 | "\n", 675 | " #=======================================================\n", 676 | " # FINAL NOTE SEQ\n", 677 | "\n", 678 | " # Writing final note asynchronously\n", 679 | "\n", 680 | " dur_vel = (8 * dur) + velocity\n", 681 | " pat_ptc = (128 * pat) + ptc\n", 682 | "\n", 683 | " melody_chords_aug.extend([emph+6784, delta_time, dur_vel+512, pat_ptc+4608])\n", 684 | "\n", 685 | " pe = e\n", 686 | "\n", 687 | " #=======================================================\n", 688 | "\n", 689 | " melody_chords_aug.extend([8999, 8999, 8999, 8999]) # EOS\n", 690 | "\n", 691 | " #=======================================================\n", 692 | "\n", 693 | " # TOTAL DICTIONARY SIZE 8999+1=9000\n", 694 | "\n", 695 | " #=======================================================\n", 696 | "\n", 697 | " return melody_chords, melody_chords_aug\n", 698 | "\n", 699 | " except Exception as ex:\n", 700 | " print('WARNING !!!')\n", 701 | " print('=' * 70)\n", 702 | " print('Bad MIDI:', f)\n", 703 | " print('Error detected:', ex)\n", 704 | " print('=' * 70)\n", 705 | " return None\n", 706 | "\n", 707 | "#===============================================================================\n", 708 | "\n", 709 | "print('=' * 70)\n", 710 | "print('TMIDIX MIDI Processor')\n", 711 | "print('=' * 70)\n", 712 | "print('Starting up...')\n", 713 | "print('=' * 70)\n", 714 | "\n", 715 | "###########\n", 716 | "\n", 717 | "melody_chords_f = []\n", 718 | "melody_chords_f_aug = []\n", 719 | "\n", 720 | "files_count = 0\n", 721 | "\n", 722 | "print('Processing MIDI files. Please wait...')\n", 723 | "print('=' * 70)\n", 724 | "\n", 725 | "for i in tqdm(range(0, len(filez), 16)):\n", 726 | "\n", 727 | " output = Parallel(n_jobs=4, verbose=0)(delayed(TMIDIX_MIDI_Processor)(fa) for fa in filez[i:i+16])\n", 728 | "\n", 729 | " for o in output:\n", 730 | "\n", 731 | " if o is not None:\n", 732 | " melody_chords_f.append(o[0])\n", 733 | " melody_chords_f_aug.append(o[1])\n", 734 | " files_count += 1\n", 735 | "\n", 736 | " # Saving every 2560 processed files\n", 737 | " if files_count % 2560 == 0 and files_count != 0:\n", 738 | " print('SAVING !!!')\n", 739 | " print('=' * 70)\n", 740 | " print('Saving processed files...')\n", 741 | " print('=' * 70)\n", 742 | " print('Data check:', min(melody_chords_f[0]), '===', max(melody_chords_f[0]), '===', len(list(set(melody_chords_f[0]))), '===', len(melody_chords_f[0]))\n", 743 | " print('=' * 70)\n", 744 | " print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio')\n", 745 | " print('=' * 70)\n", 746 | " count = str(files_count)\n", 747 | " TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/drive/MyDrive/LAKH_INTs_'+count)\n", 748 | " TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f_aug, '/content/drive/MyDrive/LAKH_AUG_INTs_'+count)\n", 749 | "\n", 750 | " melody_chords_f = []\n", 751 | " melody_chords_f_aug = []\n", 752 | "\n", 753 | " print('=' * 70)\n", 754 | "\n", 755 | "print('FINAL SAVING !!!')\n", 756 | "print('=' * 70)\n", 757 | "print('Saving processed files...')\n", 758 | "print('=' * 70)\n", 759 | "print('Data check:', min(melody_chords_f[0]), '===', max(melody_chords_f[0]), '===', len(list(set(melody_chords_f[0]))), '===', len(melody_chords_f[0]))\n", 760 | "print('=' * 70)\n", 761 | "print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio')\n", 762 | "print('=' * 70)\n", 763 | "count = str(files_count)\n", 764 | "TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/drive/MyDrive/LAKH_INTs_'+count)\n", 765 | "TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f_aug, '/content/drive/MyDrive/LAKH_AUG_INTs_'+count)\n", 766 | "print('=' * 70)" 767 | ], 768 | "metadata": { 769 | "cellView": "form", 770 | "id": "kpjqdWXvg2Nw" 771 | }, 772 | "execution_count": null, 773 | "outputs": [] 774 | }, 775 | { 776 | "cell_type": "markdown", 777 | "metadata": { 778 | "id": "-ye9rNzOHX90" 779 | }, 780 | "source": [ 781 | "# (TEST INTS)" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": null, 787 | "metadata": { 788 | "cellView": "form", 789 | "id": "zppMJ8gA3L4K" 790 | }, 791 | "outputs": [], 792 | "source": [ 793 | "#@title Test INTs\n", 794 | "\n", 795 | "train_data1 = random.choice(melody_chords_f + melody_chords_f_aug)\n", 796 | "\n", 797 | "print('Sample INTs', train_data1[:15])\n", 798 | "\n", 799 | "out = train_data1\n", 800 | "\n", 801 | "if len(out) != 0:\n", 802 | "\n", 803 | " song = out\n", 804 | " song_f = []\n", 805 | "\n", 806 | " time = 0\n", 807 | " dur = 0\n", 808 | " vel = 90\n", 809 | " pitch = 0\n", 810 | " channel = 0\n", 811 | "\n", 812 | " for ss in song:\n", 813 | "\n", 814 | " if 0 <= ss < 512:\n", 815 | "\n", 816 | " time += ss * 8\n", 817 | "\n", 818 | " if 512 <= ss < 4608:\n", 819 | "\n", 820 | " dur = ((ss-512) // 8) * 8\n", 821 | " vel = (((ss-512) % 8)+1) * 15\n", 822 | "\n", 823 | " if 4608 <= ss < 6784:\n", 824 | "\n", 825 | " patch = (ss-4608) // 128\n", 826 | "\n", 827 | " if patch == 16:\n", 828 | " channel = 9\n", 829 | " else:\n", 830 | " if 9 <= patch <= 14:\n", 831 | " channel = patch + 1\n", 832 | " else:\n", 833 | " channel = patch\n", 834 | "\n", 835 | " if patch == 15:\n", 836 | " channel = 15\n", 837 | "\n", 838 | " pitch = (ss-4608) % 128\n", 839 | "\n", 840 | " if patch == 17:\n", 841 | " break\n", 842 | "\n", 843 | " song_f.append(['note', time, dur, channel, pitch, vel ])\n", 844 | "\n", 845 | "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,\n", 846 | " output_signature = 'Experimental Music Transformer',\n", 847 | " output_file_name = '/content/Experimental-Music-Trnasformer-Composition',\n", 848 | " track_name='Project Los Angeles',\n", 849 | " list_of_MIDI_patches=[0, 10, 19, 24, 35, 40, 53, 56, 65, 9, 73, 87, 89, 99, 105, 117]\n", 850 | " )\n", 851 | "\n", 852 | "print('Done!')" 853 | ] 854 | }, 855 | { 856 | "cell_type": "markdown", 857 | "metadata": { 858 | "id": "YzCMd94Tu_gz" 859 | }, 860 | "source": [ 861 | "# Congrats! You did it! :)" 862 | ] 863 | } 864 | ], 865 | "metadata": { 866 | "colab": { 867 | "machine_shape": "hm", 868 | "private_outputs": true, 869 | "provenance": [] 870 | }, 871 | "gpuClass": "standard", 872 | "kernelspec": { 873 | "display_name": "Python 3 (ipykernel)", 874 | "language": "python", 875 | "name": "python3" 876 | }, 877 | "language_info": { 878 | "codemirror_mode": { 879 | "name": "ipython", 880 | "version": 3 881 | }, 882 | "file_extension": ".py", 883 | "mimetype": "text/x-python", 884 | "name": "python", 885 | "nbconvert_exporter": "python", 886 | "pygments_lexer": "ipython3", 887 | "version": "3.9.7" 888 | } 889 | }, 890 | "nbformat": 4, 891 | "nbformat_minor": 0 892 | } --------------------------------------------------------------------------------