├── .gitignore ├── LICENSE ├── README.md ├── c_rnn_gan.py ├── data └── classical │ ├── mozart │ ├── 11.mid │ ├── 1584morzcondma.mid │ ├── 1763wamk465a.mid │ ├── 1786gpk136a.mid │ ├── 1787gpk136b.mid │ ├── 1788gpk136c.mid │ ├── 1996rondoflute.mid │ ├── 2017mozk299no1.mid │ ├── 2018mozk299no2.mid │ ├── 2019mozk299no3.mid │ ├── 2021mozk191no1.mid │ ├── 2022mozk191no2.mid │ ├── 2023mozk191no3.mid │ ├── 2024panos3no1.mid │ ├── 2025pianos3no2.mid │ ├── 2026pianos3no3.mid │ ├── 2029rondoSteven2.mid │ ├── 2249phantasieSteven.mid │ ├── 2250adagioSteven.mid │ ├── 2294rondomozSteven.mid │ ├── 2344serenadeinbb.mid │ ├── 241Brondo.mid │ ├── 2421mozartfunnya.mid │ ├── 242Adag_bm.mid │ ├── 244cosifn2t.mid │ ├── 245div1.mid │ ├── 264diesirae.mid │ ├── 2677mozartspring.mid │ ├── 269Fantsy_c.mid │ ├── 2916requm1.mid │ ├── 2917requm2.mid │ ├── 2918requm4.mid │ ├── 351fandm.mid │ ├── 36.mid │ ├── 362sonatc1.mid │ ├── 363sonat_c2.mid │ ├── 364sonat_c3.mid │ ├── 365fant_dm.mid │ ├── 366fantsy_c.mid │ ├── 367sonat_a1.mid │ ├── 368sonat_a2.mid │ ├── 369sonat_a3.mid │ ├── 426mozfig.mid │ ├── 4321_duett.mid │ ├── 495p141ck.mid │ ├── 496p142ck.mid │ ├── 497p143ck.mid │ ├── 57811meno1.mid │ ├── 57911meno10.mid │ ├── 659andante.mid │ ├── 72.mid │ ├── MozartMagicFluteOverKV620e.mid │ ├── confuta.mid │ ├── cosifn2t.mid │ ├── jm_mozdi.mid │ ├── k452.mid │ ├── krebs.mid │ ├── magflute.mid │ ├── mfig.mid │ ├── moz_k191.mid │ ├── moz_k299.mid │ ├── mozartromancestev.mid │ ├── mozeine.mid │ ├── mozk175a.mid │ ├── mozk175b.mid │ ├── mozk175c.mid │ ├── mozk211a.mid │ ├── mozk211b.mid │ ├── mozk211c.mid │ ├── mozk216a.mid │ ├── mozk216b.mid │ ├── mozk216c.mid │ ├── mozk218a.mid │ ├── mozk218b.mid │ ├── mozk218c.mid │ ├── mozk219a.mid │ ├── mozk219b.mid │ ├── mozk219c.mid │ ├── mozk246a.mid │ ├── mozk246b.mid │ ├── mozk246c.mid │ ├── mozk281a.mid │ ├── mozk281b.mid │ ├── mozk281c.mid │ ├── mozk299a.mid │ ├── mozk299b.mid │ ├── mozk299c.mid │ ├── mozk309a.mid │ ├── mozk309b.mid │ ├── mozk309c.mid │ ├── mozk310a.mid │ ├── mozk310b.mid │ ├── mozk310c.mid │ ├── mozk311a.mid │ ├── mozk311b.mid │ ├── mozk311c.mid │ ├── mozk313a.mid │ ├── mozk313b.mid │ ├── mozk313c.mid │ ├── mozk331a.mid │ ├── mozk331b.mid │ ├── mozk331c.mid │ ├── mozk332a.mid │ ├── mozk332b.mid │ ├── mozk332c.mid │ ├── mozk333a.mid │ ├── mozk333b.mid │ ├── mozk333c.mid │ ├── mozk450a.mid │ ├── mozk450b.mid │ ├── mozk450c.mid │ ├── mozk453a.mid │ ├── mozk453b.mid │ ├── mozk453c.mid │ ├── mozk466a.mid │ ├── mozk466b.mid │ ├── mozk466c.mid │ ├── mozk467a.mid │ ├── mozk467b.mid │ ├── mozk467c.mid │ ├── mozk488a.mid │ ├── mozk488b.mid │ ├── mozk488c.mid │ ├── mozk545a.mid │ ├── mozk545b.mid │ ├── mozk545c.mid │ ├── mozk622a.mid │ ├── mozk622b.mid │ ├── mozk622c.mid │ ├── mozsq1.mid │ ├── mtmoz1.mid │ ├── mzflqnar.mid │ └── sym40-1.mid │ └── sonata-ish │ ├── 2026pianos3no3.mid │ ├── 2029rondoSteven2.mid │ ├── 241Brondo.mid │ ├── 2421mozartfunnya.mid │ ├── 362sonatc1.mid │ ├── 363sonat_c2.mid │ ├── 364sonat_c3.mid │ ├── 367sonat_a1.mid │ ├── 368sonat_a2.mid │ ├── 369sonat_a3.mid │ ├── 57911meno10.mid │ ├── mozartromancestev.mid │ ├── mozk246c.mid │ ├── mozk281a.mid │ ├── mozk281b.mid │ ├── mozk281c.mid │ ├── mozk299a.mid │ ├── mozk299b.mid │ ├── mozk299c.mid │ ├── mozk309a.mid │ ├── mozk309b.mid │ ├── mozk309c.mid │ ├── mozk310b.mid │ ├── mozk310c.mid │ ├── mozk311a.mid │ ├── mozk311b.mid │ ├── mozk311c.mid │ ├── mozk313a.mid │ ├── mozk313c.mid │ ├── mozk331a.mid │ ├── mozk331b.mid │ ├── mozk333a.mid │ ├── mozk333b.mid │ ├── mozk333c.mid │ ├── mozk450a.mid │ ├── mozk450b.mid │ ├── mozk450c.mid │ ├── mozk453a.mid │ ├── mozk453c.mid │ ├── mozk466b.mid │ ├── mozk467a.mid │ ├── mozk467c.mid │ ├── mozk488a.mid │ ├── mozk488c.mid │ ├── mozk545a.mid │ ├── mozk545b.mid │ ├── mozk545c.mid │ ├── mozk622a.mid │ ├── mozk622b.mid │ └── mozk622c.mid ├── generate.py ├── images └── simple_out.png ├── midi_statistics.py ├── models └── .gitkeep ├── music_data_utils.py ├── samples ├── sample1.mid └── sample2.mid ├── train.py └── train_simple.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # c-rnn-gan.pytorch 107 | old/ 108 | #data/ 109 | models/* 110 | !models/.gitkeep 111 | data/* 112 | !data/classical/ 113 | data/classical/* 114 | !data/classical/mozart 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Music Generation using C-RNN-GAN in PyTorch 2 | 3 | ## Introduction 4 | 5 | This project is a PyTorch implementation of [C-RNN-GAN](https://github.com/olofmogren/c-rnn-gan), which was originally developed in TensorFlow. In a nutshell, C-RNN-GAN is a GAN variant where both the Generator and the Discriminator are RNNs, with each output at each timestep from the Generator correspondingly fed into each timestep as input to the Discriminator. The goal is to train the Generator to output structured sequences, such as MIDI music which was used in the paper. If you'd like to know more, head over to this [link](http://mogren.one/publications/2016/c-rnn-gan/) or read the [paper](http://mogren.one/publications/2016/c-rnn-gan/mogren2016crnngan.pdf). 6 | 7 | ## Status 8 | 9 | The implementation can work well on simple sequences such as `a(n+1) = 2*a(n)`, where each element is twice of the previous. You can try this by executing: 10 | ``` 11 | $ python train_simple.py 12 | ``` 13 | This runs for 200 epochs, after which you should get something similar to this: 14 | 15 | ![Simple output](images/simple_out.png) 16 | 17 | When fed with MIDI data, training of this C-RNN-GAN implementation tends to be unstable. A lot of hyperparameter tweaking and training techniques such as freezing have been implemented, in which I managed to generate some convincing MIDI music, but not reproducible even while using the same set of hyperparameters and techniques. Nevertheless, I'll share these runs: 18 | 19 | ``` 20 | $ python train.py --g_lrn_rate=0.0001 --d_lrn_rate=0.0004 --g_pretraining_epochs=10 --d_pretraining_epochs=10 --label_smoothing 21 | ``` 22 | * Adam optimizer for both G and D, with learning rates 0.0001 and 0.0004, respectively 23 | * Pre-trains G and D independently for 10 epochs i.e. Train G for 10 epochs with D frozen & vice versa 24 | * Use label smoothing on real data i.e. use 0.9 as label instead of 1.0 25 | * Output: [sample1.mid](samples/sample1.mid) 26 | 27 | ``` 28 | $ python train.py --use_sgd --g_lrn_rate=0.01 --d_lrn_rate=0.005 --label_smoothing --feature_matching 29 | ``` 30 | * SGD optimizer for both G and D, with learning rates 0.01 and 0.005, respectively 31 | * Pre-trains G and D independently for 5 epochs (default) 32 | * Use label smoothing 33 | * Use feature matching for G loss (refer to [paper](http://mogren.one/publications/2016/c-rnn-gan/mogren2016crnngan.pdf) for more info) 34 | * Output: [sample2.mid](samples/sample2.mid) 35 | 36 | ## Deviations from Original 37 | 38 | This implementation is not an exact port of the original. These are some of the differences in particular: 39 | 40 | * Different set of hyperparameters 41 | * Added label smoothing 42 | * Added option to use Adam optimization 43 | * Pitch is not represented as frequency, but simply as numerical representation of each tone in MIDI (e.g. C3 = 60) 44 | * Training is done only on small subset of classical data (sonata-sounding pieces of Mozart in particular, see [sonata-ish](data/classical/sonata-ish)) 45 | 46 | ## Prerequisites 47 | 48 | * Python 3.6 49 | * PyTorch 50 | * [python3-midi](https://github.com/louisabraham/python3-midi) 51 | 52 | ## License 53 | 54 | This project is licensed under the Apache License, Version 2.0 - see the [LICENSE](LICENSE) file for details 55 | -------------------------------------------------------------------------------- /c_rnn_gan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Christopher John Bayron 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file has been created by Christopher John Bayron based on "rnn_gan.py" 16 | # by Olof Mogren. The referenced code is available in: 17 | # 18 | # https://github.com/olofmogren/c-rnn-gan 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | class Generator(nn.Module): 25 | ''' C-RNN-GAN generator 26 | ''' 27 | def __init__(self, num_feats, hidden_units=256, drop_prob=0.6, use_cuda=False): 28 | super(Generator, self).__init__() 29 | 30 | # params 31 | self.hidden_dim = hidden_units 32 | self.use_cuda = use_cuda 33 | self.num_feats = num_feats 34 | 35 | self.fc_layer1 = nn.Linear(in_features=(num_feats*2), out_features=hidden_units) 36 | self.lstm_cell1 = nn.LSTMCell(input_size=hidden_units, hidden_size=hidden_units) 37 | self.dropout = nn.Dropout(p=drop_prob) 38 | self.lstm_cell2 = nn.LSTMCell(input_size=hidden_units, hidden_size=hidden_units) 39 | self.fc_layer2 = nn.Linear(in_features=hidden_units, out_features=num_feats) 40 | 41 | def forward(self, z, states): 42 | ''' Forward prop 43 | ''' 44 | if self.use_cuda: 45 | z = z.cuda() 46 | # z: (batch_size, seq_len, num_feats) 47 | # z here is the uniformly random vector 48 | batch_size, seq_len, num_feats = z.shape 49 | 50 | # split to seq_len * (batch_size * num_feats) 51 | z = torch.split(z, 1, dim=1) 52 | z = [z_step.squeeze(dim=1) for z_step in z] 53 | 54 | # create dummy-previous-output for first timestep 55 | prev_gen = torch.empty([batch_size, num_feats]).uniform_() 56 | if self.use_cuda: 57 | prev_gen = prev_gen.cuda() 58 | 59 | # manually process each timestep 60 | state1, state2 = states # (h1, c1), (h2, c2) 61 | gen_feats = [] 62 | for z_step in z: 63 | # concatenate current input features and previous timestep output features 64 | concat_in = torch.cat((z_step, prev_gen), dim=-1) 65 | out = F.relu(self.fc_layer1(concat_in)) 66 | h1, c1 = self.lstm_cell1(out, state1) 67 | h1 = self.dropout(h1) # feature dropout only (no recurrent dropout) 68 | h2, c2 = self.lstm_cell2(h1, state2) 69 | prev_gen = self.fc_layer2(h2) 70 | # prev_gen = F.relu(self.fc_layer2(h2)) #DEBUG 71 | gen_feats.append(prev_gen) 72 | 73 | state1 = (h1, c1) 74 | state2 = (h2, c2) 75 | 76 | # seq_len * (batch_size * num_feats) -> (batch_size * seq_len * num_feats) 77 | gen_feats = torch.stack(gen_feats, dim=1) 78 | 79 | states = (state1, state2) 80 | return gen_feats, states 81 | 82 | def init_hidden(self, batch_size): 83 | ''' Initialize hidden state ''' 84 | # create NEW tensor with SAME TYPE as weight 85 | weight = next(self.parameters()).data 86 | 87 | if (self.use_cuda): 88 | hidden = ((weight.new(batch_size, self.hidden_dim).zero_().cuda(), 89 | weight.new(batch_size, self.hidden_dim).zero_().cuda()), 90 | (weight.new(batch_size, self.hidden_dim).zero_().cuda(), 91 | weight.new(batch_size, self.hidden_dim).zero_().cuda())) 92 | else: 93 | hidden = ((weight.new(batch_size, self.hidden_dim).zero_(), 94 | weight.new(batch_size, self.hidden_dim).zero_()), 95 | (weight.new(batch_size, self.hidden_dim).zero_(), 96 | weight.new(batch_size, self.hidden_dim).zero_())) 97 | 98 | return hidden 99 | 100 | 101 | class Discriminator(nn.Module): 102 | ''' C-RNN-GAN discrminator 103 | ''' 104 | def __init__(self, num_feats, hidden_units=256, drop_prob=0.6, use_cuda=False): 105 | 106 | super(Discriminator, self).__init__() 107 | 108 | # params 109 | self.hidden_dim = hidden_units 110 | self.num_layers = 2 111 | self.use_cuda = use_cuda 112 | 113 | self.dropout = nn.Dropout(p=drop_prob) 114 | self.lstm = nn.LSTM(input_size=num_feats, hidden_size=hidden_units, 115 | num_layers=self.num_layers, batch_first=True, dropout=drop_prob, 116 | bidirectional=True) 117 | self.fc_layer = nn.Linear(in_features=(2*hidden_units), out_features=1) 118 | 119 | def forward(self, note_seq, state): 120 | ''' Forward prop 121 | ''' 122 | if self.use_cuda: 123 | note_seq = note_seq.cuda() 124 | 125 | # note_seq: (batch_size, seq_len, num_feats) 126 | drop_in = self.dropout(note_seq) # input with dropout 127 | # (batch_size, seq_len, num_directions*hidden_size) 128 | lstm_out, state = self.lstm(drop_in, state) 129 | # (batch_size, seq_len, 1) 130 | out = self.fc_layer(lstm_out) 131 | out = torch.sigmoid(out) 132 | 133 | num_dims = len(out.shape) 134 | reduction_dims = tuple(range(1, num_dims)) 135 | # (batch_size) 136 | out = torch.mean(out, dim=reduction_dims) 137 | 138 | return out, lstm_out, state 139 | 140 | def init_hidden(self, batch_size): 141 | ''' Initialize hidden state ''' 142 | # create NEW tensor with SAME TYPE as weight 143 | weight = next(self.parameters()).data 144 | 145 | layer_mult = 2 # for being bidirectional 146 | 147 | if self.use_cuda: 148 | hidden = (weight.new(self.num_layers * layer_mult, batch_size, 149 | self.hidden_dim).zero_().cuda(), 150 | weight.new(self.num_layers * layer_mult, batch_size, 151 | self.hidden_dim).zero_().cuda()) 152 | else: 153 | hidden = (weight.new(self.num_layers * layer_mult, batch_size, 154 | self.hidden_dim).zero_(), 155 | weight.new(self.num_layers * layer_mult, batch_size, 156 | self.hidden_dim).zero_()) 157 | 158 | return hidden 159 | -------------------------------------------------------------------------------- /data/classical/mozart/11.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/11.mid -------------------------------------------------------------------------------- /data/classical/mozart/1584morzcondma.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/1584morzcondma.mid -------------------------------------------------------------------------------- /data/classical/mozart/1763wamk465a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/1763wamk465a.mid -------------------------------------------------------------------------------- /data/classical/mozart/1786gpk136a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/1786gpk136a.mid -------------------------------------------------------------------------------- /data/classical/mozart/1787gpk136b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/1787gpk136b.mid -------------------------------------------------------------------------------- /data/classical/mozart/1788gpk136c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/1788gpk136c.mid -------------------------------------------------------------------------------- /data/classical/mozart/1996rondoflute.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/1996rondoflute.mid -------------------------------------------------------------------------------- /data/classical/mozart/2017mozk299no1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2017mozk299no1.mid -------------------------------------------------------------------------------- /data/classical/mozart/2018mozk299no2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2018mozk299no2.mid -------------------------------------------------------------------------------- /data/classical/mozart/2019mozk299no3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2019mozk299no3.mid -------------------------------------------------------------------------------- /data/classical/mozart/2021mozk191no1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2021mozk191no1.mid -------------------------------------------------------------------------------- /data/classical/mozart/2022mozk191no2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2022mozk191no2.mid -------------------------------------------------------------------------------- /data/classical/mozart/2023mozk191no3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2023mozk191no3.mid -------------------------------------------------------------------------------- /data/classical/mozart/2024panos3no1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2024panos3no1.mid -------------------------------------------------------------------------------- /data/classical/mozart/2025pianos3no2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2025pianos3no2.mid -------------------------------------------------------------------------------- /data/classical/mozart/2026pianos3no3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2026pianos3no3.mid -------------------------------------------------------------------------------- /data/classical/mozart/2029rondoSteven2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2029rondoSteven2.mid -------------------------------------------------------------------------------- /data/classical/mozart/2249phantasieSteven.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2249phantasieSteven.mid -------------------------------------------------------------------------------- /data/classical/mozart/2250adagioSteven.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2250adagioSteven.mid -------------------------------------------------------------------------------- /data/classical/mozart/2294rondomozSteven.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2294rondomozSteven.mid -------------------------------------------------------------------------------- /data/classical/mozart/2344serenadeinbb.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2344serenadeinbb.mid -------------------------------------------------------------------------------- /data/classical/mozart/241Brondo.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/241Brondo.mid -------------------------------------------------------------------------------- /data/classical/mozart/2421mozartfunnya.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2421mozartfunnya.mid -------------------------------------------------------------------------------- /data/classical/mozart/242Adag_bm.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/242Adag_bm.mid -------------------------------------------------------------------------------- /data/classical/mozart/244cosifn2t.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/244cosifn2t.mid -------------------------------------------------------------------------------- /data/classical/mozart/245div1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/245div1.mid -------------------------------------------------------------------------------- /data/classical/mozart/264diesirae.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/264diesirae.mid -------------------------------------------------------------------------------- /data/classical/mozart/2677mozartspring.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2677mozartspring.mid -------------------------------------------------------------------------------- /data/classical/mozart/269Fantsy_c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/269Fantsy_c.mid -------------------------------------------------------------------------------- /data/classical/mozart/2916requm1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2916requm1.mid -------------------------------------------------------------------------------- /data/classical/mozart/2917requm2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2917requm2.mid -------------------------------------------------------------------------------- /data/classical/mozart/2918requm4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/2918requm4.mid -------------------------------------------------------------------------------- /data/classical/mozart/351fandm.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/351fandm.mid -------------------------------------------------------------------------------- /data/classical/mozart/36.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/36.mid -------------------------------------------------------------------------------- /data/classical/mozart/362sonatc1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/362sonatc1.mid -------------------------------------------------------------------------------- /data/classical/mozart/363sonat_c2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/363sonat_c2.mid -------------------------------------------------------------------------------- /data/classical/mozart/364sonat_c3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/364sonat_c3.mid -------------------------------------------------------------------------------- /data/classical/mozart/365fant_dm.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/365fant_dm.mid -------------------------------------------------------------------------------- /data/classical/mozart/366fantsy_c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/366fantsy_c.mid -------------------------------------------------------------------------------- /data/classical/mozart/367sonat_a1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/367sonat_a1.mid -------------------------------------------------------------------------------- /data/classical/mozart/368sonat_a2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/368sonat_a2.mid -------------------------------------------------------------------------------- /data/classical/mozart/369sonat_a3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/369sonat_a3.mid -------------------------------------------------------------------------------- /data/classical/mozart/426mozfig.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/426mozfig.mid -------------------------------------------------------------------------------- /data/classical/mozart/4321_duett.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/4321_duett.mid -------------------------------------------------------------------------------- /data/classical/mozart/495p141ck.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/495p141ck.mid -------------------------------------------------------------------------------- /data/classical/mozart/496p142ck.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/496p142ck.mid -------------------------------------------------------------------------------- /data/classical/mozart/497p143ck.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/497p143ck.mid -------------------------------------------------------------------------------- /data/classical/mozart/57811meno1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/57811meno1.mid -------------------------------------------------------------------------------- /data/classical/mozart/57911meno10.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/57911meno10.mid -------------------------------------------------------------------------------- /data/classical/mozart/659andante.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/659andante.mid -------------------------------------------------------------------------------- /data/classical/mozart/72.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/72.mid -------------------------------------------------------------------------------- /data/classical/mozart/MozartMagicFluteOverKV620e.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/MozartMagicFluteOverKV620e.mid -------------------------------------------------------------------------------- /data/classical/mozart/confuta.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/confuta.mid -------------------------------------------------------------------------------- /data/classical/mozart/cosifn2t.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/cosifn2t.mid -------------------------------------------------------------------------------- /data/classical/mozart/jm_mozdi.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/jm_mozdi.mid -------------------------------------------------------------------------------- /data/classical/mozart/k452.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/k452.mid -------------------------------------------------------------------------------- /data/classical/mozart/krebs.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/krebs.mid -------------------------------------------------------------------------------- /data/classical/mozart/magflute.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/magflute.mid -------------------------------------------------------------------------------- /data/classical/mozart/mfig.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mfig.mid -------------------------------------------------------------------------------- /data/classical/mozart/moz_k191.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/moz_k191.mid -------------------------------------------------------------------------------- /data/classical/mozart/moz_k299.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/moz_k299.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozartromancestev.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozartromancestev.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozeine.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozeine.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk175a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk175a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk175b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk175b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk175c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk175c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk211a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk211a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk211b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk211b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk211c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk211c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk216a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk216a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk216b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk216b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk216c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk216c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk218a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk218a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk218b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk218b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk218c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk218c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk219a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk219a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk219b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk219b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk219c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk219c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk246a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk246a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk246b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk246b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk246c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk246c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk281a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk281a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk281b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk281b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk281c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk281c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk299a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk299a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk299b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk299b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk299c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk299c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk309a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk309a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk309b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk309b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk309c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk309c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk310a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk310a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk310b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk310b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk310c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk310c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk311a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk311a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk311b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk311b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk311c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk311c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk313a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk313a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk313b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk313b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk313c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk313c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk331a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk331a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk331b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk331b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk331c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk331c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk332a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk332a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk332b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk332b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk332c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk332c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk333a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk333a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk333b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk333b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk333c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk333c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk450a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk450a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk450b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk450b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk450c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk450c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk453a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk453a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk453b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk453b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk453c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk453c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk466a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk466a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk466b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk466b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk466c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk466c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk467a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk467a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk467b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk467b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk467c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk467c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk488a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk488a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk488b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk488b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk488c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk488c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk545a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk545a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk545b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk545b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk545c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk545c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk622a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk622a.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk622b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk622b.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozk622c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozk622c.mid -------------------------------------------------------------------------------- /data/classical/mozart/mozsq1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mozsq1.mid -------------------------------------------------------------------------------- /data/classical/mozart/mtmoz1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mtmoz1.mid -------------------------------------------------------------------------------- /data/classical/mozart/mzflqnar.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/mzflqnar.mid -------------------------------------------------------------------------------- /data/classical/mozart/sym40-1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/mozart/sym40-1.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/2026pianos3no3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/2026pianos3no3.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/2029rondoSteven2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/2029rondoSteven2.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/241Brondo.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/241Brondo.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/2421mozartfunnya.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/2421mozartfunnya.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/362sonatc1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/362sonatc1.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/363sonat_c2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/363sonat_c2.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/364sonat_c3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/364sonat_c3.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/367sonat_a1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/367sonat_a1.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/368sonat_a2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/368sonat_a2.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/369sonat_a3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/369sonat_a3.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/57911meno10.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/57911meno10.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozartromancestev.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozartromancestev.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk246c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk246c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk281a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk281a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk281b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk281b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk281c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk281c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk299a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk299a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk299b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk299b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk299c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk299c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk309a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk309a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk309b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk309b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk309c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk309c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk310b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk310b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk310c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk310c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk311a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk311a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk311b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk311b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk311c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk311c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk313a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk313a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk313c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk313c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk331a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk331a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk331b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk331b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk333a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk333a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk333b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk333b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk333c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk333c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk450a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk450a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk450b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk450b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk450c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk450c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk453a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk453a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk453c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk453c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk466b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk466b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk467a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk467a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk467c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk467c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk488a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk488a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk488c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk488c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk545a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk545a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk545b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk545b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk545c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk545c.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk622a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk622a.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk622b.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk622b.mid -------------------------------------------------------------------------------- /data/classical/sonata-ish/mozk622c.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/data/classical/sonata-ish/mozk622c.mid -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Christopher John Bayron 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file has been created by Christopher John Bayron based on "rnn_gan.py" 16 | # by Olof Mogren. The referenced code is available in: 17 | # 18 | # https://github.com/olofmogren/c-rnn-gan 19 | 20 | import os 21 | from argparse import ArgumentParser 22 | import numpy as np 23 | import torch 24 | 25 | from c_rnn_gan import Generator 26 | import music_data_utils 27 | 28 | CKPT_DIR = 'models' 29 | G_FN = 'c_rnn_gan_g.pth' 30 | MAX_SEQ_LEN = 256 31 | FILENAME = 'sample.mid' 32 | 33 | def generate(n): 34 | ''' Sample MIDI from trained generator model 35 | ''' 36 | # prepare model 37 | dataloader = music_data_utils.MusicDataLoader(datadir=None) 38 | num_feats = dataloader.get_num_song_features() 39 | 40 | use_gpu = torch.cuda.is_available() 41 | g_model = Generator(num_feats, use_cuda=use_gpu) 42 | 43 | if not use_gpu: 44 | ckpt = torch.load(os.path.join(CKPT_DIR, G_FN), map_location='cpu') 45 | else: 46 | ckpt = torch.load(os.path.join(CKPT_DIR, G_FN)) 47 | 48 | g_model.load_state_dict(ckpt) 49 | 50 | # generate from model then save to MIDI file 51 | g_states = g_model.init_hidden(1) 52 | z = torch.empty([1, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 53 | if use_gpu: 54 | z = z.cuda() 55 | g_model.cuda() 56 | 57 | g_model.eval() 58 | 59 | full_song_data = [] 60 | for i in range(n): 61 | g_feats, g_states = g_model(z, g_states) 62 | song_data = g_feats.squeeze().cpu() 63 | song_data = song_data.detach().numpy() 64 | full_song_data.append(song_data) 65 | 66 | if len(full_song_data) > 1: 67 | full_song_data = np.concatenate(full_song_data, axis=0) 68 | else: 69 | full_song_data = full_song_data[0] 70 | 71 | dataloader.save_data(FILENAME, full_song_data) 72 | print('Full sequence shape: ', full_song_data.shape) 73 | print('Generated {}'.format(FILENAME)) 74 | 75 | 76 | if __name__ == "__main__": 77 | ARG_PARSER = ArgumentParser() 78 | # number of times to execute generator model; 79 | # all generated data are concatenated to form a single longer sequence 80 | ARG_PARSER.add_argument('-n', default=1, type=int) 81 | ARGS = ARG_PARSER.parse_args() 82 | 83 | generate(ARGS.n) 84 | -------------------------------------------------------------------------------- /images/simple_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/images/simple_out.png -------------------------------------------------------------------------------- /midi_statistics.py: -------------------------------------------------------------------------------- 1 | # Tools to load and save midi files for the rnn-gan-project. 2 | # 3 | # Written by Olof Mogren, http://mogren.one/ 4 | # 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | 19 | 20 | import sys, os, midi, math, string, time 21 | 22 | GENRE = 0 23 | COMPOSER = 1 24 | SONG_DATA = 2 25 | 26 | # INDICES IN BATCHES: 27 | LENGTH = 0 28 | FREQ = 1 29 | VELOCITY = 2 30 | TICKS_FROM_PREV_START = 3 31 | 32 | # INDICES IN SONG DATA (NOT YET BATCHED): 33 | BEGIN_TICK = 3 34 | CHANNEL = 4 35 | 36 | debug = '' 37 | #debug = 'overfit' 38 | 39 | 40 | base_tones = {'C': 0, 41 | 'C#': 1, 42 | 'D': 2, 43 | 'D#': 3, 44 | 'E': 4, 45 | 'F': 5, 46 | 'F#': 6, 47 | 'G': 7, 48 | 'G#': 8, 49 | 'A': 9, 50 | 'A#': 10, 51 | 'B': 11} 52 | 53 | scale = {} 54 | #Major scale: 55 | scale['major'] = [0,2,4,5,7,9,11] 56 | #(W-W-H-W-W-W-H) 57 | #(2 2 1 2 2 2 1) 58 | 59 | #Natural minor scale: 60 | scale['natural_minor'] = [0,2,3,5,7,8,10] 61 | #(W-H-W-W-H-W-W) 62 | #(2 1 2 2 1 2 2) 63 | 64 | #Harmonic minor scale: 65 | scale['harmonic_minor'] = [0,2,3,5,7,8,11] 66 | #(W-H-W-W-H-WH-H) 67 | #(2 1 2 2 1 3 1) 68 | 69 | tone_names = {} 70 | for tone_name in base_tones: 71 | tone_names[base_tones[tone_name]] = tone_name 72 | 73 | 74 | def get_tones(midi_pattern): 75 | """ 76 | returns a dict of statistics, keys: [scale_distribution, 77 | """ 78 | 79 | tones = [] 80 | 81 | for track in midi_pattern: 82 | for event in track: 83 | if type(event) == midi.events.SetTempoEvent: 84 | pass # These are currently ignored 85 | elif (type(event) == midi.events.NoteOffEvent) or \ 86 | (type(event) == midi.events.NoteOnEvent and \ 87 | event.velocity == 0): 88 | pass # not needed here 89 | elif type(event) == midi.events.NoteOnEvent: 90 | tones.append(event.data[0]) 91 | return tones 92 | 93 | def detect_beat(midi_pattern): 94 | """ 95 | returns a dict of statistics, keys: [scale_distribution, 96 | """ 97 | 98 | abs_ticks = [] 99 | 100 | # Tempo: 101 | ticks_per_quarter_note = float(midi_pattern.resolution) 102 | 103 | for track in midi_pattern: 104 | abs_tick=0 105 | for event in track: 106 | abs_tick += event.tick 107 | if type(event) == midi.events.SetTempoEvent: 108 | pass # These are currently ignored 109 | elif (type(event) == midi.events.NoteOffEvent) or \ 110 | (type(event) == midi.events.NoteOnEvent and \ 111 | event.velocity == 0): 112 | pass 113 | elif type(event) == midi.events.NoteOnEvent: 114 | abs_ticks.append(abs_tick) 115 | stats = {} 116 | for quarter_note_estimate in range(int(ticks_per_quarter_note), int(0.75*ticks_per_quarter_note), -1): 117 | #print('est: {}'.format(quarter_note_estimate)) 118 | avg_ticks_off = [] 119 | for begin_tick in range(quarter_note_estimate): 120 | ticks_off = [] 121 | for abs_tick in abs_ticks: 122 | #print('abs_tick: {} % {}'.format(abs_tick, quarter_note_estimate/4)) 123 | sixteenth_note_estimate = quarter_note_estimate//4 124 | ticks_off_sixteenths = int((begin_tick+abs_tick)%sixteenth_note_estimate) 125 | if ticks_off_sixteenths > sixteenth_note_estimate//2: 126 | # off, but before beat 127 | ticks_off_sixteenths = -(ticks_off_sixteenths-sixteenth_note_estimate) 128 | #print('ticks_off: {}'.format(ticks_off_sixteenths)) 129 | ticks_off.append(ticks_off_sixteenths) 130 | avg_ticks_off.append(float(sum(ticks_off))/float(len(ticks_off))) 131 | #print('avg_ticks_off: {}. min: {}.'.format(avg_ticks_off, min(avg_ticks_off))) 132 | stats[quarter_note_estimate] = min(avg_ticks_off) 133 | return stats 134 | 135 | def get_abs_ticks(midi_pattern): 136 | abs_ticks = [] 137 | for track in midi_pattern: 138 | abs_tick=0 139 | for event in track: 140 | abs_tick += event.tick 141 | if type(event) == midi.events.SetTempoEvent: 142 | pass # These are currently ignored 143 | elif (type(event) == midi.events.NoteOffEvent) or \ 144 | (type(event) == midi.events.NoteOnEvent and \ 145 | event.velocity == 0): 146 | pass 147 | elif type(event) == midi.events.NoteOnEvent: 148 | abs_ticks.append(abs_tick) 149 | abs_ticks.sort() 150 | return abs_ticks 151 | 152 | def get_top_k_intervals(midi_pattern, k): 153 | """ 154 | returns a fraction of the noteon events in midi_pattern that are polyphonous 155 | (several notes occurring at the same time). 156 | Here, two note on events are counted as the same event if they 157 | occur at the same time, and in this case it is considered a polyphonous event. 158 | """ 159 | intervals = {} 160 | abs_ticks = get_abs_ticks(midi_pattern) 161 | accumulator = 0 162 | last_abs_tick = 0 163 | for abs_tick in abs_ticks: 164 | interval = abs_tick-last_abs_tick 165 | if interval not in intervals: 166 | intervals[interval] = 0 167 | intervals[interval] += 1 168 | accumulator += 1 169 | last_abs_tick = abs_tick 170 | intervals_list = [(interval, intervals[interval]/float(accumulator)) for interval in intervals] 171 | intervals_list.sort(key=lambda i: i[1], reverse=True) 172 | return intervals_list[:k] 173 | 174 | 175 | def get_polyphony_score(midi_pattern): 176 | """ 177 | returns a fraction of the noteon events in midi_pattern that are polyphonous 178 | (several notes occurring at the same time). 179 | Here, two note on events are counted as the same event if they 180 | occur at the same time, and in this case it is considered a polyphonous event. 181 | """ 182 | 183 | abs_ticks = get_abs_ticks(midi_pattern) 184 | monophonous_events = 0 185 | polyphonous_events = 0 186 | 187 | last_abs_tick = 0 188 | tones_in_current_event = 0 189 | for abs_tick in abs_ticks: 190 | if abs_tick == last_abs_tick: 191 | tones_in_current_event += 1 192 | else: 193 | if tones_in_current_event == 1: 194 | monophonous_events += 1 195 | elif tones_in_current_event > 1: 196 | polyphonous_events += 1 197 | tones_in_current_event = 1 198 | last_abs_tick = abs_tick 199 | if tones_in_current_event == 1: 200 | monophonous_events += 1 201 | elif tones_in_current_event > 1: 202 | polyphonous_events += 1 203 | if polyphonous_events == 0: 204 | return 0.0 205 | return float(polyphonous_events)/(polyphonous_events+monophonous_events) 206 | 207 | 208 | def get_rhythm_stats(midi_pattern): 209 | """ 210 | returns a dict of statistics, keys: [scale_distribution, 211 | """ 212 | 213 | abs_ticks = [] 214 | 215 | # Tempo: 216 | ticks_per_quarter_note = float(midi_pattern.resolution) 217 | 218 | # Multiply with output_ticks_pr_input_tick for output ticks. 219 | for track in midi_pattern: 220 | abs_tick=0 221 | for event in track: 222 | abs_tick += event.tick 223 | if type(event) == midi.events.SetTempoEvent: 224 | pass # These are currently ignored 225 | elif (type(event) == midi.events.NoteOffEvent) or \ 226 | (type(event) == midi.events.NoteOnEvent and \ 227 | event.velocity == 0): 228 | pass 229 | elif type(event) == midi.events.NoteOnEvent: 230 | abs_ticks.append(abs_tick) 231 | stats = {} 232 | for abs_tick in abs_ticks: 233 | ticks_since_quarter_note = int(abs_tick%ticks_per_quarter_note) 234 | if ticks_since_quarter_note not in stats: 235 | stats[ticks_since_quarter_note] = 1 236 | else: 237 | stats[ticks_since_quarter_note] += 1 238 | return stats 239 | 240 | 241 | def get_intensities(midi_pattern): 242 | """ 243 | returns a dict of statistics, keys: [scale_distribution, 244 | """ 245 | 246 | intensities = [] 247 | 248 | for track in midi_pattern: 249 | abs_tick=0 250 | for event in track: 251 | abs_tick += event.tick 252 | if type(event) == midi.events.SetTempoEvent: 253 | pass # These are currently ignored 254 | elif (type(event) == midi.events.NoteOffEvent) or \ 255 | (type(event) == midi.events.NoteOnEvent and \ 256 | event.velocity == 0): 257 | pass 258 | elif type(event) == midi.events.NoteOnEvent: 259 | intensities.append(event.velocity) 260 | return (min(intensities), max(intensities)) 261 | 262 | 263 | def get_midi_pattern(filename): 264 | try: 265 | return midi.read_midifile(filename) 266 | except: 267 | print ('Error reading {}'.format(filename)) 268 | return None 269 | 270 | def tones_to_scales(tones): 271 | """ 272 | Midi to tone name (octave: -5): 273 | 0: C 274 | 1: C# 275 | 2: D 276 | 3: D# 277 | 4: E 278 | 5: F 279 | 6: F# 280 | 7: G 281 | 8: G# 282 | 9: A 283 | 10: A# 284 | 11: B 285 | 286 | Melodic minor scale is ignored. 287 | 288 | One octave is 12 tones. 289 | """ 290 | counts = {} 291 | for base_tone in base_tones: 292 | counts[base_tone] = {} 293 | counts[base_tone]['major'] = 0 294 | counts[base_tone]['natural_minor'] = 0 295 | counts[base_tone]['harmonic_minor'] = 0 296 | 297 | if not len(tones): 298 | frequencies = {} 299 | for base_tone in base_tones: 300 | frequencies[base_tone] = {} 301 | for scale_label in scale: 302 | frequencies[base_tone][scale_label] = 0.0 303 | return frequencies 304 | for tone in tones: 305 | for base_tone in base_tones: 306 | for scale_label in scale: 307 | if tone%12-base_tones[base_tone] in scale[scale_label]: 308 | counts[base_tone][scale_label] += 1 309 | frequencies = {} 310 | for base_tone in counts: 311 | frequencies[base_tone] = {} 312 | for scale_label in counts[base_tone]: 313 | frequencies[base_tone][scale_label] = float(counts[base_tone][scale_label])/float(len(tones)) 314 | return frequencies 315 | 316 | def repetitions(tones): 317 | rs = {} 318 | #print(tones) 319 | #print(len(tones)/2) 320 | for l in range(2, min(len(tones)//2, 10)): 321 | #print (l) 322 | rs[l] = 0 323 | for i in range(len(tones)-l*2): 324 | for j in range(i+l,len(tones)-l): 325 | #print('comparing \'{}\' and \'{}\''.format(tones[i:i+l], tones[j:j+l])) 326 | if tones[i:i+l] == tones[j:j+l]: 327 | rs[l] += 1 328 | rs2 = {} 329 | for r in rs: 330 | if rs[r]: 331 | rs2[r] = rs[r] 332 | return rs2 333 | 334 | 335 | def tone_to_tone_name(tone): 336 | """ 337 | Midi to tone name (octave: -5): 338 | 0: C 339 | 1: C# 340 | 2: D 341 | 3: D# 342 | 4: E 343 | 5: F 344 | 6: F# 345 | 7: G 346 | 8: G# 347 | 9: A 348 | 10: A# 349 | 11: B 350 | 351 | One octave is 12 tones. 352 | """ 353 | 354 | base_tone = tone_names[tone%12] 355 | octave = tone//12-5 356 | return '{} {}'.format(base_tone, octave) 357 | 358 | def max_likelihood_scale(tones): 359 | scale_statistics = tones_to_scales(tones) 360 | stat_list = [] 361 | for base_tone in scale_statistics: 362 | for scale_label in scale_statistics[base_tone]: 363 | stat_list.append((base_tone, scale_label, scale_statistics[base_tone][scale_label])) 364 | stat_list.sort(key=lambda e: e[2], reverse=True) 365 | return (stat_list[0][0]+' '+stat_list[0][1], stat_list[0][2]) 366 | 367 | def tone_to_freq(tone): 368 | """ 369 | returns the frequency of a tone. 370 | 371 | formulas from 372 | * https://en.wikipedia.org/wiki/MIDI_Tuning_Standard 373 | * https://en.wikipedia.org/wiki/Cent_(music) 374 | """ 375 | return math.pow(2, ((float(tone)-69.0)/12.0)) * 440.0 376 | 377 | def freq_to_tone(freq): 378 | """ 379 | returns a dict d where 380 | d['tone'] is the base tone in midi standard 381 | d['cents'] is the cents to make the tone into the exact-ish frequency provided. 382 | multiply this with 8192 to get the midi pitch level. 383 | 384 | formulas from 385 | * https://en.wikipedia.org/wiki/MIDI_Tuning_Standard 386 | * https://en.wikipedia.org/wiki/Cent_(music) 387 | """ 388 | if freq == 0.0: 389 | return None 390 | float_tone = (69.0+12*math.log(float(freq)/440.0, 2)) 391 | int_tone = int(float_tone) 392 | cents = int(1200*math.log(float(freq)/tone_to_freq(int_tone), 2)) 393 | return {'tone': int_tone, 'cents': cents} 394 | 395 | def cents_to_pitchwheel_units(cents): 396 | return int(40.96*(float(cents))) 397 | 398 | def get_all_stats(midi_pattern): 399 | stats = {} 400 | if not midi_pattern: 401 | print('Failed to read midi pattern.') 402 | return None 403 | tones = get_tones(midi_pattern) 404 | if len(tones) == 0: 405 | print('This is an empty song.') 406 | return None 407 | stats['num_tones'] = len(tones) 408 | stats['tone_min'] = min(tones) 409 | stats['freq_min'] = tone_to_freq(min(tones)) 410 | stats['tone_max'] = max(tones) 411 | stats['freq_max'] = tone_to_freq(max(tones)) 412 | stats['tone_span'] = max(tones)-min(tones) 413 | stats['freq_span'] = tone_to_freq(max(tones))-tone_to_freq(min(tones)) 414 | stats['tones_unique'] = len(set(tones)) 415 | rs = repetitions(tones) 416 | for r in range(2,10): 417 | if r in rs: 418 | stats['repetitions_{}'.format(r)] = rs[r] 419 | else: 420 | stats['repetitions_{}'.format(r)] = 0 421 | 422 | ml = max_likelihood_scale(tones) 423 | stats['scale'] = ml[0] 424 | stats['scale_score'] = ml[1] 425 | 426 | beat_stats = detect_beat(midi_pattern) 427 | minval = float(midi_pattern.resolution) 428 | argmin = -1 429 | for beat in beat_stats: 430 | #print('Looking at beat: {}. Avg ticks off: {}.'.format(beat, beat_stats[beat])) 431 | if beat_stats[beat] < minval: 432 | minval = beat_stats[beat] 433 | argmin = beat 434 | stats['estimated_beat'] = argmin 435 | stats['estimated_beat_avg_ticks_off'] = minval 436 | (min_int, max_int) = get_intensities(midi_pattern) 437 | stats['intensity_min'] = min_int 438 | stats['intensity_max'] = max_int 439 | stats['intensity_span'] = max_int-min_int 440 | 441 | stats['polyphony_score'] = get_polyphony_score(midi_pattern) 442 | stats['top_10_intervals'] = get_top_k_intervals(midi_pattern, 10) # NOT melodic interval, tick interval 443 | stats['top_2_interval_difference'] = 0.0 444 | if len(stats['top_10_intervals']) > 1: 445 | stats['top_2_interval_difference'] = abs(stats['top_10_intervals'][1][0]-stats['top_10_intervals'][0][0]) 446 | stats['top_3_interval_difference'] = 0.0 447 | if len(stats['top_10_intervals']) > 2: 448 | stats['top_3_interval_difference'] = abs(stats['top_10_intervals'][2][0]-stats['top_10_intervals'][0][0]) 449 | 450 | return stats 451 | 452 | def get_gnuplot_line(midi_patterns, i, showheader=True): 453 | stats = [] 454 | print('#getting stats...') 455 | stats_time = time.time() 456 | for p in midi_patterns: 457 | stats.append(get_all_stats(p)) 458 | print('done. time: {}'.format(time.time()-stats_time)) 459 | #print(stats) 460 | stats_keys_string = ['scale'] 461 | stats_keys = ['scale_score', 'tone_min', 'tone_max', 'tone_span', 'freq_min', 'freq_max', 'freq_span', 'tones_unique', 'repetitions_2', 'repetitions_3', 'repetitions_4', 'repetitions_5', 'repetitions_6', 'repetitions_7', 'repetitions_8', 'repetitions_9', 'estimated_beat', 'estimated_beat_avg_ticks_off', 'intensity_min', 'intensity_max', 'intensity_span', 'polyphony_score', 'top_2_interval_difference', 'top_3_interval_difference', 'num_tones'] 462 | gnuplotline = '' 463 | if showheader: 464 | gnuplotline = '# global-step {} {}\n'.format(' '.join([s.replace(' ', '_') for s in stats_keys_string]), ' '.join(stats_keys)) 465 | gnuplotline += '{} {} {}\n'.format(i, ' '.join(['{}'.format(stats[0][key].replace(' ', '_')) for key in stats_keys_string]), ' '.join(['{:.3f}'.format(sum([s[key] for s in stats])/float(len(stats))) for key in stats_keys])) 466 | return gnuplotline 467 | 468 | 469 | 470 | def main(): 471 | if len(sys.argv) > 2 and sys.argv[1] == '--gnuplot': 472 | #number = sys.argv[2] 473 | patterns = [] 474 | for i in range(3,len(sys.argv)): 475 | #print(i) 476 | filename = sys.argv[i] 477 | print('#File: {}'.format(filename)) 478 | #patterns.append(get_midi_pattern(filename)) 479 | print(get_gnuplot_line([get_midi_pattern(filename)], i, showheader=(i==0))) 480 | 481 | else: 482 | for i in range(1,len(sys.argv)): 483 | filename = sys.argv[i] 484 | print('File: {}'.format(filename)) 485 | midi_pattern = get_midi_pattern(filename) 486 | stats = get_all_stats(midi_pattern) 487 | if stats is None: 488 | print('Could not extract stats.') 489 | else: 490 | print ('ML scale estimate: {}: {:.2f}'.format(stats['scale'], stats['scale_score'])) 491 | print ('Min tone: {}, {:.1f} Hz.'.format(tone_to_tone_name(stats['tone_min']), stats['freq_min'])) 492 | print ('Max tone: {}, {:.1f} Hz.'.format(tone_to_tone_name(stats['tone_max']), stats['freq_max'])) 493 | print ('Span: {} tones, {:.1f} Hz.'.format(stats['tone_span'], stats['freq_span'])) 494 | print ('Unique tones: {}'.format(stats['tones_unique'])) 495 | for r in range(2,10): 496 | print('Repetitions of len {}: {}'.format(r, stats['repetitions_{}'.format(r)])) 497 | print('Estimated beat: {}. Avg ticks off: {:.2f}.'.format(stats['estimated_beat'], stats['estimated_beat_avg_ticks_off'])) 498 | print('Intensity: min: {}, max: {}.'.format(stats['intensity_min'], stats['intensity_max'])) 499 | print('Polyphonous events: {:.2f}.'.format(stats['polyphony_score'])) 500 | print('Top intervals:') 501 | for interval,score in stats['top_10_intervals']: 502 | print('{}: {:.2f}.'.format(interval,score)) 503 | print('Top 2 interval difference: {}.'.format(stats['top_2_interval_difference'])) 504 | print('Top 3 interval difference: {}.'.format(stats['top_3_interval_difference'])) 505 | 506 | 507 | if __name__ == "__main__": 508 | main() 509 | 510 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/models/.gitkeep -------------------------------------------------------------------------------- /music_data_utils.py: -------------------------------------------------------------------------------- 1 | # Tools to load and save midi files for the rnn-gan-project. 2 | # 3 | # This file has been modified by Christopher John Bayron to support 4 | # operations in c-rnn-gan.pytorch project. 5 | # 6 | # Written by Olof Mogren, http://mogren.one/ 7 | # 8 | # This file has been modified by Christopher John Bayron to support 9 | # c-rnn-gan.pytorch operations. Original file is available in: 10 | # 11 | # https://github.com/olofmogren/c-rnn-gan 12 | # 13 | # 14 | # Licensed under the Apache License, Version 2.0 (the "License"); 15 | # you may not use this file except in compliance with the License. 16 | # You may obtain a copy of the License at 17 | # 18 | # http://www.apache.org/licenses/LICENSE-2.0 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | # ============================================================================== 26 | 27 | import os, midi, math, random, re, sys 28 | import numpy as np 29 | from io import BytesIO 30 | 31 | GENRE = 0 32 | COMPOSER = 1 33 | SONG_DATA = 2 34 | 35 | # INDICES IN BATCHES (LENGTH,TONE,VELOCITY are repeated self.tones_per_cell times): 36 | TICKS_FROM_PREV_START = 0 37 | LENGTH = 1 38 | TONE = 2 39 | VELOCITY = 3 40 | 41 | # INDICES IN SONG DATA (NOT YET BATCHED): 42 | BEGIN_TICK = 0 43 | 44 | NUM_FEATURES_PER_TONE = 3 45 | IDEAL_TEMPO = 120.0 46 | 47 | # hand-picked values for normalization 48 | # "NZ" = "normalizer" 49 | NZ = { 50 | TICKS_FROM_PREV_START: {'u': 60.0, 's': 80.0}, 51 | LENGTH: {'u': 64.0, 's': 64.0}, 52 | TONE: {'min': 0, 'max': 127}, 53 | VELOCITY: {'u': 64.0, 's': 128.0}, 54 | } 55 | 56 | debug = '' 57 | #debug = 'overfit' 58 | 59 | sources = {} 60 | sources['classical'] = {} 61 | 62 | file_list = {} 63 | 64 | file_list['validation'] = [ 65 | 'classical/sonata-ish/mozk333c.mid', \ 66 | 'classical/sonata-ish/mozk331b.mid', \ 67 | 'classical/sonata-ish/mozk313a.mid', \ 68 | 'classical/sonata-ish/mozk310b.mid', \ 69 | 'classical/sonata-ish/mozk299a.mid', \ 70 | 'classical/sonata-ish/mozk622c.mid', \ 71 | 'classical/sonata-ish/mozk545b.mid', \ 72 | 'classical/sonata-ish/mozk299a.mid' 73 | ] 74 | 75 | file_list['test'] = [] 76 | 77 | 78 | # normalization, de-normalization functions 79 | def norm_std(batch_songs, ix): 80 | vals = batch_songs[:, :, ix] 81 | vals = (vals - NZ[ix]['u']) / NZ[ix]['s'] 82 | batch_songs[:, :, ix] = vals 83 | 84 | def norm_minmax(batch_songs, ix): 85 | ''' Min-max normalization, to range: [-1, 1] 86 | ''' 87 | vals = batch_songs[:, :, ix] 88 | vals = 2*((vals - NZ[ix]['min']) / (NZ[ix]['max'] - NZ[ix]['min'])) - 1 89 | batch_songs[:, :, ix] = vals 90 | 91 | def de_norm_std(song_data, ix): 92 | vals = song_data[:, ix] 93 | vals = (vals * NZ[ix]['s']) + NZ[ix]['u'] 94 | song_data[:, ix] = vals 95 | 96 | def de_norm_minmax(song_data, ix): 97 | vals = song_data[:, ix] 98 | vals = ((vals + 1) / 2)*(NZ[ix]['max'] - NZ[ix]['min']) + NZ[ix]['min'] 99 | song_data[:, ix] = vals 100 | 101 | 102 | class MusicDataLoader(object): 103 | 104 | def __init__(self, datadir, pace_events=False, tones_per_cell=1, single_composer=None): 105 | self.datadir = datadir 106 | self.output_ticks_per_quarter_note = 384.0 107 | self.tones_per_cell = tones_per_cell 108 | self.single_composer = single_composer 109 | self.pointer = {} 110 | self.pointer['validation'] = 0 111 | self.pointer['test'] = 0 112 | self.pointer['train'] = 0 113 | if not datadir is None: 114 | print ('Data loader: datadir: {}'.format(datadir)) 115 | self.read_data(pace_events) 116 | 117 | 118 | def read_data(self, pace_events): 119 | """ 120 | read_data takes a datadir with genre subdirs, and composer subsubdirs 121 | containing midi files, reads them into training data for an rnn-gan model. 122 | Midi music information will be real-valued frequencies of the 123 | tones, and intensity taken from the velocity information in 124 | the midi files. 125 | 126 | returns a list of tuples, [genre, composer, song_data] 127 | Also saves this list in self.songs. 128 | 129 | Time steps will be fractions of beat notes (32th notes). 130 | """ 131 | 132 | self.genres = sorted(sources.keys()) 133 | print (('num genres:{}'.format(len(self.genres)))) 134 | if self.single_composer is not None: 135 | self.composers = [self.single_composer] 136 | else: 137 | self.composers = [] 138 | for genre in self.genres: 139 | self.composers.extend(sources[genre].keys()) 140 | if debug == 'overfit': 141 | self.composers = self.composers[0:1] 142 | self.composers = list(set(self.composers)) 143 | self.composers.sort() 144 | print (('num composers: {}'.format(len(self.composers)))) 145 | 146 | self.songs = {} 147 | self.songs['validation'] = [] 148 | self.songs['test'] = [] 149 | self.songs['train'] = [] 150 | 151 | # OVERFIT 152 | count = 0 153 | 154 | for genre in self.genres: 155 | # OVERFIT 156 | if debug == 'overfit' and count > 20: break 157 | for composer in self.composers: 158 | # OVERFIT 159 | if debug == 'overfit' and composer not in self.composers: continue 160 | if debug == 'overfit' and count > 20: break 161 | current_path = os.path.join(self.datadir,os.path.join(genre, composer)) 162 | if not os.path.exists(current_path): 163 | print ( 'Path does not exist: {}'.format(current_path)) 164 | continue 165 | files = os.listdir(current_path) 166 | #composer_id += 1 167 | #if composer_id > max_composers: 168 | # print (('Only using {} composers.'.format(max_composers)) 169 | # break 170 | for i,f in enumerate(files): 171 | # OVERFIT 172 | if debug == 'overfit' and count > 20: break 173 | count += 1 174 | 175 | if i % 100 == 99 or i+1 == len(files): 176 | print ( 'Reading files {}/{}: {}'.format(genre, composer, (i+1))) 177 | if os.path.isfile(os.path.join(current_path,f)): 178 | song_data = self.read_one_file(current_path, f, pace_events) 179 | if song_data is None: 180 | continue 181 | if os.path.join(os.path.join(genre, composer), f) in file_list['validation']: 182 | self.songs['validation'].append([genre, composer, song_data]) 183 | elif os.path.join(os.path.join(genre, composer), f) in file_list['test']: 184 | self.songs['test'].append([genre, composer, song_data]) 185 | else: 186 | self.songs['train'].append([genre, composer, song_data]) 187 | 188 | random.shuffle(self.songs['train']) 189 | self.pointer['validation'] = 0 190 | self.pointer['test'] = 0 191 | self.pointer['train'] = 0 192 | # DEBUG: OVERFIT. overfit. 193 | if debug == 'overfit': 194 | self.songs['train'] = self.songs['train'][0:1] 195 | #print (('DEBUG: trying to overfit on the following (repeating for train/validation/test):') 196 | for i in range(200): 197 | self.songs['train'].append(self.songs['train'][0]) 198 | self.songs['validation'] = self.songs['train'][0:1] 199 | self.songs['test'] = self.songs['train'][0:1] 200 | #print (('lens: train: {}, val: {}, test: {}'.format(len(self.songs['train']), len(self.songs['validation']), len(self.songs['test']))) 201 | return self.songs 202 | 203 | def read_one_file(self, path, filename, pace_events): 204 | try: 205 | if debug: 206 | print (('Reading {}'.format(os.path.join(path,filename)))) 207 | midi_pattern = midi.read_midifile(os.path.join(path,filename)) 208 | except: 209 | print ( 'Error reading {}'.format(os.path.join(path,filename))) 210 | return None 211 | # 212 | # Interpreting the midi pattern. 213 | # A pattern has a list of tracks 214 | # (midi.Track()). 215 | # Each track is a list of events: 216 | # * midi.events.SetTempoEvent: tick, data([int, int, int]) 217 | # (The three ints are really three bytes representing one integer.) 218 | # * midi.events.TimeSignatureEvent: tick, data([int, int, int, int]) 219 | # (ignored) 220 | # * midi.events.KeySignatureEvent: tick, data([int, int]) 221 | # (ignored) 222 | # * midi.events.MarkerEvent: tick, text, data 223 | # * midi.events.PortEvent: tick(int), data 224 | # * midi.events.TrackNameEvent: tick(int), text(string), data([ints]) 225 | # * midi.events.ProgramChangeEvent: tick, channel, data 226 | # * midi.events.ControlChangeEvent: tick, channel, data 227 | # * midi.events.PitchWheelEvent: tick, data(two bytes, 14 bits) 228 | # 229 | # * midi.events.NoteOnEvent: tick(int), channel(int), data([int,int])) 230 | # - data[0] is the note (0-127) 231 | # - data[1] is the velocity. 232 | # - if velocity is 0, this is equivalent of a midi.NoteOffEvent 233 | # * midi.events.NoteOffEvent: tick(int), channel(int), data([int,int])) 234 | # 235 | # * midi.events.EndOfTrackEvent: tick(int), data() 236 | # 237 | # Ticks are relative. 238 | # 239 | # Tempo are in microseconds/quarter note. 240 | # 241 | # This interpretation was done after reading 242 | # http://electronicmusic.wikia.com/wiki/Velocity 243 | # http://faydoc.tripod.com/formats/mid.htm 244 | # http://www.lastrayofhope.co.uk/2009/12/23/midi-delta-time-ticks-to-seconds/2/ 245 | # and looking at some files. It will hopefully be enough 246 | # for the use in this project. 247 | # 248 | # We'll save the data intermediately with a dict representing each tone. 249 | # The dicts we put into a list. Times are microseconds. 250 | # Keys: 'freq', 'velocity', 'begin-tick', 'tick-length' 251 | # 252 | # 'Output ticks resolution' are fixed at a 32th note, 253 | # - so 8 ticks per quarter note. 254 | # 255 | # This approach means that we do not currently support 256 | # tempo change events. 257 | # 258 | # TODO 1: Figure out pitch. 259 | # TODO 2: Figure out different channels and instruments. 260 | # 261 | 262 | song_data = [] 263 | tempos = [] 264 | 265 | # Tempo: 266 | ticks_per_quarter_note = float(midi_pattern.resolution) 267 | #print (('Resoluton: {}'.format(ticks_per_quarter_note)) 268 | input_ticks_per_output_tick = ticks_per_quarter_note/self.output_ticks_per_quarter_note 269 | #if debug == 'overfit': input_ticks_per_output_tick = 1.0 270 | 271 | # Multiply with output_ticks_pr_input_tick for output ticks. 272 | for track in midi_pattern: 273 | last_event_input_tick=0 274 | not_closed_notes = [] 275 | for event in track: 276 | if type(event) == midi.events.SetTempoEvent: 277 | td = event.data # tempo data 278 | tempo = 60 * 1000000 / (td[0]*(256**2) + td[1]*256 + td[2]) 279 | tempos.append(tempo) 280 | 281 | elif (type(event) == midi.events.NoteOffEvent) or \ 282 | (type(event) == midi.events.NoteOnEvent and \ 283 | event.velocity == 0): 284 | retained_not_closed_notes = [] 285 | for e in not_closed_notes: 286 | if event.data[0] == e[TONE]: 287 | event_abs_tick = float(event.tick+last_event_input_tick)/input_ticks_per_output_tick 288 | #current_note['length'] = float(ticks*microseconds_per_tick) 289 | e[LENGTH] = event_abs_tick-e[BEGIN_TICK] 290 | song_data.append(e) 291 | else: 292 | retained_not_closed_notes.append(e) 293 | 294 | not_closed_notes = retained_not_closed_notes 295 | elif type(event) == midi.events.NoteOnEvent: 296 | begin_tick = float(event.tick+last_event_input_tick)/input_ticks_per_output_tick 297 | note = [0.0]*(NUM_FEATURES_PER_TONE+1) 298 | note[TONE] = event.data[0] 299 | note[VELOCITY] = float(event.data[1]) 300 | note[BEGIN_TICK] = begin_tick 301 | not_closed_notes.append(note) 302 | 303 | last_event_input_tick += event.tick 304 | for e in not_closed_notes: 305 | #print (('Warning: found no NoteOffEvent for this note. Will close it. {}'.format(e)) 306 | e[LENGTH] = float(ticks_per_quarter_note)/input_ticks_per_output_tick 307 | song_data.append(e) 308 | song_data.sort(key=lambda e: e[BEGIN_TICK]) 309 | if (pace_events): 310 | pace_event_list = [] 311 | pace_tick = 0.0 312 | song_tick_length = song_data[-1][BEGIN_TICK]+song_data[-1][LENGTH] 313 | while pace_tick < song_tick_length: 314 | song_data.append([0.0, 440.0, 0.0, pace_tick, 0.0]) 315 | pace_tick += float(ticks_per_quarter_note)/input_ticks_per_output_tick 316 | song_data.sort(key=lambda e: e[BEGIN_TICK]) 317 | 318 | # tick adjustment (based on tempo) 319 | avg_tempo = sum(tempos) / len(tempos) 320 | for frame in song_data: 321 | frame[BEGIN_TICK] = frame[BEGIN_TICK] * IDEAL_TEMPO/avg_tempo 322 | 323 | return song_data 324 | 325 | def rewind(self, part='train'): 326 | self.pointer[part] = 0 327 | 328 | def get_batch(self, batchsize, songlength, part='train', normalize=True): 329 | """ 330 | get_batch() returns a batch from self.songs, as a 331 | pair of tensors (genrecomposer, song_data). 332 | 333 | The first tensor is a tensor of genres and composers 334 | (as two one-hot vectors that are concatenated). 335 | The second tensor contains song data. 336 | Song data has dimensions [batchsize, songlength, num_song_features] 337 | 338 | To have the sequence be the primary index is convention in 339 | tensorflow's rnn api. 340 | The tensors will have to be split later. 341 | Songs are currently chopped off after songlength. 342 | TODO: handle this in a better way. 343 | 344 | Since self.songs was shuffled in read_data(), the batch is 345 | a random selection without repetition. 346 | 347 | songlength is related to internal sample frequency. 348 | We fix this to be every 32th notes. # 50 milliseconds. 349 | This means 8 samples per quarter note. 350 | There is currently no notion of tempo in the representation. 351 | 352 | composer and genre is concatenated to each event 353 | in the sequence. There might be more clever ways 354 | of doing this. It's not reasonable to change composer 355 | or genre in the middle of a song. 356 | 357 | A tone has a feature telling us the pause before it. 358 | 359 | """ 360 | #print (('get_batch(): pointer: {}, len: {}, batchsize: {}'.format(self.pointer[part], len(self.songs[part]), batchsize)) 361 | if self.pointer[part] > len(self.songs[part])-batchsize: 362 | batchsize = len(self.songs[part]) - self.pointer[part] 363 | if batchsize == 0: 364 | return None, None 365 | 366 | if self.songs[part]: 367 | batch = self.songs[part][self.pointer[part]:self.pointer[part]+batchsize] 368 | self.pointer[part] += batchsize 369 | # subtract two for start-time and channel, which we don't include. 370 | num_meta_features = len(self.genres)+len(self.composers) 371 | # All features except timing are multiplied with tones_per_cell (default 1) 372 | num_song_features = NUM_FEATURES_PER_TONE*self.tones_per_cell+1 373 | batch_genrecomposer = np.ndarray(shape=[batchsize, num_meta_features]) 374 | batch_songs = np.ndarray(shape=[batchsize, songlength, num_song_features]) 375 | 376 | for s in range(len(batch)): 377 | songmatrix = np.ndarray(shape=[songlength, num_song_features]) 378 | composeronehot = onehot(self.composers.index(batch[s][1]), len(self.composers)) 379 | genreonehot = onehot(self.genres.index(batch[s][0]), len(self.genres)) 380 | genrecomposer = np.concatenate([genreonehot, composeronehot]) 381 | 382 | #random position: 383 | begin = 0 384 | if len(batch[s][SONG_DATA]) > songlength*self.tones_per_cell: 385 | begin = random.randint(0, len(batch[s][SONG_DATA])-songlength*self.tones_per_cell) 386 | matrixrow = 0 387 | n = begin 388 | while matrixrow < songlength: 389 | eventindex = 0 390 | event = np.zeros(shape=[num_song_features]) 391 | if n < len(batch[s][SONG_DATA]): 392 | event[LENGTH] = batch[s][SONG_DATA][n][LENGTH] 393 | event[TONE] = batch[s][SONG_DATA][n][TONE] 394 | event[VELOCITY] = batch[s][SONG_DATA][n][VELOCITY] 395 | ticks_from_start_of_prev_tone = 0.0 396 | if n>0: 397 | # beginning of this tone, minus starting of previous 398 | ticks_from_start_of_prev_tone = batch[s][SONG_DATA][n][BEGIN_TICK]-batch[s][SONG_DATA][n-1][BEGIN_TICK] 399 | # we don't include start-time at index 0: 400 | # and not channel at -1. 401 | # tones are allowed to overlap. This is indicated with 402 | # relative time zero in the midi spec. 403 | event[TICKS_FROM_PREV_START] = ticks_from_start_of_prev_tone 404 | tone_count = 1 405 | for simultaneous in range(1,self.tones_per_cell): 406 | if n+simultaneous >= len(batch[s][SONG_DATA]): 407 | break 408 | if batch[s][SONG_DATA][n+simultaneous][BEGIN_TICK]-batch[s][SONG_DATA][n][BEGIN_TICK] == 0: 409 | offset = simultaneous*NUM_FEATURES_PER_TONE 410 | event[offset+LENGTH] = batch[s][SONG_DATA][n+simultaneous][LENGTH] 411 | event[offset+TONE] = batch[s][SONG_DATA][n+simultaneous][TONE] 412 | event[offset+VELOCITY] = batch[s][SONG_DATA][n+simultaneous][VELOCITY] 413 | tone_count += 1 414 | else: 415 | break 416 | songmatrix[matrixrow,:] = event 417 | matrixrow += 1 418 | n += tone_count 419 | #if s == 0 and self.pointer[part] == batchsize: 420 | # print ( songmatrix[0:10,:] 421 | batch_genrecomposer[s,:] = genrecomposer 422 | batch_songs[s,:,:] = songmatrix 423 | 424 | # input normalization 425 | if normalize: 426 | norm_std(batch_songs, TICKS_FROM_PREV_START) 427 | norm_std(batch_songs, LENGTH) 428 | norm_std(batch_songs, VELOCITY) 429 | norm_minmax(batch_songs, TONE) 430 | 431 | return batch_genrecomposer, batch_songs 432 | 433 | else: 434 | raise 'get_batch() called but self.songs is not initialized.' 435 | 436 | def get_num_song_features(self): 437 | return NUM_FEATURES_PER_TONE*self.tones_per_cell+1 438 | def get_num_meta_features(self): 439 | return len(self.genres)+len(self.composers) 440 | 441 | def get_midi_pattern(self, song_data, bpm, normalized=True): 442 | """ 443 | get_midi_pattern takes a song in internal representation 444 | (a tensor of dimensions [songlength, self.num_song_features]). 445 | the three values are length, frequency, velocity. 446 | if velocity of a frame is zero, no midi event will be 447 | triggered at that frame. 448 | 449 | returns the midi_pattern. 450 | 451 | Can be used with filename == None. Then nothing is saved, but only returned. 452 | """ 453 | 454 | # 455 | # Interpreting the midi pattern. 456 | # A pattern has a list of tracks 457 | # (midi.Track()). 458 | # Each track is a list of events: 459 | # * midi.events.SetTempoEvent: tick, data([int, int, int]) 460 | # (The three ints are really three bytes representing one integer.) 461 | # * midi.events.TimeSignatureEvent: tick, data([int, int, int, int]) 462 | # (ignored) 463 | # * midi.events.KeySignatureEvent: tick, data([int, int]) 464 | # (ignored) 465 | # * midi.events.MarkerEvent: tick, text, data 466 | # * midi.events.PortEvent: tick(int), data 467 | # * midi.events.TrackNameEvent: tick(int), text(string), data([ints]) 468 | # * midi.events.ProgramChangeEvent: tick, channel, data 469 | # * midi.events.ControlChangeEvent: tick, channel, data 470 | # * midi.events.PitchWheelEvent: tick, data(two bytes, 14 bits) 471 | # 472 | # * midi.events.NoteOnEvent: tick(int), channel(int), data([int,int])) 473 | # - data[0] is the note (0-127) 474 | # - data[1] is the velocity. 475 | # - if velocity is 0, this is equivalent of a midi.NoteOffEvent 476 | # * midi.events.NoteOffEvent: tick(int), channel(int), data([int,int])) 477 | # 478 | # * midi.events.EndOfTrackEvent: tick(int), data() 479 | # 480 | # Ticks are relative. 481 | # 482 | # Tempo are in microseconds/quarter note. 483 | # 484 | # This interpretation was done after reading 485 | # http://electronicmusic.wikia.com/wiki/Velocity 486 | # http://faydoc.tripod.com/formats/mid.htm 487 | # http://www.lastrayofhope.co.uk/2009/12/23/midi-delta-time-ticks-to-seconds/2/ 488 | # and looking at some files. It will hopefully be enough 489 | # for the use in this project. 490 | # 491 | # This approach means that we do not currently support 492 | # tempo change events. 493 | # 494 | 495 | # Tempo: 496 | # Multiply with output_ticks_pr_input_tick for output ticks. 497 | midi_pattern = midi.Pattern([], resolution=int(self.output_ticks_per_quarter_note)) 498 | cur_track = midi.Track([]) 499 | cur_track.append(midi.events.SetTempoEvent(tick=0, bpm=IDEAL_TEMPO)) 500 | future_events = {} 501 | last_event_tick = 0 502 | 503 | ticks_to_this_tone = 0.0 504 | song_events_absolute_ticks = [] 505 | abs_tick_note_beginning = 0.0 506 | 507 | if type(song_data) != np.ndarray: 508 | song_data = np.array(song_data) 509 | 510 | # de-normalize 511 | if normalized: 512 | de_norm_std(song_data, TICKS_FROM_PREV_START) 513 | de_norm_std(song_data, LENGTH) 514 | de_norm_std(song_data, VELOCITY) 515 | de_norm_minmax(song_data, TONE) 516 | 517 | for frame in song_data: 518 | abs_tick_note_beginning += int(round(frame[TICKS_FROM_PREV_START])) 519 | for subframe in range(self.tones_per_cell): 520 | offset = subframe*NUM_FEATURES_PER_TONE 521 | tick_len = int(round(frame[offset+LENGTH])) 522 | tone = int(round(frame[offset+TONE])) 523 | velocity = min(int(round(frame[offset+VELOCITY])),127) 524 | 525 | if tone is not None and velocity > 0 and tick_len > 0: 526 | # range-check with preserved tone, changed one octave: 527 | while tone < 0: tone += 12 528 | while tone > 127: tone -= 12 529 | 530 | song_events_absolute_ticks.append((abs_tick_note_beginning, 531 | midi.events.NoteOnEvent( 532 | tick=0, 533 | velocity=velocity, 534 | pitch=tone))) 535 | song_events_absolute_ticks.append((abs_tick_note_beginning+tick_len, 536 | midi.events.NoteOffEvent( 537 | tick=0, 538 | velocity=0, 539 | pitch=tone))) 540 | song_events_absolute_ticks.sort(key=lambda e: e[0]) 541 | abs_tick_note_beginning = 0.0 542 | for abs_tick,event in song_events_absolute_ticks: 543 | rel_tick = abs_tick-abs_tick_note_beginning 544 | event.tick = int(round(rel_tick)) 545 | cur_track.append(event) 546 | abs_tick_note_beginning=abs_tick 547 | 548 | cur_track.append(midi.EndOfTrackEvent(tick=int(self.output_ticks_per_quarter_note))) 549 | midi_pattern.append(cur_track) 550 | 551 | return midi_pattern 552 | 553 | def save_midi_pattern(self, filename, midi_pattern): 554 | if filename is not None: 555 | midi.write_midifile(filename, midi_pattern) 556 | 557 | def save_data(self, filename, song_data, bpm=IDEAL_TEMPO): 558 | """ 559 | save_data takes a filename and a song in internal representation 560 | (a tensor of dimensions [songlength, 3]). 561 | the three values are length, frequency, velocity. 562 | if velocity of a frame is zero, no midi event will be 563 | triggered at that frame. 564 | 565 | returns the midi_pattern. 566 | 567 | Can be used with filename == None. Then nothing is saved, but only returned. 568 | """ 569 | midi_pattern = self.get_midi_pattern(song_data, bpm=bpm) 570 | self.save_midi_pattern(filename, midi_pattern) 571 | return midi_pattern 572 | 573 | def tone_to_freq(tone): 574 | """ 575 | returns the frequency of a tone. 576 | 577 | formulas from 578 | * https://en.wikipedia.org/wiki/MIDI_Tuning_Standard 579 | * https://en.wikipedia.org/wiki/Cent_(music) 580 | """ 581 | return math.pow(2, ((float(tone)-69.0)/12.0)) * 440.0 582 | 583 | def freq_to_tone(freq): 584 | """ 585 | returns a dict d where 586 | d['tone'] is the base tone in midi standard 587 | d['cents'] is the cents to make the tone into the exact-ish frequency provided. 588 | multiply this with 8192 to get the midi pitch level. 589 | 590 | formulas from 591 | * https://en.wikipedia.org/wiki/MIDI_Tuning_Standard 592 | * https://en.wikipedia.org/wiki/Cent_(music) 593 | """ 594 | if freq <= 0.0: 595 | return None 596 | float_tone = (69.0+12*math.log(float(freq)/440.0, 2)) 597 | int_tone = int(float_tone) 598 | cents = int(1200*math.log(float(freq)/tone_to_freq(int_tone), 2)) 599 | return {'tone': int_tone, 'cents': cents} 600 | 601 | def onehot(i, length): 602 | a = np.zeros(shape=[length]) 603 | a[i] = 1 604 | return a 605 | -------------------------------------------------------------------------------- /samples/sample1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/samples/sample1.mid -------------------------------------------------------------------------------- /samples/sample2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjbayron/c-rnn-gan.pytorch/61fc00e6758aafb5bd8c92c79e964b63199d1a6d/samples/sample2.mid -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Christopher John Bayron 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file has been created by Christopher John Bayron based on "rnn_gan.py" 16 | # by Olof Mogren. The referenced code is available in: 17 | # 18 | # https://github.com/olofmogren/c-rnn-gan 19 | 20 | import os 21 | from argparse import ArgumentParser 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch import optim 26 | 27 | from c_rnn_gan import Generator, Discriminator 28 | import music_data_utils 29 | 30 | DATA_DIR = 'data' 31 | CKPT_DIR = 'models' 32 | COMPOSER = 'sonata-ish' 33 | 34 | G_FN = 'c_rnn_gan_g.pth' 35 | D_FN = 'c_rnn_gan_d.pth' 36 | 37 | G_LRN_RATE = 0.001 38 | D_LRN_RATE = 0.001 39 | MAX_GRAD_NORM = 5.0 40 | # following values are modified at runtime 41 | MAX_SEQ_LEN = 200 42 | BATCH_SIZE = 32 43 | 44 | EPSILON = 1e-40 # value to use to approximate zero (to prevent undefined results) 45 | 46 | class GLoss(nn.Module): 47 | ''' C-RNN-GAN generator loss 48 | ''' 49 | def __init__(self): 50 | super(GLoss, self).__init__() 51 | 52 | def forward(self, logits_gen): 53 | logits_gen = torch.clamp(logits_gen, EPSILON, 1.0) 54 | batch_loss = -torch.log(logits_gen) 55 | 56 | return torch.mean(batch_loss) 57 | 58 | 59 | class DLoss(nn.Module): 60 | ''' C-RNN-GAN discriminator loss 61 | ''' 62 | def __init__(self, label_smoothing=False): 63 | super(DLoss, self).__init__() 64 | self.label_smoothing = label_smoothing 65 | 66 | def forward(self, logits_real, logits_gen): 67 | ''' Discriminator loss 68 | 69 | logits_real: logits from D, when input is real 70 | logits_gen: logits from D, when input is from Generator 71 | 72 | loss = -(ylog(p) + (1-y)log(1-p)) 73 | 74 | ''' 75 | logits_real = torch.clamp(logits_real, EPSILON, 1.0) 76 | d_loss_real = -torch.log(logits_real) 77 | 78 | if self.label_smoothing: 79 | p_fake = torch.clamp((1 - logits_real), EPSILON, 1.0) 80 | d_loss_fake = -torch.log(p_fake) 81 | d_loss_real = 0.9*d_loss_real + 0.1*d_loss_fake 82 | 83 | logits_gen = torch.clamp((1 - logits_gen), EPSILON, 1.0) 84 | d_loss_gen = -torch.log(logits_gen) 85 | 86 | batch_loss = d_loss_real + d_loss_gen 87 | return torch.mean(batch_loss) 88 | 89 | 90 | def run_training(model, optimizer, criterion, dataloader, freeze_g=False, freeze_d=False): 91 | ''' Run single training epoch 92 | ''' 93 | 94 | num_feats = dataloader.get_num_song_features() 95 | dataloader.rewind(part='train') 96 | batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='train') 97 | 98 | model['g'].train() 99 | model['d'].train() 100 | 101 | loss = {} 102 | g_loss_total = 0.0 103 | d_loss_total = 0.0 104 | num_corrects = 0 105 | num_sample = 0 106 | 107 | while batch_meta is not None and batch_song is not None: 108 | 109 | real_batch_sz = batch_song.shape[0] 110 | 111 | # get initial states 112 | # each batch is independent i.e. not a continuation of previous batch 113 | # so we reset states for each batch 114 | # POSSIBLE IMPROVEMENT: next batch is continuation of previous batch 115 | g_states = model['g'].init_hidden(real_batch_sz) 116 | d_state = model['d'].init_hidden(real_batch_sz) 117 | 118 | #### GENERATOR #### 119 | if not freeze_g: 120 | optimizer['g'].zero_grad() 121 | # prepare inputs 122 | z = torch.empty([real_batch_sz, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 123 | batch_song = torch.Tensor(batch_song) 124 | 125 | # feed inputs to generator 126 | g_feats, _ = model['g'](z, g_states) 127 | 128 | # calculate loss, backprop, and update weights of G 129 | if isinstance(criterion['g'], GLoss): 130 | d_logits_gen, _, _ = model['d'](g_feats, d_state) 131 | loss['g'] = criterion['g'](d_logits_gen) 132 | else: # feature matching 133 | # feed real and generated input to discriminator 134 | _, d_feats_real, _ = model['d'](batch_song, d_state) 135 | _, d_feats_gen, _ = model['d'](g_feats, d_state) 136 | loss['g'] = criterion['g'](d_feats_real, d_feats_gen) 137 | 138 | if not freeze_g: 139 | loss['g'].backward() 140 | nn.utils.clip_grad_norm_(model['g'].parameters(), max_norm=MAX_GRAD_NORM) 141 | optimizer['g'].step() 142 | 143 | #### DISCRIMINATOR #### 144 | if not freeze_d: 145 | optimizer['d'].zero_grad() 146 | # feed real and generated input to discriminator 147 | d_logits_real, _, _ = model['d'](batch_song, d_state) 148 | # need to detach from operation history to prevent backpropagating to generator 149 | d_logits_gen, _, _ = model['d'](g_feats.detach(), d_state) 150 | # calculate loss, backprop, and update weights of D 151 | loss['d'] = criterion['d'](d_logits_real, d_logits_gen) 152 | if not freeze_d: 153 | loss['d'].backward() 154 | nn.utils.clip_grad_norm_(model['d'].parameters(), max_norm=MAX_GRAD_NORM) 155 | optimizer['d'].step() 156 | 157 | g_loss_total += loss['g'].item() 158 | d_loss_total += loss['d'].item() 159 | num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item() 160 | num_sample += real_batch_sz 161 | 162 | # fetch next batch 163 | batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='train') 164 | 165 | g_loss_avg, d_loss_avg = 0.0, 0.0 166 | d_acc = 0.0 167 | if num_sample > 0: 168 | g_loss_avg = g_loss_total / num_sample 169 | d_loss_avg = d_loss_total / num_sample 170 | d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated) 171 | 172 | return model, g_loss_avg, d_loss_avg, d_acc 173 | 174 | 175 | def run_validation(model, criterion, dataloader): 176 | ''' Run single validation epoch 177 | ''' 178 | num_feats = dataloader.get_num_song_features() 179 | dataloader.rewind(part='validation') 180 | batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='validation') 181 | 182 | model['g'].eval() 183 | model['d'].eval() 184 | 185 | g_loss_total = 0.0 186 | d_loss_total = 0.0 187 | num_corrects = 0 188 | num_sample = 0 189 | 190 | while batch_meta is not None and batch_song is not None: 191 | 192 | real_batch_sz = batch_song.shape[0] 193 | 194 | # initial states 195 | g_states = model['g'].init_hidden(real_batch_sz) 196 | d_state = model['d'].init_hidden(real_batch_sz) 197 | 198 | #### GENERATOR #### 199 | # prepare inputs 200 | z = torch.empty([real_batch_sz, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 201 | batch_song = torch.Tensor(batch_song) 202 | 203 | # feed inputs to generator 204 | g_feats, _ = model['g'](z, g_states) 205 | # feed real and generated input to discriminator 206 | d_logits_real, d_feats_real, _ = model['d'](batch_song, d_state) 207 | d_logits_gen, d_feats_gen, _ = model['d'](g_feats, d_state) 208 | # calculate loss 209 | if isinstance(criterion['g'], GLoss): 210 | g_loss = criterion['g'](d_logits_gen) 211 | else: # feature matching 212 | g_loss = criterion['g'](d_feats_real, d_feats_gen) 213 | 214 | d_loss = criterion['d'](d_logits_real, d_logits_gen) 215 | 216 | g_loss_total += g_loss.item() 217 | d_loss_total += d_loss.item() 218 | num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item() 219 | num_sample += real_batch_sz 220 | 221 | # fetch next batch 222 | batch_meta, batch_song = dataloader.get_batch(BATCH_SIZE, MAX_SEQ_LEN, part='validation') 223 | 224 | g_loss_avg, d_loss_avg = 0.0, 0.0 225 | d_acc = 0.0 226 | if num_sample > 0: 227 | g_loss_avg = g_loss_total / num_sample 228 | d_loss_avg = d_loss_total / num_sample 229 | d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated) 230 | 231 | return g_loss_avg, d_loss_avg, d_acc 232 | 233 | 234 | def run_epoch(model, optimizer, criterion, dataloader, ep, num_ep, 235 | freeze_g=False, freeze_d=False, pretraining=False): 236 | ''' Run a single epoch 237 | ''' 238 | model, trn_g_loss, trn_d_loss, trn_acc = \ 239 | run_training(model, optimizer, criterion, dataloader, freeze_g=freeze_g, freeze_d=freeze_d) 240 | 241 | val_g_loss, val_d_loss, val_acc = run_validation(model, criterion, dataloader) 242 | 243 | if pretraining: 244 | print("Pretraining Epoch %d/%d " % (ep+1, num_ep), "[Freeze G: ", freeze_g, ", Freeze D: ", freeze_d, "]") 245 | else: 246 | print("Epoch %d/%d " % (ep+1, num_ep), "[Freeze G: ", freeze_g, ", Freeze D: ", freeze_d, "]") 247 | 248 | print("\t[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f\n" 249 | "\t[Validation] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f" % 250 | (trn_g_loss, trn_d_loss, trn_acc, 251 | val_g_loss, val_d_loss, val_acc)) 252 | 253 | # -- DEBUG -- 254 | # This is for monitoring the current output from generator 255 | # generate from model then save to MIDI file 256 | g_states = model['g'].init_hidden(1) 257 | num_feats = dataloader.get_num_song_features() 258 | z = torch.empty([1, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 259 | if torch.cuda.is_available(): 260 | z = z.cuda() 261 | model['g'].cuda() 262 | 263 | model['g'].eval() 264 | g_feats, _ = model['g'](z, g_states) 265 | song_data = g_feats.squeeze().cpu() 266 | song_data = song_data.detach().numpy() 267 | 268 | if (ep+1) == num_ep: 269 | midi_data = dataloader.save_data('sample.mid', song_data) 270 | else: 271 | midi_data = dataloader.save_data(None, song_data) 272 | print(midi_data[0][:16]) 273 | # -- DEBUG -- 274 | 275 | return model, trn_acc 276 | 277 | 278 | def main(args): 279 | ''' Training sequence 280 | ''' 281 | dataloader = music_data_utils.MusicDataLoader(DATA_DIR, single_composer=COMPOSER) 282 | num_feats = dataloader.get_num_song_features() 283 | 284 | # First checking if GPU is available 285 | train_on_gpu = torch.cuda.is_available() 286 | if train_on_gpu: 287 | print('Training on GPU.') 288 | else: 289 | print('No GPU available, training on CPU.') 290 | 291 | model = { 292 | 'g': Generator(num_feats, use_cuda=train_on_gpu), 293 | 'd': Discriminator(num_feats, use_cuda=train_on_gpu) 294 | } 295 | 296 | if args.use_sgd: 297 | optimizer = { 298 | 'g': optim.SGD(model['g'].parameters(), lr=args.g_lrn_rate, momentum=0.9), 299 | 'd': optim.SGD(model['d'].parameters(), lr=args.d_lrn_rate, momentum=0.9) 300 | } 301 | else: 302 | optimizer = { 303 | 'g': optim.Adam(model['g'].parameters(), args.g_lrn_rate), 304 | 'd': optim.Adam(model['d'].parameters(), args.d_lrn_rate) 305 | } 306 | 307 | criterion = { 308 | 'g': nn.MSELoss(reduction='sum') if args.feature_matching else GLoss(), 309 | 'd': DLoss(args.label_smoothing) 310 | } 311 | 312 | if args.load_g: 313 | ckpt = torch.load(os.path.join(CKPT_DIR, G_FN)) 314 | model['g'].load_state_dict(ckpt) 315 | print("Continue training of %s" % os.path.join(CKPT_DIR, G_FN)) 316 | 317 | if args.load_d: 318 | ckpt = torch.load(os.path.join(CKPT_DIR, D_FN)) 319 | model['d'].load_state_dict(ckpt) 320 | print("Continue training of %s" % os.path.join(CKPT_DIR, D_FN)) 321 | 322 | if train_on_gpu: 323 | model['g'].cuda() 324 | model['d'].cuda() 325 | 326 | if not args.no_pretraining: 327 | for ep in range(args.d_pretraining_epochs): 328 | model, _ = run_epoch(model, optimizer, criterion, dataloader, 329 | ep, args.d_pretraining_epochs, freeze_g=True, pretraining=True) 330 | 331 | for ep in range(args.g_pretraining_epochs): 332 | model, _ = run_epoch(model, optimizer, criterion, dataloader, 333 | ep, args.g_pretraining_epochs, freeze_d=True, pretraining=True) 334 | 335 | freeze_d = False 336 | for ep in range(args.num_epochs): 337 | # if ep % args.freeze_d_every == 0: 338 | # freeze_d = not freeze_d 339 | 340 | model, trn_acc = run_epoch(model, optimizer, criterion, dataloader, ep, args.num_epochs, freeze_d=freeze_d) 341 | if args.conditional_freezing: 342 | # conditional freezing 343 | freeze_d = False 344 | if trn_acc >= 95.0: 345 | freeze_d = True 346 | 347 | if not args.no_save_g: 348 | torch.save(model['g'].state_dict(), os.path.join(CKPT_DIR, G_FN)) 349 | print("Saved generator: %s" % os.path.join(CKPT_DIR, G_FN)) 350 | 351 | if not args.no_save_d: 352 | torch.save(model['d'].state_dict(), os.path.join(CKPT_DIR, D_FN)) 353 | print("Saved discriminator: %s" % os.path.join(CKPT_DIR, D_FN)) 354 | 355 | 356 | if __name__ == "__main__": 357 | 358 | ARG_PARSER = ArgumentParser() 359 | ARG_PARSER.add_argument('--load_g', action='store_true') 360 | ARG_PARSER.add_argument('--load_d', action='store_true') 361 | ARG_PARSER.add_argument('--no_save_g', action='store_true') 362 | ARG_PARSER.add_argument('--no_save_d', action='store_true') 363 | 364 | ARG_PARSER.add_argument('--num_epochs', default=300, type=int) 365 | ARG_PARSER.add_argument('--seq_len', default=256, type=int) 366 | ARG_PARSER.add_argument('--batch_size', default=16, type=int) 367 | ARG_PARSER.add_argument('--g_lrn_rate', default=0.001, type=float) 368 | ARG_PARSER.add_argument('--d_lrn_rate', default=0.001, type=float) 369 | 370 | ARG_PARSER.add_argument('--no_pretraining', action='store_true') 371 | ARG_PARSER.add_argument('--g_pretraining_epochs', default=5, type=int) 372 | ARG_PARSER.add_argument('--d_pretraining_epochs', default=5, type=int) 373 | # ARG_PARSER.add_argument('--freeze_d_every', default=5, type=int) 374 | ARG_PARSER.add_argument('--use_sgd', action='store_true') 375 | ARG_PARSER.add_argument('--conditional_freezing', action='store_true') 376 | ARG_PARSER.add_argument('--label_smoothing', action='store_true') 377 | ARG_PARSER.add_argument('--feature_matching', action='store_true') 378 | 379 | ARGS = ARG_PARSER.parse_args() 380 | MAX_SEQ_LEN = ARGS.seq_len 381 | BATCH_SIZE = ARGS.batch_size 382 | 383 | main(ARGS) 384 | -------------------------------------------------------------------------------- /train_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Christopher John Bayron 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file has been created by Christopher John Bayron based on "rnn_gan.py" 16 | # by Olof Mogren. The referenced code is available in: 17 | # 18 | # https://github.com/olofmogren/c-rnn-gan 19 | 20 | import os 21 | from argparse import ArgumentParser 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn as nn 26 | from torch import optim 27 | from torch.utils.data import TensorDataset, DataLoader 28 | 29 | from c_rnn_gan import Generator, Discriminator 30 | import music_data_utils 31 | 32 | DATA_DIR = 'data' 33 | CKPT_DIR = 'models' 34 | G_FN = 'c_rnn_gan_g.pth' 35 | D_FN = 'c_rnn_gan_d.pth' 36 | 37 | G_LRN_RATE = 0.001 38 | D_LRN_RATE = 0.001 39 | MAX_GRAD_NORM = 5.0 40 | BATCH_SIZE = 32 41 | MAX_EPOCHS = 500 42 | L2_DECAY = 1.0 43 | 44 | COMPOSER = 'mozart' 45 | MAX_SEQ_LEN = 200 46 | 47 | PERFORM_LOSS_CHECKING = False 48 | FREEZE_G = False 49 | FREEZE_D = False 50 | 51 | NUM_DUMMY_TRN = 256 52 | NUM_DUMMY_VAL = 128 53 | 54 | EPSILON = 1e-40 # value to use to approximate zero (to prevent undefined results) 55 | 56 | def get_accuracy(logits_real, logits_gen): 57 | ''' Discriminator accuracy 58 | ''' 59 | real_corrects = (logits_real > 0.5).sum() 60 | gen_corrects = (logits_gen < 0.5).sum() 61 | 62 | acc = (real_corrects + gen_corrects) / (len(logits_real) + len(logits_gen)) 63 | return acc.item() 64 | 65 | class DLoss(nn.Module): 66 | ''' C-RNN-GAN discriminator loss 67 | ''' 68 | def __init__(self): 69 | super(DLoss, self).__init__() 70 | 71 | def forward(self, logits_real, logits_gen): 72 | ''' Discriminator loss 73 | 74 | logits_real: logits from D, when input is real 75 | logits_gen: logits from D, when input is from Generator 76 | ''' 77 | logits_real = torch.clamp(logits_real, EPSILON, 1.0) 78 | d_loss_real = -torch.log(logits_real) 79 | 80 | logits_gen = torch.clamp((1 - logits_gen), EPSILON, 1.0) 81 | d_loss_gen = -torch.log(logits_gen) 82 | 83 | batch_loss = d_loss_real + d_loss_gen 84 | return torch.mean(batch_loss) 85 | 86 | 87 | def control_grad(model, freeze=True): 88 | ''' Freeze/unfreeze optimization of model 89 | ''' 90 | if freeze: 91 | for param in model.parameters(): 92 | param.requires_grad = False 93 | 94 | else: # unfreeze 95 | for param in model.parameters(): 96 | param.requires_grad = True 97 | 98 | 99 | def check_loss(model, loss): 100 | ''' Check loss and control gradients if necessary 101 | ''' 102 | control_grad(model['g'], freeze=False) 103 | control_grad(model['d'], freeze=False) 104 | 105 | if loss['d'] == 0.0 and loss['g'] == 0.0: 106 | print('Both G and D train loss are zero. Exiting.') 107 | return False 108 | elif loss['d'] == 0.0: # freeze D 109 | control_grad(model['d'], freeze=True) 110 | elif loss['g'] == 0.0: # freeze G 111 | control_grad(model['g'], freeze=True) 112 | elif loss['g'] < 2.0 or loss['d'] < 2.0: 113 | control_grad(model['d'], freeze=True) 114 | if loss['g']*0.7 > loss['d']: 115 | control_grad(model['g'], freeze=True) 116 | 117 | return True 118 | 119 | 120 | def dummy_dataloader(seq_len, batch_size, num_sample): 121 | ''' Dummy data generator (for debugging purposes) 122 | ''' 123 | # the following code generates random data of numbers 124 | # where each number is twice the prev number 125 | np_data = np.stack([(2 ** np.arange(seq_len))[:, np.newaxis] \ 126 | * np.random.rand() for i in range(num_sample)]) 127 | 128 | data = TensorDataset(torch.from_numpy(np_data)) 129 | return DataLoader(data, shuffle=True, batch_size=batch_size) 130 | 131 | 132 | def run_training(model, optimizer, criterion, dataloader, ep, freeze_g=False, freeze_d=False): 133 | ''' Run single training epoch 134 | ''' 135 | 136 | loss = { 137 | 'g': 10.0, 138 | 'd': 10.0 139 | } 140 | 141 | num_feats = model['g'].num_feats 142 | 143 | model['g'].train() 144 | model['d'].train() 145 | 146 | g_loss_total = 0.0 147 | d_loss_total = 0.0 148 | num_corrects = 0 149 | num_sample = 0 150 | 151 | log_sum_real = 0.0 152 | log_sum_gen = 0.0 153 | 154 | for (batch_input, ) in dataloader: 155 | 156 | real_batch_sz = len(batch_input) 157 | batch_input = batch_input.type(torch.FloatTensor) 158 | 159 | # loss checking 160 | if PERFORM_LOSS_CHECKING == True: 161 | if not check_loss(model, loss): 162 | break 163 | 164 | # get initial states 165 | # each batch is independent i.e. not a continuation of previous batch 166 | # so we reset states for each batch 167 | # POSSIBLE IMPROVEMENT: next batch is continuation of previous batch 168 | g_states = model['g'].init_hidden(real_batch_sz) 169 | d_state = model['d'].init_hidden(real_batch_sz) 170 | 171 | #### GENERATOR #### 172 | if not freeze_g: 173 | optimizer['g'].zero_grad() 174 | # prepare inputs 175 | z = torch.empty([real_batch_sz, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 176 | 177 | # feed inputs to generator 178 | g_feats, _ = model['g'](z, g_states) 179 | # feed real and generated input to discriminator 180 | _, d_feats_real, _ = model['d'](batch_input, d_state) 181 | _, d_feats_gen, _ = model['d'](g_feats, d_state) 182 | 183 | # calculate loss, backprop, and update weights of G 184 | loss['g'] = criterion['g'](d_feats_real, d_feats_gen) 185 | if not freeze_g: 186 | loss['g'].backward() 187 | # nn.utils.clip_grad_norm_(model['g'].parameters(), max_norm=MAX_GRAD_NORM) 188 | optimizer['g'].step() 189 | 190 | #### DISCRIMINATOR #### 191 | if not freeze_d: 192 | optimizer['d'].zero_grad() 193 | 194 | # feed real and generated input to discriminator 195 | d_logits_real, _, _ = model['d'](batch_input, d_state) 196 | # need to detach from operation history to prevent backpropagating to generator 197 | d_logits_gen, _, _ = model['d'](g_feats.detach(), d_state) 198 | # calculate loss, backprop, and update weights of D 199 | loss['d'] = criterion['d'](d_logits_real, d_logits_gen) 200 | 201 | # print("Trn: ", d_logits_real.mean(), d_logits_gen.mean()) 202 | log_sum_real += d_logits_real.sum().item() 203 | log_sum_gen += d_logits_gen.sum().item() 204 | 205 | if not freeze_d: 206 | loss['d'].backward() 207 | nn.utils.clip_grad_norm_(model['d'].parameters(), max_norm=MAX_GRAD_NORM) 208 | optimizer['d'].step() 209 | 210 | g_loss_total += loss['g'].item() 211 | d_loss_total += loss['d'].item() 212 | num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item() 213 | num_sample += real_batch_sz 214 | 215 | g_loss_avg, d_loss_avg = 0.0, 0.0 216 | d_acc = 0.0 217 | if num_sample > 0: 218 | g_loss_avg = g_loss_total / num_sample 219 | d_loss_avg = d_loss_total / num_sample 220 | d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated) 221 | 222 | print("Trn: ", log_sum_real / num_sample, log_sum_gen / num_sample) 223 | 224 | return model, g_loss_avg, d_loss_avg, d_acc 225 | 226 | 227 | def run_validation(model, criterion, dataloader): 228 | ''' Run single validation epoch 229 | ''' 230 | num_feats = model['g'].num_feats 231 | 232 | model['g'].eval() 233 | model['d'].eval() 234 | 235 | g_loss_total = 0.0 236 | d_loss_total = 0.0 237 | num_corrects = 0 238 | num_sample = 0 239 | 240 | log_sum_real = 0.0 241 | log_sum_gen = 0.0 242 | 243 | for (batch_input, ) in dataloader: 244 | 245 | real_batch_sz = len(batch_input) 246 | batch_input = batch_input.type(torch.FloatTensor) 247 | 248 | # initial states 249 | g_states = model['g'].init_hidden(real_batch_sz) 250 | d_state = model['d'].init_hidden(real_batch_sz) 251 | 252 | #### GENERATOR #### 253 | # prepare inputs 254 | z = torch.empty([real_batch_sz, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 255 | 256 | # feed inputs to generator 257 | g_feats, _ = model['g'](z, g_states) 258 | # feed real and generated input to discriminator 259 | d_logits_real, d_feats_real, _ = model['d'](batch_input, d_state) 260 | d_logits_gen, d_feats_gen, _ = model['d'](g_feats, d_state) 261 | # print("Val: ", d_logits_real.mean(), d_logits_gen.mean()) 262 | log_sum_real += d_logits_real.sum().item() 263 | log_sum_gen += d_logits_gen.sum().item() 264 | 265 | # calculate loss 266 | g_loss = criterion['g'](d_feats_real, d_feats_gen) 267 | d_loss = criterion['d'](d_logits_real, d_logits_gen) 268 | 269 | g_loss_total += g_loss.item() 270 | d_loss_total += d_loss.item() 271 | num_corrects += (d_logits_real > 0.5).sum().item() + (d_logits_gen < 0.5).sum().item() 272 | num_sample += real_batch_sz 273 | 274 | g_loss_avg, d_loss_avg = 0.0, 0.0 275 | d_acc = 0.0 276 | if num_sample > 0: 277 | g_loss_avg = g_loss_total / num_sample 278 | d_loss_avg = d_loss_total / num_sample 279 | d_acc = 100 * num_corrects / (2 * num_sample) # 2 because (real + generated) 280 | 281 | print("Val: ", log_sum_real / num_sample, log_sum_gen / num_sample) 282 | 283 | return g_loss_avg, d_loss_avg, d_acc 284 | 285 | 286 | def generate_sample(g_model, num_sample=1): 287 | ''' Sample from generator 288 | ''' 289 | num_feats = g_model.num_feats 290 | g_states = g_model.init_hidden(num_sample) 291 | 292 | z = torch.empty([num_sample, MAX_SEQ_LEN, num_feats]).uniform_() # random vector 293 | 294 | g_feats, _ = g_model(z, g_states) 295 | return g_feats 296 | 297 | 298 | def main(args): 299 | ''' Training sequence 300 | ''' 301 | trn_dataloader = dummy_dataloader(MAX_SEQ_LEN, BATCH_SIZE, NUM_DUMMY_TRN) 302 | val_dataloader = dummy_dataloader(MAX_SEQ_LEN, BATCH_SIZE, NUM_DUMMY_VAL) 303 | 304 | # First checking if GPU is available 305 | train_on_gpu = torch.cuda.is_available() 306 | if train_on_gpu: 307 | print('Training on GPU.') 308 | else: 309 | print('No GPU available, training on CPU.') 310 | 311 | model = { 312 | 'g': Generator(num_feats=1, use_cuda=train_on_gpu), 313 | 'd': Discriminator(num_feats=1, use_cuda=train_on_gpu) 314 | } 315 | 316 | optimizer = { 317 | # 'g': optim.SGD(model['g'].parameters(), G_LRN_RATE, weight_decay=L2_DECAY), 318 | 'g': optim.Adam(model['g'].parameters(), G_LRN_RATE), 319 | 'd': optim.Adam(model['d'].parameters(), D_LRN_RATE) 320 | } 321 | 322 | criterion = { 323 | 'g': nn.MSELoss(reduction='sum'), # feature matching 324 | 'd': DLoss() 325 | } 326 | 327 | if args.load_g: 328 | ckpt = torch.load(os.path.join(CKPT_DIR, G_FN)) 329 | model['g'].load_state_dict(ckpt) 330 | print("Continue training of %s" % os.path.join(CKPT_DIR, G_FN)) 331 | 332 | if args.load_d: 333 | ckpt = torch.load(os.path.join(CKPT_DIR, D_FN)) 334 | model['d'].load_state_dict(ckpt) 335 | print("Continue training of %s" % os.path.join(CKPT_DIR, D_FN)) 336 | 337 | if train_on_gpu: 338 | model['g'].cuda() 339 | model['d'].cuda() 340 | 341 | if not args.no_pretraining: 342 | for ep in range(args.pretraining_epochs): 343 | model, trn_g_loss, trn_d_loss, trn_acc = \ 344 | run_training(model, optimizer, criterion, trn_dataloader, ep, freeze_g=True) 345 | val_g_loss, val_d_loss, val_acc = run_validation(model, criterion, val_dataloader) 346 | 347 | sample = generate_sample(model['g']) 348 | 349 | print("Epoch %d/%d\n" 350 | "\t[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f\n" 351 | "\t[Validation] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f" % 352 | (ep+1, args.num_epochs, trn_g_loss, trn_d_loss, trn_acc, 353 | val_g_loss, val_d_loss, val_acc)) 354 | 355 | print(sample) 356 | 357 | for ep in range(args.pretraining_epochs): 358 | model, trn_g_loss, trn_d_loss, trn_acc = \ 359 | run_training(model, optimizer, criterion, trn_dataloader, ep, freeze_d=False) 360 | val_g_loss, val_d_loss, val_acc = run_validation(model, criterion, val_dataloader) 361 | 362 | sample = generate_sample(model['g']) 363 | 364 | print("Epoch %d/%d\n" 365 | "\t[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f\n" 366 | "\t[Validation] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f" % 367 | (ep+1, args.num_epochs, trn_g_loss, trn_d_loss, trn_acc, 368 | val_g_loss, val_d_loss, val_acc)) 369 | 370 | print(sample) 371 | 372 | 373 | for ep in range(args.num_epochs): 374 | model, trn_g_loss, trn_d_loss, trn_acc = run_training(model, optimizer, criterion, trn_dataloader, ep) 375 | val_g_loss, val_d_loss, val_acc = run_validation(model, criterion, val_dataloader) 376 | 377 | sample = generate_sample(model['g']) 378 | 379 | print("Epoch %d/%d\n" 380 | "\t[Training] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f\n" 381 | "\t[Validation] G_loss: %0.8f, D_loss: %0.8f, D_acc: %0.2f" % 382 | (ep+1, args.num_epochs, trn_g_loss, trn_d_loss, trn_acc, 383 | val_g_loss, val_d_loss, val_acc)) 384 | print(sample) 385 | 386 | # sampling (to check if generator really learns) 387 | 388 | if not args.no_save_g: 389 | torch.save(model['g'].state_dict(), os.path.join(CKPT_DIR, G_FN)) 390 | print("Saved generator: %s" % os.path.join(CKPT_DIR, G_FN)) 391 | 392 | if not args.no_save_d: 393 | torch.save(model['d'].state_dict(), os.path.join(CKPT_DIR, D_FN)) 394 | print("Saved discriminator: %s" % os.path.join(CKPT_DIR, D_FN)) 395 | 396 | 397 | if __name__ == "__main__": 398 | 399 | ARG_PARSER = ArgumentParser() 400 | ARG_PARSER.add_argument('--load_g', action='store_true') 401 | ARG_PARSER.add_argument('--load_d', action='store_true') 402 | ARG_PARSER.add_argument('--no_save_g', action='store_true') 403 | ARG_PARSER.add_argument('--no_save_d', action='store_true') 404 | ARG_PARSER.add_argument('--freeze_g', action='store_true') 405 | ARG_PARSER.add_argument('--freeze_d', action='store_true') 406 | ARG_PARSER.add_argument('--num_epochs', default=200, type=int) 407 | ARG_PARSER.add_argument('--seq_len', default=8, type=int) 408 | ARG_PARSER.add_argument('--batch_size', default=32, type=int) 409 | 410 | ARG_PARSER.add_argument('-m', action='store_true') 411 | ARG_PARSER.add_argument('--no_pretraining', action='store_true') 412 | ARG_PARSER.add_argument('--pretraining_epochs', default=10, type=int) 413 | 414 | ARGS = ARG_PARSER.parse_args() 415 | MAX_SEQ_LEN = ARGS.seq_len 416 | BATCH_SIZE = ARGS.batch_size 417 | FREEZE_G = ARGS.freeze_g 418 | FREEZE_D = ARGS.freeze_d 419 | 420 | main(ARGS) 421 | --------------------------------------------------------------------------------