├── .gitignore ├── LICENSE ├── README.md ├── data ├── midi │ └── .gitignore ├── play │ └── .gitignore ├── samples │ └── .gitignore └── test │ ├── initiator.json │ ├── initiator_old.json │ └── midi_keyboard_correspondance.png ├── deepmusic ├── __init__.py ├── composer.py ├── imgconnector.py ├── keyboardcell.py ├── midiconnector.py ├── model.py ├── model_old.py ├── moduleloader.py ├── modulemanager.py ├── modules │ ├── batchbuilder.py │ ├── decoder.py │ ├── encoder.py │ ├── learningratepolicy.py │ └── loopprocessing.py ├── musicdata.py ├── songstruct.py └── tfutils.py ├── docs ├── ideas.md ├── imgs │ ├── basic_rnn.png │ ├── endecell.png │ ├── training_begin.png │ └── training_end.png ├── midi.md ├── midi │ ├── basic_rnn_joplin.mid │ ├── basic_rnn_ragtime_-38000-0-C4.mid │ ├── basic_rnn_ragtime_BaseStructure.mid │ ├── basic_rnn_ragtime_TempoChange.mid │ └── basic_rnn_ragtime_structure.mid └── models.md ├── main.py ├── save └── .gitignore └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | docs/mp3/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # Pycharm 10 | .idea/ 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *,cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # IPython Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MusicGenerator 2 | 3 | ## Presentation 4 | 5 | Experiment diverse Deep learning models for music generation with TensorFlow 6 | 7 | ## Results 8 | 9 | The different models and experiments are explained [here](docs/models.md). 10 | 11 | ## Installation 12 | 13 | The program requires the following dependencies (easy to install using pip): 14 | * Python 3 15 | * TensorFlow (tested with v0.10.0rc0. Won't work with previous versions) 16 | * CUDA (for using gpu, see TensorFlow [installation page](https://www.tensorflow.org/versions/master/get_started/os_setup.html#optional-install-cuda-gpus-on-linux) for more details) 17 | * Numpy (should be installed with TensorFlow) 18 | * Mido (midi library) 19 | * Tqdm (for the nice progression bars) 20 | * OpenCv (Sorry, there is no simple way to install it with python 3. It's primarily used as visualisation tool to print the piano roll so is quite optional. All OpenCv calls are contained inside the imgconnector file so if you want to use test the program without OpenCv, you can try removing the functions inside the file) 21 | 22 | ## Running 23 | 24 | To train the model, simply run `main.py`. Once trained, you can generate the results with `main.py --test --sample_length 500`. For more help and options, use `python main.py -h`. 25 | 26 | To visualize the computational graph and the cost with TensorBoard, run `tensorboard --logdir save/`. 27 | -------------------------------------------------------------------------------- /data/midi/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /data/play/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /data/samples/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /data/test/initiator.json: -------------------------------------------------------------------------------- 1 | {"initiator":[ 2 | {"name":"30_40_60", 3 | "seq":[ 4 | {"notes":[30, 40, 60]} 5 | ]}, 6 | {"name":"DoReMi", 7 | "seq":[ 8 | {"notes":[60]}, 9 | {"notes":[62]}, 10 | {"notes":[64]} 11 | ]}, 12 | {"name":"Do_re_mi_", 13 | "seq":[ 14 | {"notes":[60]}, 15 | {"notes":[]}, 16 | {"notes":[]}, 17 | {"notes":[]}, 18 | {"notes":[61]}, 19 | {"notes":[]}, 20 | {"notes":[]}, 21 | {"notes":[]}, 22 | {"notes":[62]}, 23 | {"notes":[]}, 24 | {"notes":[]}, 25 | {"notes":[]} 26 | ]}, 27 | {"name":"Chromatic", 28 | "seq":[ 29 | {"notes":[60]}, 30 | {"notes":[61]}, 31 | {"notes":[62]} 32 | ]}, 33 | {"name":"Chromatic_", 34 | "seq":[ 35 | {"notes":[60]}, 36 | {"notes":[]}, 37 | {"notes":[]}, 38 | {"notes":[]}, 39 | {"notes":[61]}, 40 | {"notes":[]}, 41 | {"notes":[]}, 42 | {"notes":[]}, 43 | {"notes":[62]}, 44 | {"notes":[]}, 45 | {"notes":[]}, 46 | {"notes":[]} 47 | ]}, 48 | {"name":"C_major", 49 | "seq":[ 50 | {"notes":[60, 64, 67]} 51 | ]}, 52 | {"name":"C_M_m", 53 | "seq":[ 54 | {"notes":[60, 64, 67]}, 55 | {"notes":[60, 63, 67]} 56 | ]}, 57 | {"name":"C_minor", 58 | "seq":[ 59 | {"notes":[60, 63, 67]} 60 | ]} 61 | ]} 62 | -------------------------------------------------------------------------------- /data/test/initiator_old.json: -------------------------------------------------------------------------------- 1 | {"initiator":[ 2 | {"name":"An_empty_song", 3 | "seq":[ 4 | {"notes":[]} 5 | ]}, 6 | {"name":"C4", 7 | "seq":[ 8 | {"notes":[60]} 9 | ]}, 10 | {"name":"30_40_60", 11 | "seq":[ 12 | {"notes":[30, 40, 60]} 13 | ]}, 14 | {"name":"D4", 15 | "seq":[ 16 | {"notes":[62]} 17 | ]}, 18 | {"name":"E4", 19 | "seq":[ 20 | {"notes":[64]} 21 | ]}, 22 | {"name":"F4", 23 | "seq":[ 24 | {"notes":[65]} 25 | ]}, 26 | {"name":"C4_sharp", 27 | "seq":[ 28 | {"notes":[61]} 29 | ]}, 30 | {"name":"D4_sharp", 31 | "seq":[ 32 | {"notes":[63]} 33 | ]}, 34 | {"name":"G4", 35 | "seq":[ 36 | {"notes":[67]} 37 | ]}, 38 | {"name":"DoReMi", 39 | "seq":[ 40 | {"notes":[60]}, 41 | {"notes":[62]}, 42 | {"notes":[64]} 43 | ]}, 44 | {"name":"Do_re_mi_", 45 | "seq":[ 46 | {"notes":[60]}, 47 | {"notes":[]}, 48 | {"notes":[]}, 49 | {"notes":[]}, 50 | {"notes":[61]}, 51 | {"notes":[]}, 52 | {"notes":[]}, 53 | {"notes":[]}, 54 | {"notes":[62]}, 55 | {"notes":[]}, 56 | {"notes":[]}, 57 | {"notes":[]} 58 | ]}, 59 | {"name":"Chromatic", 60 | "seq":[ 61 | {"notes":[60]}, 62 | {"notes":[61]}, 63 | {"notes":[62]} 64 | ]}, 65 | {"name":"Chromatic_", 66 | "seq":[ 67 | {"notes":[60]}, 68 | {"notes":[]}, 69 | {"notes":[]}, 70 | {"notes":[]}, 71 | {"notes":[61]}, 72 | {"notes":[]}, 73 | {"notes":[]}, 74 | {"notes":[]}, 75 | {"notes":[62]}, 76 | {"notes":[]}, 77 | {"notes":[]}, 78 | {"notes":[]} 79 | ]}, 80 | {"name":"C_major", 81 | "seq":[ 82 | {"notes":[60, 64, 67]} 83 | ]}, 84 | {"name":"C_minor", 85 | "seq":[ 86 | {"notes":[60, 63, 67]} 87 | ]} 88 | ]} 89 | -------------------------------------------------------------------------------- /data/test/midi_keyboard_correspondance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/data/test/midi_keyboard_correspondance.png -------------------------------------------------------------------------------- /deepmusic/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["composer"] 2 | 3 | from deepmusic.composer import Composer 4 | -------------------------------------------------------------------------------- /deepmusic/composer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Music composer. Act as the coordinator. Orchestrate and call the different models, see the readme for more details. 18 | 19 | Use python 3 20 | """ 21 | 22 | import argparse # Command line parsing 23 | import configparser # Saving the models parameters 24 | import datetime # Chronometer 25 | import os # Files management 26 | from tqdm import tqdm # Progress bar 27 | import tensorflow as tf 28 | import gc # Manual garbage collect before each epoch 29 | 30 | from deepmusic.moduleloader import ModuleLoader 31 | from deepmusic.musicdata import MusicData 32 | from deepmusic.midiconnector import MidiConnector 33 | from deepmusic.imgconnector import ImgConnector 34 | from deepmusic.model import Model 35 | 36 | 37 | class Composer: 38 | """ 39 | Main class which launch the training or testing mode 40 | """ 41 | 42 | class TestMode: 43 | """ Simple structure representing the different testing modes 44 | """ 45 | ALL = 'all' # The network try to generate a new original composition with all models present (with the tag) 46 | DAEMON = 'daemon' # Runs on background and can regularly be called to predict something (Not implemented) 47 | INTERACTIVE = 'interactive' # The user start a melodie and the neural network complete (Not implemented) 48 | 49 | @staticmethod 50 | def get_test_modes(): 51 | """ Return the list of the different testing modes 52 | Useful on when parsing the command lines arguments 53 | """ 54 | return [Composer.TestMode.ALL, Composer.TestMode.DAEMON, Composer.TestMode.INTERACTIVE] 55 | 56 | def __init__(self): 57 | """ 58 | """ 59 | # Model/dataset parameters 60 | self.args = None 61 | 62 | # Task specific objects 63 | self.music_data = None # Dataset 64 | self.model = None # Base model class 65 | 66 | # TensorFlow utilities for convenience saving/logging 67 | self.writer = None 68 | self.writer_test = None 69 | self.saver = None 70 | self.model_dir = '' # Where the model is saved 71 | self.glob_step = 0 # Represent the number of iteration for the current model 72 | 73 | # TensorFlow main session (we keep track for the daemon) 74 | self.sess = None 75 | 76 | # Filename and directories constants 77 | self.MODEL_DIR_BASE = 'save/model' 78 | self.MODEL_NAME_BASE = 'model' 79 | self.MODEL_EXT = '.ckpt' 80 | self.CONFIG_FILENAME = 'params.ini' 81 | self.CONFIG_VERSION = '0.5' # Ensure to raise a warning if there is a change in the format 82 | 83 | self.TRAINING_VISUALIZATION_STEP = 1000 # Plot a training sample every x iterations (Warning: There is a really low probability that on a epoch, it's always the same testing bach which is visualized) 84 | self.TRAINING_VISUALIZATION_DIR = 'progression' 85 | self.TESTING_VISUALIZATION_DIR = 'midi' # Would 'generated', 'output' or 'testing' be a best folder name ? 86 | 87 | @staticmethod 88 | def _parse_args(args): 89 | """ 90 | Parse the arguments from the given command line 91 | Args: 92 | args (list): List of arguments to parse. If None, the default sys.argv will be parsed 93 | """ 94 | 95 | parser = argparse.ArgumentParser() 96 | 97 | # Global options 98 | global_args = parser.add_argument_group('Global options') 99 | global_args.add_argument('--test', nargs='?', choices=Composer.TestMode.get_test_modes(), const=Composer.TestMode.ALL, default=None, 100 | help='if present, launch the program try to answer all sentences from data/test/ with' 101 | ' the defined model(s), in interactive mode, the user can wrote his own sentences,' 102 | ' use daemon mode to integrate the chatbot in another program') 103 | global_args.add_argument('--reset', action='store_true', help='use this if you want to ignore the previous model present on the model directory (Warning: the model will be destroyed with all the folder content)') 104 | global_args.add_argument('--keep_all', action='store_true', help='if this option is set, all saved model will be keep (Warning: make sure you have enough free disk space or increase save_every)') # TODO: Add an option to delimit the max size 105 | global_args.add_argument('--model_tag', type=str, default=None, help='tag to differentiate which model to store/load') 106 | global_args.add_argument('--sample_length', type=int, default=40, help='number of time units (steps) of a training sentence, length of the sequence to generate') # Warning: the unit is defined by the MusicData.MAXIMUM_SONG_RESOLUTION parameter 107 | global_args.add_argument('--root_dir', type=str, default=None, help='folder where to look for the models and data') 108 | global_args.add_argument('--device', type=str, default=None, help='\'gpu\' or \'cpu\' (Warning: make sure you have enough free RAM), allow to choose on which hardware run the model') 109 | global_args.add_argument('--temperature', type=float, default=1.0, help='Used when testing, control the ouput sampling') 110 | 111 | # Dataset options 112 | dataset_args = parser.add_argument_group('Dataset options') 113 | dataset_args.add_argument('--dataset_tag', type=str, default='ragtimemusic', help='tag to differentiate which data use (if the data are not present, the program will try to generate from the midi folder)') 114 | dataset_args.add_argument('--create_dataset', action='store_true', help='if present, the program will only generate the dataset from the corpus (no training/testing)') 115 | dataset_args.add_argument('--play_dataset', type=int, nargs='?', const=10, default=None, help='if set, the program will randomly play some samples(can be use conjointly with create_dataset if this is the only action you want to perform)') # TODO: Play midi ? / Or show sample images ? Both ? 116 | dataset_args.add_argument('--ratio_dataset', type=float, default=0.9, help='ratio of songs between training/testing. The ratio is fixed at the beginning and cannot be changed') 117 | ModuleLoader.batch_builders.add_argparse(dataset_args, 'Control the song representation for the inputs of the neural network.') 118 | 119 | # Network options (Warning: if modifying something here, also make the change on save/restore_params() ) 120 | nn_args = parser.add_argument_group('Network options', 'architecture related option') 121 | ModuleLoader.enco_cells.add_argparse(nn_args, 'Encoder cell used.') 122 | ModuleLoader.deco_cells.add_argparse(nn_args, 'Decoder cell used.') 123 | nn_args.add_argument('--hidden_size', type=int, default=512, help='Size of one neural network layer') 124 | nn_args.add_argument('--num_layers', type=int, default=2, help='Nb of layers of the RNN') 125 | nn_args.add_argument('--scheduled_sampling', type=str, nargs='+', default=[Model.ScheduledSamplingPolicy.NONE], help='Define the schedule sampling policy. If set, have to indicates the parameters of the chosen policy') 126 | nn_args.add_argument('--target_weights', nargs='?', choices=Model.TargetWeightsPolicy.get_policies(), default=Model.TargetWeightsPolicy.LINEAR, 127 | help='policy to choose the loss contribution of each step') 128 | ModuleLoader.loop_processings.add_argparse(nn_args, 'Transformation to apply on each ouput.') 129 | 130 | # Training options (Warning: if modifying something here, also make the change on save/restore_params() ) 131 | training_args = parser.add_argument_group('Training options') 132 | training_args.add_argument('--num_epochs', type=int, default=0, help='maximum number of epochs to run (0 for infinity)') 133 | training_args.add_argument('--save_every', type=int, default=1000, help='nb of mini-batch step before creating a model checkpoint') 134 | training_args.add_argument('--batch_size', type=int, default=64, help='mini-batch size') 135 | ModuleLoader.learning_rate_policies.add_argparse(training_args, 'Learning rate initial value and decay policy.') 136 | training_args.add_argument('--testing_curve', type=int, default=10, help='Also record the testing curve each every x iteration (given by the parameter)') 137 | 138 | return parser.parse_args(args) 139 | 140 | def main(self, args=None): 141 | """ 142 | Launch the training and/or the interactive mode 143 | """ 144 | print('Welcome to DeepMusic v0.1 !') 145 | print() 146 | print('TensorFlow detected: v{}'.format(tf.__version__)) 147 | 148 | # General initialisations 149 | 150 | tf.logging.set_verbosity(tf.logging.INFO) # DEBUG, INFO, WARN (default), ERROR, or FATAL 151 | 152 | ModuleLoader.register_all() # Load available modules 153 | self.args = self._parse_args(args) 154 | if not self.args.root_dir: 155 | self.args.root_dir = os.getcwd() # Use the current working directory 156 | 157 | self._restore_params() # Update the self.model_dir and self.glob_step, for now, not used when loading Model 158 | self._print_params() 159 | 160 | self.music_data = MusicData(self.args) 161 | if self.args.create_dataset: 162 | print('Dataset created! You can start training some models.') 163 | return # No need to go further 164 | 165 | with tf.device(self._get_device()): 166 | self.model = Model(self.args) 167 | 168 | # Saver/summaries 169 | self.writer = tf.train.SummaryWriter(os.path.join(self.model_dir, 'train')) 170 | self.writer_test = tf.train.SummaryWriter(os.path.join(self.model_dir, 'test')) 171 | self.saver = tf.train.Saver(max_to_keep=200) # Set the arbitrary limit ? 172 | 173 | # TODO: Fixed seed (WARNING: If dataset shuffling, make sure to do that after saving the 174 | # dataset, otherwise, all what comes after the shuffling won't be replicable when 175 | # reloading the dataset). How to restore the seed after loading ?? (with get_state/set_state) 176 | # Also fix seed for np.random (does it works globally for all files ?) 177 | 178 | # Running session 179 | 180 | self.sess = tf.Session() 181 | 182 | print('Initialize variables...') 183 | self.sess.run(tf.initialize_all_variables()) 184 | 185 | # Reload the model eventually (if it exist), on testing mode, the models are not loaded here (but in main_test()) 186 | self._restore_previous_model(self.sess) 187 | 188 | if self.args.test: 189 | if self.args.test == Composer.TestMode.ALL: 190 | self._main_test() 191 | elif self.args.test == Composer.TestMode.DAEMON: 192 | print('Daemon mode, running in background...') 193 | raise NotImplementedError('No daemon mode') # Come back later 194 | else: 195 | raise RuntimeError('Unknown test mode: {}'.format(self.args.test)) # Should never happen 196 | else: 197 | self._main_train() 198 | 199 | if self.args.test != Composer.TestMode.DAEMON: 200 | self.sess.close() 201 | print('The End! Thanks for using this program') 202 | 203 | def _main_train(self): 204 | """ Training loop 205 | """ 206 | assert self.sess 207 | 208 | # Specific training dependent loading (Warning: When restoring a model, we don't restore the progression 209 | # bar, nor the current batch.) 210 | 211 | merged_summaries = tf.merge_all_summaries() 212 | if self.glob_step == 0: # Not restoring from previous run 213 | self.writer.add_graph(self.sess.graph) # First time only 214 | 215 | print('Start training (press Ctrl+C to save and exit)...') 216 | 217 | try: # If the user exit while training, we still try to save the model 218 | e = 0 219 | while self.args.num_epochs == 0 or e < self.args.num_epochs: # Main training loop (infinite if num_epoch==0) 220 | e += 1 221 | 222 | print() 223 | print('------- Epoch {} (lr={}) -------'.format( 224 | '{}/{}'.format(e, self.args.num_epochs) if self.args.num_epochs else '{}'.format(e), 225 | self.model.learning_rate_policy.get_learning_rate(self.glob_step)) 226 | ) 227 | 228 | # Explicit garbage collector call (clear the previous batches) 229 | gc.collect() # TODO: Better memory management (use generators,...) 230 | 231 | batches_train, batches_test = self.music_data.get_batches() 232 | 233 | # Also update learning parameters eventually ?? (Some is done in the model class with the policy classes) 234 | 235 | tic = datetime.datetime.now() 236 | for next_batch in tqdm(batches_train, desc='Training'): # Iterate over the batches 237 | # TODO: Could compute the perfs (time forward pass vs time batch pre-processing) 238 | # Indicate if the output should be computed or not 239 | is_output_visualized = self.glob_step % self.TRAINING_VISUALIZATION_STEP == 0 240 | 241 | # Training pass 242 | ops, feed_dict = self.model.step( 243 | next_batch, 244 | train_set=True, 245 | glob_step=self.glob_step, 246 | ret_output=is_output_visualized 247 | ) 248 | outputs_train = self.sess.run((merged_summaries,) + ops, feed_dict) 249 | self.writer.add_summary(outputs_train[0], self.glob_step) 250 | 251 | # Testing pass (record the testing curve and visualize some testing predictions) 252 | # TODO: It makes no sense to completely disable the ground truth feeding (it's impossible to the 253 | # network to do a good prediction with only the first step) 254 | if is_output_visualized or (self.args.testing_curve and self.glob_step % self.args.testing_curve == 0): 255 | next_batch_test = batches_test[self.glob_step % len(batches_test)] # Generate test batches in a cycling way (test set smaller than train set) 256 | ops, feed_dict = self.model.step( 257 | next_batch_test, 258 | train_set=False, 259 | ret_output=is_output_visualized 260 | ) 261 | outputs_test = self.sess.run((merged_summaries,) + ops, feed_dict) 262 | self.writer_test.add_summary(outputs_test[0], self.glob_step) 263 | 264 | # Some visualisation (we compute some training/testing samples and compare them to the ground truth) 265 | if is_output_visualized: 266 | visualization_base_name = os.path.join(self.model_dir, self.TRAINING_VISUALIZATION_DIR, str(self.glob_step)) 267 | tqdm.write('Visualizing: ' + visualization_base_name) 268 | self._visualize_output( 269 | visualization_base_name, 270 | outputs_train[-1], 271 | outputs_test[-1] # The network output will always be the last operator returned by model.step() 272 | ) 273 | 274 | # Checkpoint 275 | self.glob_step += 1 # Iterate here to avoid saving at the first iteration 276 | if self.glob_step % self.args.save_every == 0: 277 | self._save_session(self.sess) 278 | 279 | toc = datetime.datetime.now() 280 | 281 | print('Epoch finished in {}'.format(toc-tic)) # Warning: Will overflow if an epoch takes more than 24 hours, and the output isn't really nicer 282 | except (KeyboardInterrupt, SystemExit): # If the user press Ctrl+C while testing progress 283 | print('Interruption detected, exiting the program...') 284 | 285 | self._save_session(self.sess) # Ultimate saving before complete exit 286 | 287 | def _main_test(self): 288 | """ Generate some songs 289 | The midi files will be saved on the same model_dir 290 | """ 291 | assert self.sess 292 | assert self.args.batch_size == 1 293 | 294 | print('Start predicting...') 295 | 296 | model_list = self._get_model_list() 297 | if not model_list: 298 | print('Warning: No model found in \'{}\'. Please train a model before trying to predict'.format(self.model_dir)) 299 | return 300 | 301 | batches, names = self.music_data.get_batches_test_old() 302 | samples = list(zip(batches, names)) 303 | 304 | # Predicting for each model present in modelDir 305 | for model_name in tqdm(sorted(model_list), desc='Model', unit='model'): # TODO: Natural sorting / TODO: tqdm ? 306 | self.saver.restore(self.sess, model_name) 307 | 308 | for next_sample in tqdm(samples, desc='Generating ({})'.format(os.path.basename(model_name)), unit='songs'): 309 | batch = next_sample[0] 310 | name = next_sample[1] # Unzip 311 | 312 | ops, feed_dict = self.model.step(batch) 313 | assert len(ops) == 2 # sampling, output 314 | chosen_labels, outputs = self.sess.run(ops, feed_dict) 315 | 316 | model_dir, model_filename = os.path.split(model_name) 317 | model_dir = os.path.join(model_dir, self.TESTING_VISUALIZATION_DIR) 318 | model_filename = model_filename[:-len(self.MODEL_EXT)] + '-' + name 319 | 320 | # Save piano roll as image (color map red/blue to see the prediction confidence) 321 | # Save the midi file 322 | self.music_data.visit_recorder( 323 | outputs, 324 | model_dir, 325 | model_filename, 326 | [ImgConnector, MidiConnector], 327 | chosen_labels=chosen_labels 328 | ) 329 | # TODO: Print song statistics (nb of generated notes, closest songs in dataset ?, try to compute a 330 | # score to indicate potentially interesting songs (low score if too repetitive) ?,...). Create new 331 | # visited recorder class ? 332 | # TODO: Include infos on potentially interesting songs (include metric in the name ?), we should try to detect 333 | # the loops, simple metric: nb of generated notes, nb of unique notes (Metric: 2d 334 | # tensor [NB_NOTES, nb_of_time_the_note_is played], could plot histogram normalized by nb of 335 | # notes). Is piano roll enough ? 336 | 337 | print('Prediction finished, {} songs generated'.format(self.args.batch_size * len(model_list) * len(batches))) 338 | 339 | def _visualize_output(self, visualization_base_name, outputs_train, outputs_test): 340 | """ Record some result/generated songs during training. 341 | This allow to see the training progression and get an idea of what the network really learned 342 | Args: 343 | visualization_base_name (str): 344 | outputs_train: Output of the forward pass(training set) 345 | outputs_test: Output of the forward pass (testing set) 346 | """ 347 | # Record: 348 | # * Training/testing: 349 | # * Prediction/ground truth: 350 | # * piano roll 351 | # * midi file 352 | # Format name: ---. 353 | # TODO: Also records the ground truth 354 | 355 | model_dir, model_filename = os.path.split(visualization_base_name) 356 | for output, set_name in [(outputs_train, 'train'), (outputs_test, 'test')]: 357 | self.music_data.visit_recorder( 358 | output, 359 | model_dir, 360 | model_filename + '-' + set_name, 361 | [ImgConnector, MidiConnector] 362 | ) 363 | 364 | def _restore_previous_model(self, sess): 365 | """ Restore or reset the model, depending of the parameters 366 | If testing mode is set, the function has no effect 367 | If the destination directory already contains some file, it will handle the conflict as following: 368 | * If --reset is set, all present files will be removed (warning: no confirmation is asked) and the training 369 | restart from scratch (globStep & cie reinitialized) 370 | * Otherwise, it will depend of the directory content. If the directory contains: 371 | * No model files (only summary logs): works as a reset (restart from scratch) 372 | * Other model files, but model_name not found (surely keep_all option changed): raise error, the user should 373 | decide by himself what to do 374 | * The right model file (eventually some other): no problem, simply resume the training 375 | In any case, the directory will exist as it has been created by the summary writer 376 | Args: 377 | sess: The current running session 378 | """ 379 | 380 | if self.args.test == Composer.TestMode.ALL: # On testing, the models are not restored here 381 | return 382 | 383 | print('WARNING: ', end='') 384 | 385 | model_name = self._get_model_name() 386 | 387 | if os.listdir(self.model_dir): 388 | if self.args.reset: 389 | print('Reset: Destroying previous model at {}'.format(self.model_dir)) 390 | # Analysing directory content 391 | elif os.path.exists(model_name): # Restore the model 392 | print('Restoring previous model from {}'.format(model_name)) 393 | self.saver.restore(sess, model_name) # Will crash when --reset is not activated and the model has not been saved yet 394 | print('Model restored.') 395 | elif self._get_model_list(): 396 | print('Conflict with previous models.') 397 | raise RuntimeError('Some models are already present in \'{}\'. You should check them first'.format(self.model_dir)) 398 | else: # No other model to conflict with (probably summary files) 399 | print('No previous model found, but some files/folders found at {}. Cleaning...'.format(self.model_dir)) # Warning: No confirmation asked 400 | self.args.reset = True 401 | 402 | if self.args.reset: 403 | # WARNING: No confirmation is asked. All subfolders will be deleted 404 | for root, dirs, files in os.walk(self.model_dir, topdown=False): 405 | for name in files: 406 | file_path = os.path.join(root, name) 407 | print('Removing {}'.format(file_path)) 408 | os.remove(file_path) 409 | else: 410 | print('No previous model found, starting from clean directory: {}'.format(self.model_dir)) 411 | 412 | def _save_session(self, sess): 413 | """ Save the model parameters and the variables 414 | Args: 415 | sess: the current session 416 | """ 417 | tqdm.write('Checkpoint reached: saving model (don\'t stop the run)...') 418 | self._save_params() 419 | self.saver.save(sess, self._get_model_name()) # Put a limit size (ex: 3GB for the model_dir) ? 420 | tqdm.write('Model saved.') 421 | 422 | def _restore_params(self): 423 | """ Load the some values associated with the current model, like the current glob_step value. 424 | Needs to be called before any other function because it initialize some variables used on the rest of the 425 | program 426 | 427 | Warning: if you modify this function, make sure the changes mirror _save_params, also check if the parameters 428 | should be reset in manage_previous_model 429 | """ 430 | # Compute the current model path 431 | self.model_dir = os.path.join(self.args.root_dir, self.MODEL_DIR_BASE) 432 | if self.args.model_tag: 433 | self.model_dir += '-' + self.args.model_tag 434 | 435 | # If there is a previous model, restore some parameters 436 | config_name = os.path.join(self.model_dir, self.CONFIG_FILENAME) 437 | if not self.args.reset and not self.args.create_dataset and os.path.exists(config_name): 438 | # Loading 439 | config = configparser.ConfigParser() 440 | config.read(config_name) 441 | 442 | # Check the version 443 | current_version = config['General'].get('version') 444 | if current_version != self.CONFIG_VERSION: 445 | raise UserWarning('Present configuration version {0} does not match {1}. You can try manual changes on \'{2}\''.format(current_version, self.CONFIG_VERSION, config_name)) 446 | 447 | # Restoring the the parameters 448 | self.glob_step = config['General'].getint('glob_step') 449 | self.args.keep_all = config['General'].getboolean('keep_all') 450 | self.args.dataset_tag = config['General'].get('dataset_tag') 451 | if not self.args.test: # When testing, we don't use the training length 452 | self.args.sample_length = config['General'].getint('sample_length') 453 | 454 | self.args.hidden_size = config['Network'].getint('hidden_size') 455 | self.args.num_layers = config['Network'].getint('num_layers') 456 | self.args.target_weights = config['Network'].get('target_weights') 457 | self.args.scheduled_sampling = config['Network'].get('scheduled_sampling').split(' ') 458 | 459 | self.args.batch_size = config['Training'].getint('batch_size') 460 | self.args.save_every = config['Training'].getint('save_every') 461 | self.args.ratio_dataset = config['Training'].getfloat('ratio_dataset') 462 | self.args.testing_curve = config['Training'].getint('testing_curve') 463 | 464 | ModuleLoader.load_all(self.args, config) 465 | 466 | # Show the restored params 467 | print('Warning: Restoring parameters from previous configuration (you should manually edit the file if you want to change one of those)') 468 | 469 | # When testing, only predict one song at the time 470 | if self.args.test: 471 | self.args.batch_size = 1 472 | self.args.scheduled_sampling = [Model.ScheduledSamplingPolicy.NONE] 473 | 474 | def _save_params(self): 475 | """ Save the params of the model, like the current glob_step value 476 | Warning: if you modify this function, make sure the changes mirror load_params 477 | """ 478 | config = configparser.ConfigParser() 479 | config['General'] = {} 480 | config['General']['version'] = self.CONFIG_VERSION 481 | config['General']['glob_step'] = str(self.glob_step) 482 | config['General']['keep_all'] = str(self.args.keep_all) 483 | config['General']['dataset_tag'] = self.args.dataset_tag 484 | config['General']['sample_length'] = str(self.args.sample_length) 485 | 486 | config['Network'] = {} 487 | config['Network']['hidden_size'] = str(self.args.hidden_size) 488 | config['Network']['num_layers'] = str(self.args.num_layers) 489 | config['Network']['target_weights'] = self.args.target_weights # Could be modified manually 490 | config['Network']['scheduled_sampling'] = ' '.join(self.args.scheduled_sampling) 491 | 492 | # Keep track of the learning params (are not model dependent so can be manually edited) 493 | config['Training'] = {} 494 | config['Training']['batch_size'] = str(self.args.batch_size) 495 | config['Training']['save_every'] = str(self.args.save_every) 496 | config['Training']['ratio_dataset'] = str(self.args.ratio_dataset) 497 | config['Training']['testing_curve'] = str(self.args.testing_curve) 498 | 499 | # Save the chosen modules and their configuration 500 | ModuleLoader.save_all(config) 501 | 502 | with open(os.path.join(self.model_dir, self.CONFIG_FILENAME), 'w') as config_file: 503 | config.write(config_file) 504 | 505 | def _print_params(self): 506 | """ Print the current params 507 | """ 508 | print() 509 | print('Current parameters:') 510 | print('glob_step: {}'.format(self.glob_step)) 511 | print('keep_all: {}'.format(self.args.keep_all)) 512 | print('dataset_tag: {}'.format(self.args.dataset_tag)) 513 | print('sample_length: {}'.format(self.args.sample_length)) 514 | 515 | print('hidden_size: {}'.format(self.args.hidden_size)) 516 | print('num_layers: {}'.format(self.args.num_layers)) 517 | print('target_weights: {}'.format(self.args.target_weights)) 518 | print('scheduled_sampling: {}'.format(' '.join(self.args.scheduled_sampling))) 519 | 520 | print('batch_size: {}'.format(self.args.batch_size)) 521 | print('save_every: {}'.format(self.args.save_every)) 522 | print('ratio_dataset: {}'.format(self.args.ratio_dataset)) 523 | print('testing_curve: {}'.format(self.args.testing_curve)) 524 | 525 | ModuleLoader.print_all(self.args) 526 | print() 527 | 528 | def _get_model_name(self): 529 | """ Parse the argument to decide were to save/load the model 530 | This function is called at each checkpoint and the first time the model is load. If keep_all option is set, the 531 | glob_step value will be included in the name. 532 | Return: 533 | str: The path and name were the model need to be saved 534 | """ 535 | model_name = os.path.join(self.model_dir, self.MODEL_NAME_BASE) 536 | if self.args.keep_all: # We do not erase the previously saved model by including the current step on the name 537 | model_name += '-' + str(self.glob_step) 538 | return model_name + self.MODEL_EXT 539 | 540 | def _get_model_list(self): 541 | """ Return the list of the model files inside the model directory 542 | """ 543 | return [os.path.join(self.model_dir, f) for f in os.listdir(self.model_dir) if f.endswith(self.MODEL_EXT)] 544 | 545 | def _get_device(self): 546 | """ Parse the argument to decide on which device run the model 547 | Return: 548 | str: The name of the device on which run the program 549 | """ 550 | if self.args.device == 'cpu': 551 | return '"/cpu:0' 552 | elif self.args.device == 'gpu': # Won't work in case of multiple GPUs 553 | return '/gpu:0' 554 | elif self.args.device is None: # No specified device (default) 555 | return None 556 | else: 557 | print('Warning: Error in the device name: {}, use the default device'.format(self.args.device)) 558 | return None 559 | -------------------------------------------------------------------------------- /deepmusic/imgconnector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Image connector interface 18 | 19 | """ 20 | 21 | import cv2 as cv 22 | import numpy as np 23 | 24 | import deepmusic.songstruct as music # Should we use that to tuncate the top and bottom image ? 25 | 26 | 27 | class ImgConnector: 28 | """ Class to read and write songs (piano roll arrays) as images 29 | """ 30 | 31 | @staticmethod 32 | def load_file(filename): 33 | """ Extract data from midi file 34 | Args: 35 | filename (str): a valid img file 36 | Return: 37 | np.array: the piano roll associated with the 38 | """ 39 | # TODO ? Could be useful to load initiators created with Gimp (more intuitive than the current version) 40 | 41 | @staticmethod 42 | def write_song(piano_roll, filename): 43 | """ Save the song on disk 44 | Args: 45 | piano_roll (np.array): a song object containing the tracks and melody 46 | filename (str): the path were to save the song (don't add the file extension) 47 | """ 48 | note_played = piano_roll > 0.5 49 | piano_roll_int = np.uint8(piano_roll*255) 50 | 51 | b = piano_roll_int * (~note_played).astype(np.uint8) # Note silenced 52 | g = np.zeros(piano_roll_int.shape, dtype=np.uint8) # Empty channel 53 | r = piano_roll_int * note_played.astype(np.uint8) # Notes played 54 | 55 | img = cv.merge((b, g, r)) 56 | 57 | # TODO: We could insert a first column indicating the piano keys (black/white key) 58 | 59 | cv.imwrite(filename + '.png', img) 60 | 61 | @staticmethod 62 | def get_input_type(): 63 | return 'array' 64 | -------------------------------------------------------------------------------- /deepmusic/keyboardcell.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Main cell which predict the next keyboard configuration 18 | 19 | """ 20 | 21 | import collections 22 | import tensorflow as tf 23 | 24 | from deepmusic.moduleloader import ModuleLoader 25 | import deepmusic.songstruct as music 26 | 27 | 28 | class KeyboardCell(tf.contrib.rnn.RNNCell): 29 | """ Cell which wrap the encoder/decoder network 30 | """ 31 | 32 | def __init__(self, args): 33 | self.args = args 34 | self.is_init = False 35 | 36 | # Get the chosen enco/deco 37 | self.encoder = ModuleLoader.enco_cells.build_module(self.args) 38 | self.decoder = ModuleLoader.deco_cells.build_module(self.args) 39 | 40 | @property 41 | def state_size(self): 42 | raise NotImplementedError('Abstract method') 43 | 44 | @property 45 | def output_size(self): 46 | raise NotImplementedError('Abstract method') 47 | 48 | def __call__(self, prev_keyboard, prev_state, scope=None): 49 | """ Run the cell at step t 50 | Args: 51 | prev_keyboard: keyboard configuration for the step t-1 (Ground truth or previous step) 52 | prev_state: a tuple (prev_state_enco, prev_state_deco) 53 | scope: TensorFlow scope 54 | Return: 55 | Tuple: the keyboard configuration and the enco and deco states 56 | """ 57 | 58 | # First time only (we do the initialisation here to be on the global rnn loop scope) 59 | if not self.is_init: 60 | with tf.variable_scope('weights_keyboard_cell'): 61 | # TODO: With self.args, see which network we have chosen (create map 'network name':class) 62 | self.encoder.build() 63 | self.decoder.build() 64 | 65 | prev_state = self.encoder.init_state(), self.decoder.init_state() 66 | self.is_init = True 67 | 68 | # TODO: If encoder act as VAE, we should sample here, from the previous state 69 | 70 | # Encoder/decoder network 71 | with tf.variable_scope(scope or type(self).__name__): 72 | with tf.variable_scope('Encoder'): 73 | # TODO: Should be enco_output, enco_state 74 | next_state_enco = self.encoder.get_cell(prev_keyboard, prev_state) 75 | with tf.variable_scope('Decoder'): # Reset gate and update gate. 76 | next_keyboard, next_state_deco = self.decoder.get_cell(prev_keyboard, (next_state_enco, prev_state[1])) 77 | return next_keyboard, (next_state_enco, next_state_deco) 78 | -------------------------------------------------------------------------------- /deepmusic/midiconnector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Mid-level interface for the python files 18 | """ 19 | 20 | import mido # Midi lib 21 | 22 | import deepmusic.songstruct as music 23 | 24 | 25 | class MidiInvalidException(Exception): 26 | pass 27 | 28 | 29 | class MidiConnector: 30 | """ Class which manage the midi files at the message level 31 | """ 32 | META_INFO_TYPES = [ # Can safely be ignored 33 | 'midi_port', 34 | 'track_name', 35 | 'lyrics', 36 | 'end_of_track', 37 | 'copyright', 38 | 'marker', 39 | 'text' 40 | ] 41 | META_TEMPO_TYPES = [ # Have an impact on how the song is played 42 | 'key_signature', 43 | 'set_tempo', 44 | 'time_signature' 45 | ] 46 | 47 | MINIMUM_TRACK_LENGTH = 4 # Bellow this value, the track will be ignored 48 | 49 | MIDI_CHANNEL_DRUMS = 10 # The channel reserved for the drums (according to the specs) 50 | 51 | # Define a max song length ? 52 | 53 | #self.resolution = 0 # bpm 54 | #self.initial_tempo = 0.0 55 | 56 | #self.data = None # Sparse tensor of size [NB_KEYS,nb_bars*BAR_DIVISION] or simply a list of note ? 57 | 58 | @staticmethod 59 | def load_file(filename): 60 | """ Extract data from midi file 61 | Args: 62 | filename (str): a valid midi file 63 | Return: 64 | Song: a song object containing the tracks and melody 65 | """ 66 | # Load in the MIDI data using the midi module 67 | midi_data = mido.MidiFile(filename) 68 | 69 | # Get header values 70 | 71 | # 3 midi types: 72 | # * type 0 (single track): all messages are saved in one multi-channel track 73 | # * type 1 (synchronous): all tracks start at the same time 74 | # * type 2 (asynchronous): each track is independent of the others 75 | 76 | # Division (ticks per beat notes or SMTPE timecode) 77 | # If negative (first byte=1), the mode is SMTPE timecode (unsupported) 78 | # 1 MIDI clock = 1 beat = 1 quarter note 79 | 80 | # Assert 81 | if midi_data.type != 1: 82 | raise MidiInvalidException('Only type 1 supported ({} given)'.format(midi_data.type)) 83 | if not 0 < midi_data.ticks_per_beat < 128: 84 | raise MidiInvalidException('SMTPE timecode not supported ({} given)'.format(midi_data.ticks_per_beat)) 85 | 86 | # TODO: Support at least for type 0 87 | 88 | # Get tracks messages 89 | 90 | # The tracks are a mix of meta messages, which determine the tempo and signature, and note messages, which 91 | # correspond to the melodie. 92 | # Generally, the meta event are set at the beginning of each tracks. In format 1, these meta-events should be 93 | # contained in the first track (known as 'Tempo Map'). 94 | 95 | # If not set, default parameters are: 96 | # * time signature: 4/4 97 | # * tempo: 120 beats per minute 98 | 99 | # Each event contain begins by a delta time value, which correspond to the number of ticks from the previous 100 | # event (0 for simultaneous event) 101 | 102 | tempo_map = midi_data.tracks[0] # Will contains the tick scales 103 | # TODO: smpte_offset 104 | 105 | # Warning: The drums are filtered 106 | 107 | # Merge tracks ? < Not when creating the dataset 108 | #midi_data.tracks = [mido.merge_tracks(midi_data.tracks)] ?? 109 | 110 | new_song = music.Song() 111 | 112 | new_song.ticks_per_beat = midi_data.ticks_per_beat 113 | 114 | # TODO: Normalize the ticks per beats (same for all songs) 115 | 116 | for message in tempo_map: 117 | # TODO: Check we are only 4/4 (and there is no tempo changes ?) 118 | if not isinstance(message, mido.MetaMessage): 119 | raise MidiInvalidException('Tempo map should not contains notes') 120 | if message.type in MidiConnector.META_INFO_TYPES: 121 | pass 122 | elif message.type == 'set_tempo': 123 | new_song.tempo_map.append(message) 124 | elif message.type in MidiConnector.META_TEMPO_TYPES: # We ignore the key signature and time_signature ? 125 | pass 126 | elif message.type == 'smpte_offset': 127 | pass # TODO 128 | else: 129 | err_msg = 'Header track contains unsupported meta-message type ({})'.format(message.type) 130 | raise MidiInvalidException(err_msg) 131 | 132 | for i, track in enumerate(midi_data.tracks[1:]): # We ignore the tempo map 133 | i += 1 # Warning: We have skipped the track 0 so shift the track id 134 | #tqdm.write('Track {}: {}'.format(i, track.name)) 135 | 136 | new_track = music.Track() 137 | 138 | buffer_notes = [] # Store the current notes (pressed but not released) 139 | abs_tick = 0 # Absolute nb of ticks from the beginning of the track 140 | for message in track: 141 | abs_tick += message.time 142 | if isinstance(message, mido.MetaMessage): # Lyrics, track name and other meta info 143 | if message.type in MidiConnector.META_INFO_TYPES: 144 | pass 145 | elif message.type in MidiConnector.META_TEMPO_TYPES: 146 | # TODO: Could be just a warning 147 | raise MidiInvalidException('Track {} should not contain {}'.format(i, message.type)) 148 | else: 149 | err_msg = 'Track {} contains unsupported meta-message type ({})'.format(i, message.type) 150 | raise MidiInvalidException(err_msg) 151 | # What about 'sequence_number', cue_marker ??? 152 | else: # Note event 153 | if message.type == 'note_on' and message.velocity != 0: # Note added 154 | new_note = music.Note() 155 | new_note.tick = abs_tick 156 | new_note.note = message.note 157 | if message.channel+1 != i and message.channel+1 != MidiConnector.MIDI_CHANNEL_DRUMS: # Warning: Mido shift the channels (start at 0) # TODO: Channel management for type 0 158 | raise MidiInvalidException('Notes belong to the wrong tracks ({} instead of {})'.format(i, message.channel)) # Warning: May not be an error (drums ?) but probably 159 | buffer_notes.append(new_note) 160 | elif message.type == 'note_off' or message.type == 'note_on': # Note released 161 | for note in buffer_notes: 162 | if note.note == message.note: 163 | note.duration = abs_tick - note.tick 164 | buffer_notes.remove(note) 165 | new_track.notes.append(note) 166 | elif message.type == 'program_change': # Instrument change 167 | if not new_track.set_instrument(message): 168 | # TODO: We should create another track with the new instrument 169 | raise MidiInvalidException('Track {} as already a program defined'.format(i)) 170 | pass 171 | elif message.type == 'control_change': # Damper pedal, mono/poly, channel volume,... 172 | # Ignored 173 | pass 174 | elif message.type == 'aftertouch': # Signal send after a key has been press. What real effect ? 175 | # Ignored ? 176 | pass 177 | elif message.type == 'pitchwheel': # Modulate the song 178 | # Ignored 179 | pass 180 | else: 181 | err_msg = 'Track {} contains unsupported message type ({})'.format(i, message) 182 | raise MidiInvalidException(err_msg) 183 | # Message read 184 | # Track read 185 | 186 | # Assert 187 | if buffer_notes: # All notes should have ended 188 | raise MidiInvalidException('Some notes ({}) did not ended'.format(len(buffer_notes))) 189 | if len(new_track.notes) < MidiConnector.MINIMUM_TRACK_LENGTH: 190 | #tqdm.write('Track {} ignored (too short): {} notes'.format(i, len(new_track.notes))) 191 | continue 192 | if new_track.is_drum: 193 | #tqdm.write('Track {} ignored (is drum)'.format(i)) 194 | continue 195 | 196 | new_song.tracks.append(new_track) 197 | # All track read 198 | 199 | if not new_song.tracks: 200 | raise MidiInvalidException('Empty song. No track added') 201 | 202 | return new_song 203 | 204 | @staticmethod 205 | def write_song(song, filename): 206 | """ Save the song on disk 207 | Args: 208 | song (Song): a song object containing the tracks and melody 209 | filename (str): the path were to save the song (don't add the file extension) 210 | """ 211 | 212 | midi_data = mido.MidiFile(ticks_per_beat=song.ticks_per_beat) 213 | 214 | # Define track 0 215 | new_track = mido.MidiTrack() 216 | midi_data.tracks.append(new_track) 217 | new_track.extend(song.tempo_map) 218 | 219 | for i, track in enumerate(song.tracks): 220 | # Define the track 221 | new_track = mido.MidiTrack() 222 | midi_data.tracks.append(new_track) 223 | new_track.append(mido.Message('program_change', program=0, time=0)) # Played with standard piano 224 | 225 | messages = [] 226 | for note in track.notes: 227 | # Add all messages in absolute time 228 | messages.append(mido.Message( 229 | 'note_on', 230 | note=note.note, # WARNING: The note should be int (NOT np.int64) 231 | velocity=64, 232 | channel=i, 233 | time=note.tick)) 234 | messages.append(mido.Message( 235 | 'note_off', 236 | note=note.note, 237 | velocity=64, 238 | channel=i, 239 | time=note.tick+note.duration) 240 | ) 241 | 242 | # Reorder the messages chronologically 243 | messages.sort(key=lambda x: x.time) 244 | 245 | # Convert absolute tick in relative tick 246 | last_time = 0 247 | for message in messages: 248 | message.time -= last_time 249 | last_time += message.time 250 | 251 | new_track.append(message) 252 | 253 | midi_data.save(filename + '.mid') 254 | 255 | @staticmethod 256 | def get_input_type(): 257 | return 'song' 258 | -------------------------------------------------------------------------------- /deepmusic/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Model to generate new songs 18 | 19 | """ 20 | 21 | import numpy as np # To generate random numbers 22 | import tensorflow as tf 23 | 24 | from deepmusic.moduleloader import ModuleLoader 25 | from deepmusic.keyboardcell import KeyboardCell 26 | import deepmusic.songstruct as music 27 | 28 | 29 | class Model: 30 | """ 31 | Base class which manage the different models and experimentation. 32 | """ 33 | 34 | class TargetWeightsPolicy: 35 | """ Structure to represent the different policy for choosing the target weights 36 | This is used to scale the contribution of each timestep to the global loss 37 | """ 38 | NONE = 'none' # All weights equals (=1.0) (default behavior) 39 | LINEAR = 'linear' # The first outputs are less penalized than the last ones 40 | STEP = 'step' # We start penalizing only after x steps (enco/deco behavior) 41 | 42 | def __init__(self, args): 43 | """ 44 | Args: 45 | args: parameters of the model 46 | """ 47 | self.args = args 48 | 49 | def get_weight(self, i): 50 | """ Return the target weight for the given step i using the chosen policy 51 | """ 52 | if not self.args.target_weights or self.args.target_weights == Model.TargetWeightsPolicy.NONE: 53 | return 1.0 54 | elif self.args.target_weights == Model.TargetWeightsPolicy.LINEAR: 55 | return i / (self.args.sample_length - 1) # Gradually increment the loss weight 56 | elif self.args.target_weights == Model.TargetWeightsPolicy.STEP: 57 | raise NotImplementedError('Step target weight policy not implemented yet, please consider another policy') 58 | else: 59 | raise ValueError('Unknown chosen target weight policy: {}'.format(self.args.target_weights)) 60 | 61 | @staticmethod 62 | def get_policies(): 63 | """ Return the list of the different modes 64 | Useful when parsing the command lines arguments 65 | """ 66 | return [ 67 | Model.TargetWeightsPolicy.NONE, 68 | Model.TargetWeightsPolicy.LINEAR, 69 | Model.TargetWeightsPolicy.STEP 70 | ] 71 | 72 | class ScheduledSamplingPolicy: 73 | """ Container for the schedule sampling policy 74 | See http://arxiv.org/abs/1506.03099 for more details 75 | """ 76 | NONE = 'none' # No scheduled sampling (always take the given input) 77 | ALWAYS = 'always' # Always samples from the predicted output 78 | LINEAR = 'linear' # Gradually increase the sampling rate 79 | 80 | def __init__(self, args): 81 | self.sampling_policy_fct = None 82 | 83 | assert args.scheduled_sampling 84 | assert len(args.scheduled_sampling) > 0 85 | 86 | policy = args.scheduled_sampling[0] 87 | if policy == Model.ScheduledSamplingPolicy.NONE: 88 | self.sampling_policy_fct = lambda step: 1.0 89 | elif policy == Model.ScheduledSamplingPolicy.ALWAYS: 90 | self.sampling_policy_fct = lambda step: 0.0 91 | elif policy == Model.ScheduledSamplingPolicy.LINEAR: 92 | if len(args.scheduled_sampling) != 5: 93 | raise ValueError('Not the right arguments for the sampling linear policy ({} instead of 4)'.format(len(args.scheduled_sampling)-1)) 94 | 95 | start_step = int(args.scheduled_sampling[1]) 96 | end_step = int(args.scheduled_sampling[2]) 97 | start_value = float(args.scheduled_sampling[3]) 98 | end_value = float(args.scheduled_sampling[4]) 99 | 100 | if (start_step >= end_step or 101 | not (0.0 <= start_value <= 1.0) or 102 | not (0.0 <= end_value <= 1.0)): 103 | raise ValueError('Some schedule sampling parameters incorrect.') 104 | 105 | # TODO: Add default values (as optional arguments) 106 | 107 | def linear_policy(step): 108 | if step < start_step: 109 | threshold = start_value 110 | elif start_step <= step < end_step: 111 | slope = (start_value-end_value)/(start_step-end_step) # < 0 (because end_step>start_step and start_value>end_value) 112 | threshold = slope*(step-start_step) + start_value 113 | elif end_step <= step: 114 | threshold = end_value 115 | else: 116 | raise RuntimeError('Invalid value for the sampling policy') # Parameters have not been correctly defined! 117 | assert 0.0 <= threshold <= 1.0 118 | return threshold 119 | 120 | self.sampling_policy_fct = linear_policy 121 | else: 122 | raise ValueError('Unknown chosen schedule sampling policy: {}'.format(policy)) 123 | 124 | def get_prev_threshold(self, glob_step, i=0): 125 | """ Return the previous sampling probability for the current step. 126 | If above, the RNN should use the previous step instead of the given input. 127 | Args: 128 | glob_step (int): the global iteration step for the training 129 | i (int): the timestep of the RNN (TODO: implement incrementive slope (progression like -\|), remove the '=0') 130 | """ 131 | return self.sampling_policy_fct(glob_step) 132 | 133 | def __init__(self, args): 134 | """ 135 | Args: 136 | args: parameters of the model 137 | """ 138 | print('Model creation...') 139 | 140 | self.args = args # Keep track of the parameters of the model 141 | 142 | # Placeholders 143 | self.inputs = None 144 | self.targets = None 145 | self.use_prev = None # Boolean tensor which say at Graph evaluation time if we use the input placeholder or the previous output. 146 | self.current_learning_rate = None # Allow to have a dynamic learning rate 147 | 148 | # Main operators 149 | self.opt_op = None # Optimizer 150 | self.outputs = None # Outputs of the network 151 | self.final_state = None # When testing, we feed this value as initial state ? 152 | 153 | # Other options 154 | self.target_weights_policy = None 155 | self.schedule_policy = None 156 | self.learning_rate_policy = None 157 | self.loop_processing = None 158 | 159 | # Construct the graphs 160 | self._build_network() 161 | 162 | def _build_network(self): 163 | """ Create the computational graph 164 | """ 165 | input_dim = ModuleLoader.batch_builders.get_module().get_input_dim() 166 | 167 | # Placeholders (Use tf.SparseTensor with training=False instead) (TODO: Try restoring dynamic batch_size) 168 | with tf.name_scope('placeholder_inputs'): 169 | self.inputs = [ 170 | tf.placeholder( 171 | tf.float32, # -1.0/1.0 ? Probably better for the sigmoid 172 | [self.args.batch_size, input_dim], # TODO: Get input size from batch_builder 173 | name='input') 174 | for _ in range(self.args.sample_length) 175 | ] 176 | with tf.name_scope('placeholder_targets'): 177 | self.targets = [ 178 | tf.placeholder( 179 | tf.int32, # 0/1 # TODO: Int for sofmax, Float for sigmoid 180 | [self.args.batch_size,], # TODO: For softmax, only 1d, for sigmoid, 2d (batch_size, num_class) 181 | name='target') 182 | for _ in range(self.args.sample_length) 183 | ] 184 | with tf.name_scope('placeholder_use_prev'): 185 | self.use_prev = [ 186 | tf.placeholder( 187 | tf.bool, 188 | [], 189 | name='use_prev') 190 | for _ in range(self.args.sample_length) # The first value will never be used (always takes self.input for the first step) 191 | ] 192 | 193 | # Define the network 194 | self.loop_processing = ModuleLoader.loop_processings.build_module(self.args) 195 | def loop_rnn(prev, i): 196 | """ Loop function used to connect one output of the rnn to the next input. 197 | The previous input and returned value have to be from the same shape. 198 | This is useful to use the same network for both training and testing. 199 | Args: 200 | prev: the previous predicted keyboard configuration at step i-1 201 | i: the current step id (Warning: start at 1, 0 is ignored) 202 | Return: 203 | tf.Tensor: the input at the step i 204 | """ 205 | next_input = self.loop_processing(prev) 206 | 207 | # On training, we force the correct input, on testing, we use the previous output as next input 208 | return tf.cond(self.use_prev[i], lambda: next_input, lambda: self.inputs[i]) 209 | 210 | # TODO: Try attention decoder/use dynamic_rnn instead 211 | self.outputs, self.final_state = tf.nn.seq2seq.rnn_decoder( 212 | decoder_inputs=self.inputs, 213 | initial_state=None, # The initial state is defined inside KeyboardCell 214 | cell=KeyboardCell(self.args), 215 | loop_function=loop_rnn 216 | ) 217 | 218 | # For training only 219 | if not self.args.test: 220 | # Finally, we define the loss function 221 | 222 | # The network will predict a mix a wrong and right notes. For the loss function, we would like to 223 | # penalize note which are wrong. Eventually, the penalty should be less if the network predict the same 224 | # note but not in the right pitch (ex: C4 instead of C5), with a decay the further the prediction 225 | # is (D5 and D1 more penalized than D4 and D3 if the target is D2) 226 | 227 | # For the piano roll mode, by using sigmoid_cross_entropy_with_logits, the task is formulated as a NB_NOTES binary 228 | # classification problems 229 | 230 | # For the relative note experiment, it use a standard SoftMax where the label is the relative position to the previous 231 | # note 232 | 233 | self.schedule_policy = Model.ScheduledSamplingPolicy(self.args) 234 | self.target_weights_policy = Model.TargetWeightsPolicy(self.args) 235 | self.learning_rate_policy = ModuleLoader.learning_rate_policies.build_module(self.args) # Load the chosen policies 236 | 237 | # TODO: If train on different length, check that the loss is proportional to the length or average ??? 238 | loss_fct = tf.nn.seq2seq.sequence_loss( 239 | self.outputs, 240 | self.targets, 241 | [tf.constant(self.target_weights_policy.get_weight(i), shape=self.targets[0].get_shape()) for i in range(len(self.targets))], # Weights 242 | #softmax_loss_function=tf.nn.softmax_cross_entropy_with_logits, # Previous: tf.nn.sigmoid_cross_entropy_with_logits TODO: Use option to choose. (new module ?) 243 | average_across_timesteps=True, # Before: I think it's best for variables length sequences (specially with the target weights=0), isn't it (it implies also that short sequences are less penalized than long ones) ? (TODO: For variables length sequences, be careful about the target weights) 244 | average_across_batch=True # Before: Penalize by sample (should allows dynamic batch size) Warning: need to tune the learning rate 245 | ) 246 | tf.scalar_summary('training_loss', loss_fct) # Keep track of the cost 247 | 248 | self.current_learning_rate = tf.placeholder(tf.float32, []) 249 | 250 | # Initialize the optimizer 251 | opt = tf.train.AdamOptimizer( 252 | learning_rate=self.current_learning_rate, 253 | beta1=0.9, 254 | beta2=0.999, 255 | epsilon=1e-08 256 | ) 257 | 258 | # TODO: Also keep track of magnitudes (how much is updated) 259 | self.opt_op = opt.minimize(loss_fct) 260 | 261 | def step(self, batch, train_set=True, glob_step=-1, ret_output=False): 262 | """ Forward/training step operation. 263 | Does not perform run on itself but just return the operators to do so. Those have then to be run by the 264 | main program. 265 | If the output operator is returned, it will always be the last one on the list 266 | Args: 267 | batch (Batch): Input data on testing mode, input and target on output mode 268 | train_set (Bool): indicate if the batch come from the test/train set (not used when generating) 269 | glob_step (int): indicate the global step for the schedule sampling 270 | ret_output (Bool): for the training mode, if true, 271 | Return: 272 | Tuple[ops], dict: The list of the operators to run (training_step or outputs) with the associated feed dictionary 273 | """ 274 | # TODO: Could optimize feeding between train/test/generating (compress code) 275 | 276 | feed_dict = {} 277 | ops = () # For small length, it seems (from my investigations) that tuples are faster than list for merging 278 | batch.generate(target=False if self.args.test else True) 279 | 280 | # Feed placeholders and choose the ops 281 | if not self.args.test: # Training 282 | if train_set: # We update the learning rate every x iterations # TODO: What happens when we don't feed the learning rate ??? Stays at the last value ? 283 | assert glob_step >= 0 284 | feed_dict[self.current_learning_rate] = self.learning_rate_policy.get_learning_rate(glob_step) 285 | 286 | for i in range(self.args.sample_length): 287 | feed_dict[self.inputs[i]] = batch.inputs[i] 288 | feed_dict[self.targets[i]] = batch.targets[i] 289 | #if np.random.rand() >= self.schedule_policy.get_prev_threshold(glob_step)*self.target_weights_policy.get_weight(i): # Regular Schedule sample (TODO: Try sampling with the weigths or a mix of weights/sampling) 290 | if np.random.rand() >= self.schedule_policy.get_prev_threshold(glob_step): # Weight the threshold by the target weights (don't schedule sample if weight=0) 291 | feed_dict[self.use_prev[i]] = True 292 | else: 293 | feed_dict[self.use_prev[i]] = False 294 | 295 | if train_set: 296 | ops += (self.opt_op,) 297 | if ret_output: 298 | ops += (self.outputs,) 299 | else: # Generating (batch_size == 1) 300 | # TODO: What to put for initialisation state (empty ? random ?) ? 301 | # TODO: Modify use_prev 302 | for i in range(self.args.sample_length): 303 | if i < len(batch.inputs): 304 | feed_dict[self.inputs[i]] = batch.inputs[i] 305 | feed_dict[self.use_prev[i]] = False 306 | else: # Even not used, we still need to feed a placeholder 307 | feed_dict[self.inputs[i]] = batch.inputs[0] # Could be anything but we need it to be from the right shape 308 | feed_dict[self.use_prev[i]] = True # When we don't have an input, we use the previous output instead 309 | 310 | ops += (self.loop_processing.get_op(), self.outputs,) # The loop_processing operator correspond to the recorded softmax sampled 311 | 312 | # Return one pass operator 313 | return ops, feed_dict 314 | -------------------------------------------------------------------------------- /deepmusic/model_old.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Model to generate new songs 18 | 19 | """ 20 | 21 | import numpy as np # To generate random numbers 22 | import tensorflow as tf 23 | 24 | from deepmusic.musicdata import Batch 25 | import deepmusic.songstruct as music 26 | 27 | 28 | class Model: 29 | """ 30 | Base class which manage the different models and experimentation. 31 | """ 32 | 33 | class TargetWeightsPolicy: 34 | """ Structure to represent the different policy for choosing the target weights 35 | This is used to scale the contribution of each timestep to the global loss 36 | """ 37 | NONE = 'none' # All weights equals (=1.0) (default behavior) 38 | LINEAR = 'linear' # The first outputs are less penalized than the last ones 39 | STEP = 'step' # We start penalizing only after x steps (enco/deco behavior) 40 | 41 | def __init__(self, args): 42 | """ 43 | Args: 44 | args: parameters of the model 45 | """ 46 | self.args = args 47 | 48 | def get_weight(self, i): 49 | """ Return the target weight for the given step i using the chosen policy 50 | """ 51 | if not self.args.target_weights or self.args.target_weights == Model.TargetWeightsPolicy.NONE: 52 | return 1.0 53 | elif self.args.target_weights == Model.TargetWeightsPolicy.LINEAR: 54 | return i / (self.args.sample_length - 1) # Gradually increment the loss weight 55 | elif self.args.target_weights == Model.TargetWeightsPolicy.STEP: 56 | raise NotImplementedError('Step target weight policy not implemented yet, please consider another policy') 57 | else: 58 | raise ValueError('Unknown chosen target weight policy: {}'.format(self.args.target_weights)) 59 | 60 | @staticmethod 61 | def get_policies(): 62 | """ Return the list of the different modes 63 | Useful when parsing the command lines arguments 64 | """ 65 | return [ 66 | Model.TargetWeightsPolicy.NONE, 67 | Model.TargetWeightsPolicy.LINEAR, 68 | Model.TargetWeightsPolicy.STEP 69 | ] 70 | 71 | class ScheduledSamplingPolicy: 72 | """ Container for the schedule sampling policy 73 | See http://arxiv.org/abs/1506.03099 for more details 74 | """ 75 | NONE = 'none' # No scheduled sampling (always take the given input) 76 | ALWAYS = 'always' # Always samples from the predicted output 77 | LINEAR = 'linear' # Gradually increase the sampling rate 78 | 79 | def __init__(self, args): 80 | self.sampling_policy_fct = None 81 | 82 | assert args.scheduled_sampling 83 | assert len(args.scheduled_sampling) > 0 84 | 85 | policy = args.scheduled_sampling[0] 86 | if policy == Model.ScheduledSamplingPolicy.NONE: 87 | self.sampling_policy_fct = lambda step: 1.0 88 | elif policy == Model.ScheduledSamplingPolicy.ALWAYS: 89 | self.sampling_policy_fct = lambda step: 0.0 90 | elif policy == Model.ScheduledSamplingPolicy.LINEAR: 91 | if len(args.scheduled_sampling) != 5: 92 | raise ValueError('Not the right arguments for the sampling linear policy ({} instead of 4)'.format(len(args.scheduled_sampling)-1)) 93 | 94 | start_step = int(args.scheduled_sampling[1]) 95 | end_step = int(args.scheduled_sampling[2]) 96 | start_value = float(args.scheduled_sampling[3]) 97 | end_value = float(args.scheduled_sampling[4]) 98 | 99 | if (start_step >= end_step or 100 | not (0.0 <= start_value <= 1.0) or 101 | not (0.0 <= end_value <= 1.0)): 102 | raise ValueError('Some schedule sampling parameters incorrect.') 103 | 104 | # TODO: Add default values (as optional arguments) 105 | 106 | def linear_policy(step): 107 | if step < start_step: 108 | threshold = start_value 109 | elif start_step <= step < end_step: 110 | slope = (start_value-end_value)/(start_step-end_step) # < 0 (because end_step>start_step and start_value>end_value) 111 | threshold = slope*(step-start_step) + start_value 112 | elif end_step <= step: 113 | threshold = end_value 114 | else: 115 | raise RuntimeError('Invalid value for the sampling policy') # Parameters have not been correctly defined! 116 | assert 0.0 <= threshold <= 1.0 117 | return threshold 118 | 119 | self.sampling_policy_fct = linear_policy 120 | else: 121 | raise ValueError('Unknown chosen schedule sampling policy: {}'.format(policy)) 122 | 123 | def get_prev_threshold(self, glob_step, i=0): 124 | """ Return the previous sampling probability for the current step. 125 | If above, the RNN should use the previous step instead of the given input. 126 | Args: 127 | glob_step (int): the global iteration step for the training 128 | i (int): the timestep of the RNN (TODO: implement incrementive slope (progression like -\|), remove the '=0') 129 | """ 130 | return self.sampling_policy_fct(glob_step) 131 | 132 | class LearningRatePolicy: 133 | """ Contains the different policies for the learning rate decay 134 | """ 135 | CST = 'cst' # Fixed learning rate over all steps (default behavior) 136 | STEP = 'step' # We divide the learning rate every x iterations 137 | EXPONENTIAL = 'exponential' # 138 | 139 | @staticmethod 140 | def get_policies(): 141 | """ Return the list of the different modes 142 | Useful when parsing the command lines arguments 143 | """ 144 | return [ 145 | Model.LearningRatePolicy.CST, 146 | Model.LearningRatePolicy.STEP, 147 | Model.LearningRatePolicy.EXPONENTIAL 148 | ] 149 | 150 | def __init__(self, args): 151 | """ 152 | Args: 153 | args: parameters of the model 154 | """ 155 | self.learning_rate_fct = None 156 | 157 | assert args.learning_rate 158 | assert len(args.learning_rate) > 0 159 | 160 | policy = args.learning_rate[0] 161 | 162 | if policy == Model.LearningRatePolicy.CST: 163 | if not len(args.learning_rate) == 2: 164 | raise ValueError('Learning rate cst policy should be on the form: {} lr_value'.format(Model.LearningRatePolicy.CST)) 165 | self.learning_rate_init = float(args.learning_rate[1]) 166 | self.learning_rate_fct = self._lr_cst 167 | 168 | elif policy == Model.LearningRatePolicy.STEP: 169 | if not len(args.learning_rate) == 3: 170 | raise ValueError('Learning rate step policy should be on the form: {} lr_init decay_period'.format(Model.LearningRatePolicy.STEP)) 171 | self.learning_rate_init = float(args.learning_rate[1]) 172 | self.decay_period = int(args.learning_rate[2]) 173 | self.learning_rate_fct = self._lr_step 174 | 175 | elif policy == Model.LearningRatePolicy.EXPONENTIAL: 176 | raise NotImplementedError('Exponential learning rate policy not implemented yet, please consider another policy') 177 | 178 | else: 179 | raise ValueError('Unknown chosen learning rate policy: {}'.format(policy)) 180 | 181 | def _lr_cst(self, glob_step): 182 | """ Just a constant learning rate 183 | """ 184 | return self.learning_rate_init 185 | 186 | def _lr_step(self, glob_step): 187 | """ Every decay period, the learning rate is divided by 2 188 | """ 189 | return self.learning_rate_init / 2**(glob_step//self.decay_period) 190 | 191 | def get_learning_rate(self, glob_step): 192 | """ Return the learning rate associated at the current training step 193 | Args: 194 | glob_step (int): Number of iterations since the beginning of training 195 | Return: 196 | float: the learning rate at the given step 197 | """ 198 | return self.learning_rate_fct(glob_step) 199 | 200 | def __init__(self, args): 201 | """ 202 | Args: 203 | args: parameters of the model 204 | """ 205 | print('Model creation...') 206 | 207 | self.args = args # Keep track of the parameters of the model 208 | 209 | # Placeholders 210 | self.inputs = None 211 | self.targets = None 212 | self.use_prev = None # Boolean tensor which say at Graph evaluation time if we use the input placeholder or the previous output. 213 | self.current_learning_rate = None # Allow to have a dynamic learning rate 214 | 215 | # Main operators 216 | self.opt_op = None # Optimizer 217 | self.outputs = None # Outputs of the network 218 | self.final_state = None # When testing, we feed this value as initial state ? 219 | 220 | # Other options 221 | self.target_weights_policy = None 222 | self.schedule_policy = None 223 | self.learning_rate_policy = None 224 | 225 | # Construct the graphs 226 | self._build_network() 227 | 228 | def _build_network(self): 229 | """ Create the computational graph 230 | """ 231 | 232 | # Placeholders (Use tf.SparseTensor with training=False instead) (TODO: Try restoring dynamic batch_size) 233 | with tf.name_scope('placeholder_inputs'): 234 | self.inputs = [ 235 | tf.placeholder( 236 | tf.float32, # -1.0/1.0 ? Probably better for the sigmoid 237 | [self.args.batch_size, music.NB_NOTES], 238 | name='input') 239 | for _ in range(self.args.sample_length) 240 | ] 241 | with tf.name_scope('placeholder_targets'): 242 | self.targets = [ 243 | tf.placeholder( 244 | tf.float32, # 0/1 245 | [self.args.batch_size, music.NB_NOTES], 246 | name='target') 247 | for _ in range(self.args.sample_length) 248 | ] 249 | with tf.name_scope('placeholder_use_prev'): 250 | self.use_prev = [ 251 | tf.placeholder( 252 | tf.bool, 253 | [], 254 | name='use_prev') 255 | for _ in range(self.args.sample_length) # The first value will never be used (always takes self.input for the first step) 256 | ] 257 | 258 | # Projection on the keyboard 259 | with tf.name_scope('note_projection_weights'): 260 | W = tf.Variable( 261 | tf.truncated_normal([self.args.hidden_size, music.NB_NOTES]), 262 | name='weights' 263 | ) 264 | b = tf.Variable( 265 | tf.truncated_normal([music.NB_NOTES]), # Tune the initializer ? 266 | name='bias', 267 | ) 268 | 269 | def project_note(X): 270 | with tf.name_scope('note_projection'): 271 | return tf.matmul(X, W) + b # [batch_size, NB_NOTE] 272 | 273 | # RNN network 274 | rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(self.args.hidden_size, state_is_tuple=True) # Or GRUCell, LSTMCell(args.hidden_size) 275 | #rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, input_keep_prob=1.0, output_keep_prob=1.0) # TODO: Custom values (WARNING: No dropout when testing !!!, possible to use placeholder ?) 276 | rnn_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell] * self.args.num_layers, state_is_tuple=True) 277 | 278 | initial_state = rnn_cell.zero_state(batch_size=self.args.batch_size, dtype=tf.float32) 279 | 280 | def loop_rnn(prev, i): 281 | """ Loop function used to connect one output of the rnn to the next input. 282 | Will re-adapt the output shape to the input one. 283 | This is useful to use the same network for both training and testing. Warning: Because of the fixed 284 | batch size, we have to predict batch_size sequences when testing. 285 | """ 286 | # Predict the output from prev and scale the result on [-1, 1] 287 | next_input = project_note(prev) 288 | next_input = tf.sub(tf.mul(2.0, tf.nn.sigmoid(next_input)), 1.0) # x_{i} = 2*sigmoid(y_{i-1}) - 1 289 | 290 | # On training, we force the correct input, on testing, we use the previous output as next input 291 | return tf.cond(self.use_prev[i], lambda: next_input, lambda: self.inputs[i]) 292 | 293 | (outputs, self.final_state) = tf.nn.seq2seq.rnn_decoder( 294 | decoder_inputs=self.inputs, 295 | initial_state=initial_state, 296 | cell=rnn_cell, 297 | loop_function=loop_rnn 298 | ) 299 | 300 | # Final projection 301 | with tf.name_scope('final_output'): 302 | self.outputs = [] 303 | for output in outputs: 304 | proj = project_note(output) 305 | self.outputs.append(proj) 306 | 307 | # For training only 308 | if not self.args.test: 309 | # Finally, we define the loss function 310 | 311 | # The network will predict a mix a wrong and right notes. For the loss function, we would like to 312 | # penalize note which are wrong. Eventually, the penalty should be less if the network predict the same 313 | # note but not in the right pitch (ex: C4 instead of C5), with a decay the further the prediction 314 | # is (D5 and D1 more penalized than D4 and D3 if the target is D2) 315 | 316 | # For now, by using sigmoid_cross_entropy_with_logits, the task is formulated as a NB_NOTES binary 317 | # classification problems 318 | 319 | self.schedule_policy = Model.ScheduledSamplingPolicy(self.args) 320 | self.target_weights_policy = Model.TargetWeightsPolicy(self.args) 321 | self.learning_rate_policy = Model.LearningRatePolicy(self.args) # Load the chosen policies 322 | 323 | # TODO: If train on different length, check that the loss is proportional to the length or average ??? 324 | loss_fct = tf.nn.seq2seq.sequence_loss( 325 | self.outputs, 326 | self.targets, 327 | [tf.constant(self.target_weights_policy.get_weight(i), shape=self.targets[0].get_shape()) for i in range(len(self.targets))], # Weights 328 | softmax_loss_function=tf.nn.sigmoid_cross_entropy_with_logits, 329 | average_across_timesteps=False, # I think it's best for variables length sequences (specially with the target weights=0), isn't it (it implies also that short sequences are less penalized than long ones) ? (TODO: For variables length sequences, be careful about the target weights) 330 | average_across_batch=False # Penalize by sample (should allows dynamic batch size) Warning: need to tune the learning rate 331 | ) 332 | tf.scalar_summary('training_loss', loss_fct) # Keep track of the cost 333 | 334 | self.current_learning_rate = tf.placeholder(tf.float32, []) 335 | 336 | # Initialize the optimizer 337 | opt = tf.train.AdamOptimizer( 338 | learning_rate=self.current_learning_rate, 339 | beta1=0.9, 340 | beta2=0.999, 341 | epsilon=1e-08 342 | ) 343 | 344 | # TODO: Also keep track of magnitudes (how much is updated) 345 | self.opt_op = opt.minimize(loss_fct) 346 | 347 | def step(self, batch, train_set=True, glob_step=-1, ret_output=False): 348 | """ Forward/training step operation. 349 | Does not perform run on itself but just return the operators to do so. Those have then to be run by the 350 | main program. 351 | If the output operator is returned, it will always be the last one on the list 352 | Args: 353 | batch (Batch): Input data on testing mode, input and target on output mode 354 | train_set (Bool): indicate if the batch come from the test/train set 355 | glob_step (int): indicate the global step for the schedule sampling 356 | ret_output (Bool): for the training mode, if true, 357 | Return: 358 | Tuple[ops], dict: The list of the operators to run (training_step or outputs) with the associated feed dictionary 359 | """ 360 | # TODO: Could optimize feeding between train/test/generating (compress code) 361 | 362 | # Feed the dictionary 363 | feed_dict = {} 364 | ops = () # For small length, it seems (from my investigations) that tuples are faster than list for merging 365 | 366 | # Feed placeholders and choose the ops 367 | if not self.args.test: # Training 368 | if train_set: 369 | assert glob_step >= 0 370 | feed_dict[self.current_learning_rate] = self.learning_rate_policy.get_learning_rate(glob_step) 371 | 372 | for i in range(self.args.sample_length): 373 | feed_dict[self.inputs[i]] = batch.inputs[i] 374 | feed_dict[self.targets[i]] = batch.targets[i] 375 | #if not train_set or np.random.rand() > self.schedule_policy.get_prev_threshold(glob_step)*self.target_weights_policy.get_weight(i): # Regular Schedule sample (TODO: Try sampling with the weigths or a mix of weights/sampling) 376 | if np.random.rand() > self.schedule_policy.get_prev_threshold(glob_step): # Weight the threshold by the target weights (don't schedule sample if weight=0) 377 | feed_dict[self.use_prev[i]] = True 378 | else: 379 | feed_dict[self.use_prev[i]] = False 380 | 381 | if train_set: 382 | ops += (self.opt_op,) 383 | if ret_output: 384 | ops += (self.outputs,) 385 | else: # Generating (batch_size == 1) 386 | # TODO: What to put for initialisation state (empty ? random ?) ? 387 | # TODO: Modify use_prev 388 | for i in range(self.args.sample_length): 389 | if i < len(batch.inputs): 390 | feed_dict[self.inputs[i]] = batch.inputs[i] 391 | feed_dict[self.use_prev[i]] = False 392 | else: # Even not used, we still need to feed a placeholder 393 | feed_dict[self.inputs[i]] = batch.inputs[0] # Could be anything but we need it to be from the right shape 394 | feed_dict[self.use_prev[i]] = True # When we don't have an input, we use the previous output instead 395 | 396 | ops += (self.outputs,) 397 | 398 | # Return one pass operator 399 | return ops, feed_dict 400 | -------------------------------------------------------------------------------- /deepmusic/moduleloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ Register all available modules 17 | All new module should be added here 18 | """ 19 | 20 | from deepmusic.modulemanager import ModuleManager 21 | 22 | from deepmusic.modules import batchbuilder 23 | from deepmusic.modules import learningratepolicy 24 | from deepmusic.modules import encoder 25 | from deepmusic.modules import decoder 26 | from deepmusic.modules import loopprocessing 27 | 28 | 29 | class ModuleLoader: 30 | """ Global module manager, synchronize the loading, printing, parsing of 31 | all modules. 32 | The modules are then instantiated and use in their respective class 33 | """ 34 | enco_cells = None 35 | deco_cells = None 36 | batch_builders = None 37 | learning_rate_policies = None 38 | loop_processings = None 39 | 40 | @staticmethod 41 | def register_all(): 42 | """ List all available modules for the current session 43 | This function should be called only once at the beginning of the 44 | program, before parsing the command lines arguments 45 | It doesn't instantiate anything here (just notify the program). 46 | The module manager name will define the command line flag 47 | which will be used 48 | """ 49 | ModuleLoader.batch_builders = ModuleManager('batch_builder') 50 | ModuleLoader.batch_builders.register(batchbuilder.Relative) 51 | ModuleLoader.batch_builders.register(batchbuilder.PianoRoll) 52 | 53 | ModuleLoader.learning_rate_policies = ModuleManager('learning_rate') 54 | ModuleLoader.learning_rate_policies.register(learningratepolicy.Cst) 55 | ModuleLoader.learning_rate_policies.register(learningratepolicy.StepsWithDecay) 56 | ModuleLoader.learning_rate_policies.register(learningratepolicy.Adaptive) 57 | 58 | ModuleLoader.enco_cells = ModuleManager('enco_cell') 59 | ModuleLoader.enco_cells.register(encoder.Identity) 60 | ModuleLoader.enco_cells.register(encoder.Rnn) 61 | ModuleLoader.enco_cells.register(encoder.Embedding) 62 | 63 | ModuleLoader.deco_cells = ModuleManager('deco_cell') 64 | ModuleLoader.deco_cells.register(decoder.Lstm) 65 | ModuleLoader.deco_cells.register(decoder.Perceptron) 66 | ModuleLoader.deco_cells.register(decoder.Rnn) 67 | 68 | ModuleLoader.loop_processings = ModuleManager('loop_processing') 69 | ModuleLoader.loop_processings.register(loopprocessing.SampleSoftmax) 70 | ModuleLoader.loop_processings.register(loopprocessing.ActivateScale) 71 | 72 | @staticmethod 73 | def save_all(config): 74 | """ Save the modules configurations 75 | """ 76 | config['Modules'] = {} 77 | ModuleLoader.batch_builders.save(config['Modules']) 78 | ModuleLoader.learning_rate_policies.save(config['Modules']) 79 | ModuleLoader.enco_cells.save(config['Modules']) 80 | ModuleLoader.deco_cells.save(config['Modules']) 81 | ModuleLoader.loop_processings.save(config['Modules']) 82 | 83 | @staticmethod 84 | def load_all(args, config): 85 | """ Restore the module configuration 86 | """ 87 | ModuleLoader.batch_builders.load(args, config['Modules']) 88 | ModuleLoader.learning_rate_policies.load(args, config['Modules']) 89 | ModuleLoader.enco_cells.load(args, config['Modules']) 90 | ModuleLoader.deco_cells.load(args, config['Modules']) 91 | ModuleLoader.loop_processings.load(args, config['Modules']) 92 | 93 | @staticmethod 94 | def print_all(args): 95 | """ Print modules current configuration 96 | """ 97 | ModuleLoader.batch_builders.print(args) 98 | ModuleLoader.learning_rate_policies.print(args) 99 | ModuleLoader.enco_cells.print(args) 100 | ModuleLoader.deco_cells.print(args) 101 | ModuleLoader.loop_processings.print(args) 102 | -------------------------------------------------------------------------------- /deepmusic/modulemanager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ Module manager class definition 17 | """ 18 | 19 | from collections import OrderedDict 20 | 21 | 22 | class ModuleManager: 23 | """ Class which manage modules 24 | A module can be any class, as long as it implement the static method 25 | get_module_id and has a compatible constructor. The role of the module 26 | manager is to ensure that only one of the registered classes are used 27 | on the program. 28 | The first added module will be the one used by default. 29 | For now the modules can't save their state. 30 | """ 31 | def __init__(self, name): 32 | """ 33 | Args: 34 | name (str): the name of the module manager (useful for saving/printing) 35 | """ 36 | self.name = name 37 | self.modules = OrderedDict() # The order on which the modules are added is conserved 38 | self.module_instance = None # Reference to the chosen module 39 | self.module_name = '' # Type of module chosen (useful when saving/loading) 40 | self.module_parameters = [] # Arguments passed (for saving/loading) 41 | 42 | def register(self, module): 43 | """ Register a new module 44 | The only restriction is that the given class has to implement the static 45 | method get_module_id 46 | Args: 47 | module (Class): the module class to register 48 | """ 49 | assert not module.get_module_id() in self.modules # Overwriting not allowed 50 | self.modules[module.get_module_id()] = module 51 | 52 | def get_modules_ids(self): 53 | """ Return the list of added modules 54 | Useful for instance for the command line parser 55 | Returns: 56 | list[str]: the list of modules 57 | """ 58 | return self.modules.keys() 59 | 60 | def get_chosen_name(self): 61 | """ Return the name of the chosen module 62 | Name is defined by get_module_id 63 | Returns: 64 | str: the name of the chosen module 65 | """ 66 | return self.module_name 67 | 68 | def get_module(self): 69 | """ Return the chosen module 70 | Returns: 71 | Obj: the reference on the module instance 72 | """ 73 | assert self.module_instance is not None 74 | return self.module_instance 75 | 76 | def build_module(self, args): 77 | """ Instantiate the chosen module 78 | This function can be called only once when initializing the module 79 | Args: 80 | args (Obj): the global program parameters 81 | Returns: 82 | Obj: the created module 83 | """ 84 | assert self.module_instance is None 85 | 86 | module_args = getattr(args, self.name) # Get the name of the module and its eventual additional parameters (Exception will be raised if the user try incorrect module) 87 | 88 | self.module_name = module_args[0] 89 | self.module_parameters = module_args[1:] 90 | self.module_instance = self.modules[self.module_name](args, *self.module_parameters) 91 | return self.module_instance 92 | 93 | def add_argparse(self, group_args, comment): 94 | """ Add the module to the command line parser 95 | All modules have to be registered before that call 96 | Args: 97 | group_args (ArgumentParser): 98 | comment (str): help to add 99 | """ 100 | assert len(self.modules.keys()) # Should contain at least 1 module 101 | 102 | keys = list(self.modules.keys()) 103 | group_args.add_argument( 104 | '--{}'.format(self.name), 105 | type=str, 106 | nargs='+', 107 | default=[keys[0]], # No defaults optional argument (implemented in the module class) 108 | help=comment + ' Choices available: {}'.format(', '.join(keys)) 109 | ) 110 | 111 | def save(self, config_group): 112 | """ Save the current module parameters 113 | Args: 114 | config_group (dict): dictionary where to write the configuration 115 | """ 116 | config_group[self.name] = ' '.join([self.module_name] + self.module_parameters) 117 | # TODO: The module state should be saved here 118 | 119 | def load(self, args, config_group): 120 | """ Restore the parameters from the configuration group 121 | Args: 122 | args (parse_args() returned Obj): the parameters of the models (will be modified) 123 | config_group (dict): the module group parameters to extract 124 | Warning: Only restore the arguments. The instantiation is not done here 125 | """ 126 | setattr(args, self.name, config_group.get(self.name).split(' ')) 127 | 128 | def print(self, args): 129 | """ Just print the current module configuration 130 | We use the args parameters because the function is called 131 | before build_module 132 | Args: 133 | args (parse_args() returned Obj): the parameters of the models 134 | """ 135 | print('{}: {}'.format( 136 | self.name, 137 | ' '.join(getattr(args, self.name)) 138 | )) 139 | -------------------------------------------------------------------------------- /deepmusic/modules/batchbuilder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | The batch builder convert the songs into data readable by the neural networks. 18 | Used for training, testing and generating 19 | """ 20 | 21 | import random # Shuffling 22 | import operator # Multi-level sorting 23 | import numpy as np 24 | 25 | import deepmusic.songstruct as music 26 | 27 | 28 | class Batch: 29 | """Structure containing batches info 30 | Should be in a tf placeholder compatible format 31 | """ 32 | def __init__(self): 33 | self.inputs = [] 34 | self.targets = [] 35 | 36 | def generate(self, target=True): 37 | """ Is called just before feeding the placeholder, allows additional 38 | pre-processing 39 | Args: 40 | target(Bool): is true if the bach also need to generate the target 41 | """ 42 | pass 43 | 44 | 45 | class BatchBuilder: 46 | """ Class which create and manage batches 47 | Batches are created from the songs 48 | Define the song representation input (and output) format so the network must 49 | support the format 50 | The class has the choice to either entirely create the 51 | batches when get list is called or to create the batches 52 | as the training progress (more memory efficient) 53 | """ 54 | # TODO: Should add a option to pre-compute a lot of batches and 55 | # cache them in the hard drive 56 | # TODO: For generating mode, add another function 57 | # TODO: Add a function to get the length too (for tqdm when generators) ? 58 | def __init__(self, args): 59 | """ 60 | """ 61 | self.args = args 62 | 63 | @staticmethod 64 | def get_module_id(): 65 | """ Return the unique id associated with the builder 66 | Ultimately, the id will be used for saving/loading the dataset, and 67 | as parameter argument. 68 | Returns: 69 | str: The name of the builder 70 | """ 71 | raise NotImplementedError('Abstract class') 72 | 73 | def get_list(self, dataset, name): 74 | """ Compute the batches for the current epoch 75 | Is called twice (for training and testing) 76 | Args: 77 | dataset (list[Objects]): the training/testing set 78 | name (str): indicate the dataset type 79 | Return: 80 | list[Batch]: the batches to process 81 | """ 82 | raise NotImplementedError('Abstract class') 83 | 84 | def build_next(self, batch): 85 | """ In case of a generator (batches non precomputed), compute the batch given 86 | the batch id passed 87 | Args: 88 | batch: the current testing or training batch or id of batch to generate 89 | Return: 90 | Batch: the computed batch 91 | """ 92 | # TODO: Unused function. Instead Batch.generate does the same thing. Is it 93 | # a good idea ? Probably not. Instead should prefer this factory function 94 | return batch 95 | 96 | def build_placeholder_input(self): 97 | """ Create a placeholder compatible with the batch input 98 | Allow to control the dimensions 99 | Return: 100 | tf.placeholder: the placeholder for a single timestep 101 | """ 102 | raise NotImplementedError('Abstract class') 103 | 104 | def build_placeholder_target(self): 105 | """ Create a placeholder compatible with the target 106 | Allow to control the dimensions 107 | Return: 108 | tf.placeholder: the placeholder for a single timestep 109 | """ 110 | # TODO: The target also depend of the loss function (sigmoid, softmax,...) How to redefined that ? 111 | raise NotImplementedError('Abstract class') 112 | 113 | def process_song(self, song): 114 | """ Apply some pre-processing to the songs so the song 115 | already get the right input representation. 116 | Do it once globally for all songs 117 | Args: 118 | song (Song): the training/testing set 119 | Return: 120 | Object: the song after formatting 121 | """ 122 | return song # By default no pre-processing 123 | 124 | def reconstruct_song(self, song): 125 | """ Reconstruct the original raw song from the preprocessed data 126 | We should have: 127 | reconstruct_song(process_song(my_song)) == my_song 128 | 129 | Args: 130 | song (Object): the training/testing set 131 | Return: 132 | Song: the song after formatting 133 | """ 134 | return song # By default no pre-processing 135 | 136 | def process_batch(self, raw_song): 137 | """ Create the batch associated with the song 138 | Called when generating songs to create the initial input batch 139 | Args: 140 | raw_song (Song): The song to convert 141 | Return: 142 | Batch 143 | """ 144 | raise NotImplementedError('Abstract class') 145 | 146 | def reconstruct_batch(self, output, batch_id, chosen_labels=None): 147 | """ Create the song associated with the network output 148 | Args: 149 | output (list[np.Array]): The ouput of the network (size batch_size*output_dim) 150 | batch_id (int): The batch that we must reconstruct 151 | chosen_labels (list[np.Array[batch_size, int]]): the sampled class at each timestep (useful to reconstruct the generated song) 152 | Return: 153 | Song: The reconstructed song 154 | """ 155 | raise NotImplementedError('Abstract class') 156 | 157 | def get_input_dim(): 158 | """ Return the input dimension 159 | Return: 160 | int: 161 | """ 162 | raise NotImplementedError('Abstract class') 163 | 164 | 165 | class Relative(BatchBuilder): 166 | """ Prepare the batches for the current epoch. 167 | Generate batches of the form: 168 | 12 values for relative position with previous notes (modulo 12) 169 | 14 values for the relative pitch (+/-7) 170 | 12 values for the relative positions with the previous note 171 | """ 172 | NB_NOTES_SCALE = 12 173 | OFFSET_SCALES = 0 # Start at A0 174 | NB_SCALES = 7 # Up to G7 (Official order is A6, B6, C7, D7, E7,... G7) 175 | 176 | # Experiments on the relative note representation: 177 | # Experiment 1: 178 | # * As baseline, we only project the note on one scale (C5: 51) 179 | BASELINE_OFFSET = 51 180 | 181 | # Options: 182 | # * Note absolute (A,B,C,...G) vs relative ((current-prev)%12) 183 | NOTE_ABSOLUTE = False 184 | # * Use separation token between the notes (a note with class_pitch=-1 is a separation token) 185 | HAS_EMPTY = True 186 | 187 | class RelativeNote: 188 | """ Struct which define a note in a relative way with respect to 189 | the previous note 190 | Can only play 7 octave (so the upper and lower notes of the 191 | piano are never reached (not that important in practice)) 192 | """ 193 | def __init__(self): 194 | # TODO: Should the network get some information about the absolute pitch ?? An other way could be to 195 | # always start by a note from the base 196 | # TODO: Define behavior when saturating 197 | # TODO: Try with time relative to prev vs next 198 | # TODO: Try to randomly permute chords vs low to high pitch 199 | # TODO: Try pitch %7 vs fixed +/-7 200 | # TODO: Try to add a channel number for each note (2 class SoftMax) <= Would require a clean database where the melodie/bass are clearly separated 201 | self.pitch_class = 0 # A, B, C,... +/- %12 202 | self.scale = 0 # Octave +/- % 7 203 | self.prev_tick = 0 # Distance from previous note (from -0 up to -MAXIMUM_SONG_RESOLUTION*NOTES_PER_BAR (=1 bar)) 204 | 205 | class RelativeSong: 206 | """ Struct which define a song in a relative way (intern class format) 207 | Can only play 7 octave (so the upper and lower notes of the 208 | piano are never reached (not that important in practice)) 209 | """ 210 | def __init__(self): 211 | """ All attribute are defined with respect with the previous one 212 | """ 213 | self.first_note = None # Define the reference note 214 | self.notes = [] 215 | 216 | class RelativeBatch(Batch): 217 | """ Struct which contains temporary information necessary to reconstruct the 218 | batch 219 | """ 220 | class SongExtract: # Define a subsong 221 | def __init__(self): 222 | self.song = None # The song reference 223 | self.begin = 0 224 | self.end = 0 225 | 226 | def __init__(self, extracts): 227 | """ 228 | Args: 229 | extracts(list[SongExtract]): Should be of length batch_size, or at least all from the same size 230 | """ 231 | super().__init__() 232 | self.extracts = extracts 233 | 234 | def generate(self, target=True): 235 | """ 236 | Args: 237 | target(Bool): is true if the bach also need to generate the target 238 | """ 239 | # TODO: Could potentially be optimized (one big numpy array initialized only one, each input is a sub-arrays) 240 | # TODO: Those inputs should be cleared once the training pass has be run (Use class with generator, __next__ and __len__) 241 | sequence_length = self.extracts[0].end - self.extracts[0].begin 242 | shape_input = (len(self.extracts), Relative.RelativeBatch.get_input_dim()) # (batch_size, note_space) +1 because of the token 243 | 244 | def gen_input(i): 245 | array = np.zeros(shape_input) 246 | for j, extract in enumerate(self.extracts): # Iterate over the batches 247 | # Set the one-hot vector (chose label between ,A,...,G) 248 | label = extract.song.notes[extract.begin + i].pitch_class 249 | array[j, 0 if not label else label + 1] = 1 250 | return array 251 | 252 | def gen_target(i): # TODO: Could merge with the previous function to optimize the calls 253 | array = np.zeros([len(self.extracts)], dtype=int) # Int for SoftMax compatibility 254 | for j, extract in enumerate(self.extracts): # Iterate over the batches 255 | # Set the one-hot label (chose label between ,A,...,G) 256 | label = extract.song.notes[extract.begin + i + 1].pitch_class # Warning: +1 because targets are shifted with respect to the inputs 257 | array[j] = 0 if not label else label + 1 258 | return array 259 | 260 | self.inputs = [gen_input(i) for i in range(sequence_length)] # Generate each input sequence 261 | if target: 262 | self.targets = [gen_target(i) for i in range(sequence_length)] 263 | 264 | @staticmethod 265 | def get_input_dim(): 266 | """ 267 | """ 268 | # TODO: Refactoring. Where to place this functions ?? Should be accessible from model, and batch and depend of 269 | # batch_builder, also used in enco/deco modules. Ideally should not be static 270 | return 1 + Relative.NB_NOTES_SCALE # +1 because of the token 271 | 272 | def __init__(self, args): 273 | super().__init__(args) 274 | 275 | @staticmethod 276 | def get_module_id(): 277 | return 'relative' 278 | 279 | def process_song(self, old_song): 280 | """ Pre-process the data once globally 281 | Do it once globally. 282 | Args: 283 | old_song (Song): original song 284 | Returns: 285 | list[RelativeSong]: the new formatted song 286 | """ 287 | new_song = Relative.RelativeSong() 288 | 289 | old_song.normalize() 290 | 291 | # Gather all notes and sort them by absolute time 292 | all_notes = [] 293 | for track in old_song.tracks: 294 | for note in track.notes: 295 | all_notes.append(note) 296 | all_notes.sort(key=operator.attrgetter('tick', 'note')) # Sort first by tick, then by pitch 297 | 298 | # Compute the relative position for each note 299 | prev_note = all_notes[0] 300 | new_song.first_note = prev_note # TODO: What if the song start by a chord ? 301 | for note in all_notes[1:]: 302 | # Check if we should insert an empty token 303 | temporal_distance = note.tick - prev_note.tick 304 | assert temporal_distance >= 0 305 | if Relative.HAS_EMPTY and temporal_distance > 0: 306 | for i in range(temporal_distance): 307 | separator = Relative.RelativeNote() # Separation token 308 | separator.pitch_class = None 309 | new_song.notes.append(separator) 310 | 311 | # Insert the new relative note 312 | new_note = Relative.RelativeNote() 313 | if Relative.NOTE_ABSOLUTE: 314 | new_note.pitch_class = note.note % Relative.NB_NOTES_SCALE 315 | else: 316 | new_note.pitch_class = (note.note - prev_note.note) % Relative.NB_NOTES_SCALE 317 | new_note.scale = (note.note//Relative.NB_NOTES_SCALE - prev_note.note//Relative.NB_NOTES_SCALE) % Relative.NB_SCALES # TODO: add offset for the notes ? (where does the game begins ?) 318 | new_note.prev_tick = temporal_distance 319 | new_song.notes.append(new_note) 320 | 321 | prev_note = note 322 | 323 | return new_song 324 | 325 | def reconstruct_song(self, rel_song): 326 | """ Reconstruct the original raw song from the preprocessed data 327 | See parent class for details 328 | 329 | Some information will be lost compare to the original song: 330 | * Only one track left 331 | * Original tempo lost 332 | Args: 333 | rel_song (RelativeSong): the song to reconstruct 334 | Return: 335 | Song: the reconstructed song 336 | """ 337 | raw_song = music.Song() 338 | main_track = music.Track() 339 | 340 | prev_note = rel_song.first_note 341 | main_track.notes.append(rel_song.first_note) 342 | current_tick = rel_song.first_note.tick 343 | for next_note in rel_song.notes: 344 | # Case of separator 345 | if next_note.pitch_class is None: 346 | current_tick += 1 347 | continue 348 | 349 | # Adding the new note 350 | new_note = music.Note() 351 | # * Note 352 | if Relative.NOTE_ABSOLUTE: 353 | new_note.note = Relative.BASELINE_OFFSET + next_note.pitch_class 354 | else: 355 | new_note.note = Relative.BASELINE_OFFSET + ((prev_note.note-Relative.BASELINE_OFFSET) + next_note.pitch_class) % Relative.NB_NOTES_SCALE 356 | # * Tick 357 | if Relative.HAS_EMPTY: 358 | new_note.tick = current_tick 359 | else: 360 | new_note.tick = prev_note.tick + next_note.prev_tick 361 | # * Scale 362 | # ... 363 | main_track.notes.append(new_note) 364 | prev_note = new_note 365 | 366 | raw_song.tracks.append(main_track) 367 | raw_song.normalize(inverse=True) 368 | return raw_song 369 | 370 | def process_batch(self, raw_song): 371 | """ Create the batch associated with the song 372 | Args: 373 | raw_song (Song): The song to convert 374 | Return: 375 | RelativeBatch 376 | """ 377 | processed_song = self.process_song(raw_song) 378 | extract = self.create_extract(processed_song, 0, len(processed_song.notes)) 379 | batch = Relative.RelativeBatch([extract]) 380 | return batch 381 | 382 | def reconstruct_batch(self, output, batch_id, chosen_labels=None): 383 | """ Create the song associated with the network output 384 | Args: 385 | output (list[np.Array]): The ouput of the network (size batch_size*output_dim) 386 | batch_id (int): The batch id 387 | chosen_labels (list[np.Array[batch_size, int]]): the sampled class at each timestep (useful to reconstruct the generated song) 388 | Return: 389 | Song: The reconstructed song 390 | """ 391 | assert Relative.HAS_EMPTY == True 392 | 393 | processed_song = Relative.RelativeSong() 394 | processed_song.first_note = music.Note() 395 | processed_song.first_note.note = 56 # TODO: Define what should be the first note 396 | print('Reconstruct') 397 | for i, note in enumerate(output): 398 | relative = Relative.RelativeNote() 399 | # Here if we did sample the output, we should get which has heen the selected output 400 | if not chosen_labels or i == len(chosen_labels): # If chosen_labels, the last generated note has not been sampled 401 | chosen_label = int(np.argmax(note[batch_id,:])) # Cast np.int64 to int to avoid compatibility with mido 402 | else: 403 | chosen_label = int(chosen_labels[i][batch_id]) 404 | print(chosen_label, end=' ') # TODO: Add a text output connector 405 | if chosen_label == 0: # token 406 | relative.pitch_class = None 407 | #relative.scale = # Note used 408 | #relative.prev_tick = 409 | else: 410 | relative.pitch_class = chosen_label-1 411 | #relative.scale = 412 | #relative.prev_tick = 413 | processed_song.notes.append(relative) 414 | print() 415 | return self.reconstruct_song(processed_song) 416 | 417 | def create_extract(self, processed_song, start, length): 418 | """ preprocessed song > batch 419 | """ 420 | extract = Relative.RelativeBatch.SongExtract() 421 | extract.song = processed_song 422 | extract.begin = start 423 | extract.end = extract.begin + length 424 | return extract 425 | 426 | # TODO: How to optimize !! (precompute all values, use sparse arrays ?) 427 | def get_list(self, dataset, name): 428 | """ See parent class for more details 429 | Args: 430 | dataset (list[Song]): the training/testing set 431 | name (str): indicate the dataset type 432 | Return: 433 | list[Batch]: the batches to process 434 | """ 435 | # Randomly extract subsamples of the songs 436 | print('Subsampling the songs ({})...'.format(name)) 437 | 438 | extracts = [] 439 | sample_subsampling_length = self.args.sample_length+1 # We add 1 because each input has to predict the next output 440 | for song in dataset: 441 | len_song = len(song.notes) 442 | max_start = len_song - sample_subsampling_length 443 | assert max_start >= 0 # TODO: Error handling (and if =0, compatible with randint ?) 444 | nb_sample_song = 2*len_song // self.args.sample_length # The number of subsample is proportional to the song length (TODO: Could control the factor) 445 | for _ in range(nb_sample_song): 446 | extracts.append(self.create_extract( 447 | song, 448 | random.randrange(max_start), # Begin TODO: Add mode to only start at the beginning of a bar 449 | self.args.sample_length # End 450 | )) 451 | 452 | # Shuffle the song extracts 453 | print('Shuffling the dataset...') 454 | random.shuffle(extracts) 455 | 456 | # Group the samples together to create the batches 457 | print('Generating batches...') 458 | 459 | def gen_next_samples(): 460 | """ Generator over the mini-batch training samples 461 | Warning: the last samples will be ignored if the number of batch does not match the number of samples 462 | """ 463 | nb_samples = len(extracts) 464 | for i in range(nb_samples//self.args.batch_size): 465 | yield extracts[i*self.args.batch_size:(i+1)*self.args.batch_size] 466 | 467 | batch_set = [Relative.RelativeBatch(e) for e in gen_next_samples()] 468 | return batch_set 469 | 470 | def get_input_dim(self): 471 | """ In the case of the relative song, the input dim is the number of 472 | note on the scale (12) + 1 for the next token 473 | Return: 474 | int: 475 | """ 476 | return Relative.RelativeBatch.get_input_dim() 477 | 478 | 479 | class PianoRoll(BatchBuilder): 480 | """ Old piano roll format (legacy code). Won't work as it is 481 | """ 482 | def __init__(self, args): 483 | super().__init__(args) 484 | 485 | @staticmethod 486 | def get_module_id(): 487 | return 'pianoroll' 488 | 489 | def get_list(self, dataset): 490 | 491 | # On the original version, the songs were directly converted to piano roll 492 | # self._convert_song2array() 493 | 494 | batches = [] 495 | 496 | # TODO: Create batches (randomly cut each song in some small parts (need to know the total length for that) 497 | # then create the big matrix (NB_NOTE*sample_length) and turn that into batch). If process too long, 498 | # could save the created batches in a new folder, data/samples or save/model. 499 | 500 | # TODO: Create batches from multiples length (buckets). How to change the loss functions weights (longer 501 | # sequences more penalized ?) 502 | 503 | # TODO: Optimize memory management 504 | 505 | # First part: Randomly extract subsamples of the songs 506 | print('Subsampling songs ({})...'.format('train' if train_set else 'test')) 507 | 508 | sample_subsampling_length = self.args.sample_length+1 # We add 1 because each input has to predict the next output 509 | 510 | sub_songs = [] 511 | songs_set = dataset 512 | for song in songs_set: 513 | len_song = song.shape[-1] # The last dimension correspond to the song duration 514 | max_start = len_song - sample_subsampling_length 515 | assert max_start >= 0 # TODO: Error handling (and if =0, compatible with randint ?) 516 | nb_sample_song = 2*len_song // self.args.sample_length # The number of subsample is proportional to the song length 517 | for _ in range(nb_sample_song): 518 | start = np.random.randint(max_start) # TODO: Add mode to only start at the begining of a bar 519 | sub_song = song[:, start:start+sample_subsampling_length] 520 | sub_songs.append(sub_song) 521 | 522 | # Second part: Shuffle the song extracts 523 | print("Shuffling the dataset...") 524 | np.random.shuffle(sub_songs) 525 | 526 | # Third part: Group the samples together to create the batches 527 | print("Generating batches...") 528 | 529 | def gen_next_samples(): 530 | """ Generator over the mini-batch training samples 531 | Warning: the last samples will be ignored if the number of batch does not match the number of samples 532 | """ 533 | nb_samples = len(sub_songs) 534 | for i in range(nb_samples//self.args.batch_size): 535 | yield sub_songs[i*self.args.batch_size:(i+1)*self.args.batch_size] 536 | 537 | for samples in gen_next_samples(): # TODO: tqdm with persist = False / will this work with generators ? 538 | batch = Batch() 539 | 540 | # samples has shape [batch_size, NB_NOTES, sample_subsampling_length] 541 | assert len(samples) == self.args.batch_size 542 | assert samples[0].shape == (music.NB_NOTES, sample_subsampling_length) 543 | 544 | # Define targets and inputs 545 | for i in range(self.args.sample_length): 546 | input = -np.ones([len(samples), music.NB_NOTES]) 547 | target = np.zeros([len(samples), music.NB_NOTES]) 548 | for j, sample in enumerate(samples): # len(samples) == self.args.batch_size 549 | # TODO: Could reuse boolean idx computed (from target to next input) 550 | input[j, sample[:, i] == 1] = 1.0 551 | target[j, sample[:, i+1] == 1] = 1.0 552 | 553 | batch.inputs.append(input) 554 | batch.targets.append(target) 555 | 556 | batches.append(batch) 557 | 558 | # Use tf.train.batch() ?? 559 | 560 | # TODO: Save some batches as midi to see if correct 561 | 562 | return batches 563 | 564 | def get_batches_test(self): # TODO: Move that to BatchBuilder 565 | """ Return the batches which initiate the RNN when generating 566 | The initial batches are loaded from a json file containing the first notes of the song. The note values 567 | are the standard midi ones. Here is an examples of an initiator file: 568 | 569 | ``` 570 | {"initiator":[ 571 | {"name":"Simple_C4", 572 | "seq":[ 573 | {"notes":[60]} 574 | ]}, 575 | {"name":"some_chords", 576 | "seq":[ 577 | {"notes":[60,64]} 578 | {"notes":[66,68,71]} 579 | {"notes":[60,64]} 580 | ]} 581 | ]} 582 | ``` 583 | 584 | Return: 585 | List[Batch], List[str]: The generated batches with the associated names 586 | """ 587 | assert self.args.batch_size == 1 588 | 589 | batches = [] 590 | names = [] 591 | 592 | with open(self.TEST_INIT_FILE) as init_file: 593 | initiators = json.load(init_file) 594 | 595 | for initiator in initiators['initiator']: 596 | batch = Batch() 597 | 598 | for seq in initiator['seq']: # We add a few notes 599 | new_input = -np.ones([self.args.batch_size, music.NB_NOTES]) # No notes played by default 600 | for note in seq['notes']: 601 | new_input[0, note] = 1.0 602 | batch.inputs.append(new_input) 603 | 604 | names.append(initiator['name']) 605 | batches.append(batch) 606 | 607 | return batches, names 608 | -------------------------------------------------------------------------------- /deepmusic/modules/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | """ 18 | 19 | import tensorflow as tf 20 | 21 | import deepmusic.tfutils as tfutils 22 | import deepmusic.songstruct as music 23 | 24 | 25 | # TODO: Some class from the encoder and decoder are really similar. Could they be merged ? 26 | class DecoderNetwork: 27 | """ Predict a keyboard configuration at step t 28 | This is just an abstract class 29 | Warning: To encapsulate the weights in the right tf scope, they should be defined 30 | within the build function 31 | """ 32 | def __init__(self, args): 33 | """ 34 | Args: 35 | args: parameters of the model 36 | """ 37 | self.args = args 38 | 39 | def build(self): 40 | """ Initialize the weights of the model 41 | """ 42 | pass 43 | 44 | def init_state(self): 45 | """ Return the initial cell state 46 | """ 47 | return None 48 | 49 | def get_cell(self, prev_keyboard, prev_state_enco): 50 | """ Predict the next keyboard state 51 | Args: 52 | prev_keyboard (?): the previous keyboard configuration 53 | prev_state_enco (?): the encoder output state 54 | Return: 55 | Tuple: A tuple containing the predicted keyboard configuration and last decoder state 56 | """ 57 | raise NotImplementedError('Abstract class') 58 | 59 | 60 | class Rnn(DecoderNetwork): 61 | """ Predict a keyboard configuration at step t 62 | Use a RNN to predict the next configuration 63 | """ 64 | @staticmethod 65 | def get_module_id(): 66 | return 'rnn' 67 | 68 | def __init__(self, args): 69 | """ 70 | Args: 71 | args: parameters of the model 72 | """ 73 | super().__init__(args) 74 | self.rnn_cell = None 75 | self.project_key = None # Fct which project the decoder output into a single key space 76 | 77 | def build(self): 78 | """ Initialize the weights of the model 79 | """ 80 | self.rnn_cell = tfutils.get_rnn_cell(self.args, "deco_cell") 81 | self.project_key = tfutils.single_layer_perceptron([self.args.hidden_size, 1], 82 | 'project_key') 83 | 84 | def init_state(self): 85 | """ Return the initial cell state 86 | """ 87 | return self.rnn_cell.zero_state(batch_size=self.args.batch_size, dtype=tf.float32) 88 | 89 | def get_cell(self, prev_keyboard, prev_state_enco): 90 | """ a RNN decoder 91 | See parent class for arguments details 92 | """ 93 | 94 | axis = 1 # The first dimension is the batch, we split the keys 95 | assert prev_keyboard.get_shape()[axis].value == music.NB_NOTES 96 | inputs = tf.split(axis, music.NB_NOTES, prev_keyboard) 97 | 98 | outputs, final_state = tf.nn.seq2seq.rnn_decoder( 99 | decoder_inputs=inputs, 100 | initial_state=prev_state_enco, 101 | cell=self.rnn_cell 102 | # TODO: Which loop function (should use prediction) ? : Should take the previous generated input/ground truth (as the global model loop_fct). Need to add a new bool placeholder 103 | ) 104 | 105 | # Is it better to do the projection before or after the packing ? 106 | next_keys = [] 107 | for output in outputs: 108 | next_keys.append(self.project_key(output)) 109 | 110 | next_keyboard = tf.concat(axis, next_keys) 111 | 112 | return next_keyboard, final_state 113 | 114 | 115 | class Perceptron(DecoderNetwork): 116 | """ Single layer perceptron. Just a proof of concept for the architecture 117 | """ 118 | @staticmethod 119 | def get_module_id(): 120 | return 'perceptron' 121 | 122 | def __init__(self, args): 123 | """ 124 | Args: 125 | args: parameters of the model 126 | """ 127 | super().__init__(args) 128 | 129 | self.project_hidden = None # Fct which decode the previous state 130 | self.project_keyboard = None # Fct which project the decoder output into the keyboard space 131 | 132 | def build(self): 133 | """ Initialize the weights of the model 134 | """ 135 | # For projecting on the keyboard space 136 | self.project_hidden = tfutils.single_layer_perceptron([music.NB_NOTES, self.args.hidden_size], 137 | 'project_hidden') 138 | 139 | # For projecting on the keyboard space 140 | self.project_keyboard = tfutils.single_layer_perceptron([self.args.hidden_size, music.NB_NOTES], 141 | 'project_keyboard') # Should we do the activation sigmoid here ? 142 | 143 | def get_cell(self, prev_keyboard, prev_state_enco): 144 | """ Simple 1 hidden layer perceptron 145 | See parent class for arguments details 146 | """ 147 | # Don't change the state 148 | next_state_deco = prev_state_enco # Return the last state (Useful ?) 149 | 150 | # Compute the next output 151 | hidden_state = self.project_hidden(prev_keyboard) 152 | next_keyboard = self.project_keyboard(hidden_state) # Should we do the activation sigmoid here ? Maybe not because the loss function does it 153 | 154 | return next_keyboard, next_state_deco 155 | 156 | 157 | class Lstm(DecoderNetwork): 158 | """ Multi-layer Lstm. Just a wrapper around the official tf 159 | """ 160 | @staticmethod 161 | def get_module_id(): 162 | return 'lstm' 163 | 164 | def __init__(self, args, *module_args): 165 | """ 166 | Args: 167 | args: parameters of the model 168 | """ 169 | super().__init__(args) 170 | self.args = args 171 | 172 | self.rnn_cell = None 173 | self.project_keyboard = None # Fct which project the decoder output into the ouput space 174 | 175 | def build(self): 176 | """ Initialize the weights of the model 177 | """ 178 | # TODO: Control over the the Cell using module arguments instead of global arguments (hidden_size and num_layer) !! 179 | # RNN network 180 | rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(self.args.hidden_size, state_is_tuple=True) # Or GRUCell, LSTMCell(args.hidden_size) 181 | if not self.args.test: # TODO: Should use a placeholder instead 182 | rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, input_keep_prob=1.0, output_keep_prob=0.9) # TODO: Custom values 183 | rnn_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell] * self.args.num_layers, state_is_tuple=True) 184 | 185 | self.rnn_cell = rnn_cell 186 | 187 | # For projecting on the keyboard space 188 | self.project_output = tfutils.single_layer_perceptron([self.args.hidden_size, 12 + 1], # TODO: HACK: Input/output space hardcoded !!! 189 | 'project_output') # Should we do the activation sigmoid here ? 190 | 191 | def init_state(self): 192 | """ Return the initial cell state 193 | """ 194 | return self.rnn_cell.zero_state(batch_size=self.args.batch_size, dtype=tf.float32) 195 | 196 | def get_cell(self, prev_input, prev_states): 197 | """ 198 | """ 199 | next_output, next_state = self.rnn_cell(prev_input, prev_states[1]) 200 | next_output = self.project_output(next_output) 201 | # No activation function here: SoftMax is computed by the loss function 202 | 203 | return next_output, next_state 204 | -------------------------------------------------------------------------------- /deepmusic/modules/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | """ 18 | 19 | import tensorflow as tf 20 | 21 | import deepmusic.tfutils as tfutils 22 | import deepmusic.songstruct as music 23 | 24 | 25 | class EncoderNetwork: 26 | """ From the previous keyboard configuration, prepare the state for the next one 27 | Encode the keyboard configuration at a state t 28 | This abstract class has no effect be is here to be subclasses 29 | Warning: To encapsulate the weights in the right tf scope, they should be defined 30 | within the build function 31 | """ 32 | def __init__(self, args): 33 | """ 34 | Args: 35 | args: parameters of the model 36 | """ 37 | self.args = args 38 | 39 | def build(self): 40 | """ Initialize the weights of the model 41 | """ 42 | 43 | def init_state(self): 44 | """ Return the initial cell state 45 | """ 46 | return None 47 | 48 | def get_cell(self, prev_keyboard, prev_state): 49 | """ Predict the next keyboard state 50 | Args: 51 | prev_keyboard (tf.Tensor): the previous keyboard configuration 52 | prev_state (Tuple): the previous decoder state 53 | Return: 54 | tf.Tensor: the final encoder state 55 | """ 56 | raise NotImplementedError('Abstract Class') 57 | 58 | 59 | class Identity(EncoderNetwork): 60 | """ Implement lookup for note embedding 61 | """ 62 | 63 | @staticmethod 64 | def get_module_id(): 65 | return 'identity' 66 | 67 | def __init__(self, args): 68 | """ 69 | Args: 70 | args: parameters of the model 71 | """ 72 | super().__init__(args) 73 | 74 | def get_cell(self, prev_keyboard, prev_state): 75 | """ Predict the next keyboard state 76 | Args: 77 | prev_keyboard (tf.Tensor): the previous keyboard configuration 78 | prev_state (Tuple): the previous decoder state 79 | Return: 80 | tf.Tensor: the final encoder state 81 | """ 82 | prev_state_enco, prev_state_deco = prev_state 83 | 84 | # This simple class just pass the previous state 85 | next_state_enco = prev_state_enco 86 | 87 | return next_state_enco 88 | 89 | 90 | class Rnn(EncoderNetwork): 91 | """ Read each keyboard configuration note by note and encode it's configuration 92 | """ 93 | @staticmethod 94 | def get_module_id(): 95 | return 'rnn' 96 | 97 | def __init__(self, args): 98 | """ 99 | Args: 100 | args: parameters of the model 101 | """ 102 | super().__init__(args) 103 | self.rnn_cell = None 104 | 105 | def build(self): 106 | """ Initialize the weights of the model 107 | """ 108 | self.rnn_cell = tfutils.get_rnn_cell(self.args, "deco_cell") 109 | 110 | def init_state(self): 111 | """ Return the initial cell state 112 | """ 113 | return self.rnn_cell.zero_state(batch_size=self.args.batch_size, dtype=tf.float32) 114 | 115 | def get_cell(self, prev_keyboard, prev_state): 116 | """ a RNN encoder 117 | See parent class for arguments details 118 | """ 119 | prev_state_enco, prev_state_deco = prev_state 120 | 121 | axis = 1 # The first dimension is the batch, we split the keys 122 | assert prev_keyboard.get_shape()[axis].value == music.NB_NOTES 123 | inputs = tf.split(axis, music.NB_NOTES, prev_keyboard) 124 | 125 | _, final_state = tf.nn.rnn( 126 | self.rnn_cell, 127 | inputs, 128 | initial_state=prev_state_deco 129 | ) 130 | 131 | return final_state 132 | 133 | 134 | class Embedding(EncoderNetwork): 135 | """ Implement lookup for note embedding 136 | """ 137 | @staticmethod 138 | def get_module_id(): 139 | return 'embedding' 140 | 141 | def __init__(self, args): 142 | """ 143 | Args: 144 | args: parameters of the model 145 | """ 146 | super().__init__(args) 147 | 148 | def build(self): 149 | """ Initialize the weights of the model 150 | """ 151 | 152 | def init_state(self): 153 | """ Return the initial cell state 154 | """ 155 | 156 | def get_cell(self, prev_keyboard, prev_state): 157 | """ a RNN encoder 158 | See parent class for arguments details 159 | """ 160 | # TODO: 161 | return 162 | 163 | -------------------------------------------------------------------------------- /deepmusic/modules/learningratepolicy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | The learning rate policy control the evolution of the learning rate during 18 | the training 19 | """ 20 | 21 | 22 | class LearningRatePolicy: 23 | """ Contains the different policies for the learning rate decay 24 | """ 25 | def __init__(self, args): 26 | """ 27 | Args: 28 | args: parameters of the model 29 | """ 30 | 31 | def get_learning_rate(self, glob_step): 32 | """ Return the learning rate associated at the current training step 33 | Args: 34 | glob_step (int): Number of iterations since the beginning of training 35 | Return: 36 | float: the learning rate at the given step 37 | """ 38 | raise NotImplementedError('Abstract class') 39 | 40 | 41 | class LearningRatePolicyOld: 42 | """ Contains the different policies for the learning rate decay 43 | """ 44 | CST = 'cst' # Fixed learning rate over all steps (default behavior) 45 | STEP = 'step' # We divide the learning rate every x iterations 46 | EXPONENTIAL = 'exponential' # 47 | 48 | @staticmethod 49 | def get_policies(): 50 | """ Return the list of the different modes 51 | Useful when parsing the command lines arguments 52 | """ 53 | return [ 54 | LearningRatePolicy.CST, 55 | LearningRatePolicy.STEP, 56 | LearningRatePolicy.EXPONENTIAL 57 | ] 58 | 59 | def __init__(self, args): 60 | """ 61 | Args: 62 | args: parameters of the model 63 | """ 64 | self.learning_rate_fct = None 65 | 66 | assert args.learning_rate 67 | assert len(args.learning_rate) > 0 68 | 69 | policy = args.learning_rate[0] 70 | 71 | if policy == LearningRatePolicy.CST: 72 | if not len(args.learning_rate) == 2: 73 | raise ValueError( 74 | 'Learning rate cst policy should be on the form: {} lr_value'.format(Model.LearningRatePolicy.CST)) 75 | self.learning_rate_init = float(args.learning_rate[1]) 76 | self.learning_rate_fct = self._lr_cst 77 | 78 | elif policy == LearningRatePolicy.STEP: 79 | if not len(args.learning_rate) == 3: 80 | raise ValueError('Learning rate step policy should be on the form: {} lr_init decay_period'.format( 81 | LearningRatePolicy.STEP)) 82 | self.learning_rate_init = float(args.learning_rate[1]) 83 | self.decay_period = int(args.learning_rate[2]) 84 | self.learning_rate_fct = self._lr_step 85 | 86 | else: 87 | raise ValueError('Unknown chosen learning rate policy: {}'.format(policy)) 88 | 89 | def _lr_cst(self, glob_step): 90 | """ Just a constant learning rate 91 | """ 92 | return self.learning_rate_init 93 | 94 | def _lr_step(self, glob_step): 95 | """ Every decay period, the learning rate is divided by 2 96 | """ 97 | return self.learning_rate_init / 2 ** (glob_step // self.decay_period) 98 | 99 | def get_learning_rate(self, glob_step): 100 | """ Return the learning rate associated at the current training step 101 | Args: 102 | glob_step (int): Number of iterations since the beginning of training 103 | Return: 104 | float: the learning rate at the given step 105 | """ 106 | return self.learning_rate_fct(glob_step) 107 | 108 | 109 | class Cst(LearningRatePolicy): 110 | """ Fixed learning rate over all steps (default behavior) 111 | """ 112 | @staticmethod 113 | def get_module_id(): 114 | return 'cst' 115 | 116 | def __init__(self, args, lr=0.0001): 117 | """ 118 | Args: 119 | args: parameters of the model 120 | """ 121 | self.lr = lr 122 | 123 | def get_learning_rate(self, glob_step): 124 | """ Return the learning rate associated at the current training step 125 | Args: 126 | glob_step (int): Number of iterations since the beginning of training 127 | Return: 128 | float: the learning rate at the given step 129 | """ 130 | return self.lr 131 | 132 | 133 | class StepsWithDecay(LearningRatePolicy): 134 | """ 135 | """ 136 | 137 | @staticmethod 138 | def get_module_id(): 139 | return 'step' 140 | 141 | 142 | class Adaptive(LearningRatePolicy): 143 | """ Decrease the learning rate when training error 144 | reach a plateau 145 | """ 146 | 147 | @staticmethod 148 | def get_module_id(): 149 | return 'adaptive' 150 | -------------------------------------------------------------------------------- /deepmusic/modules/loopprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | """ 18 | 19 | import tensorflow as tf 20 | 21 | 22 | class LoopProcessing: 23 | """ Apply some processing to the rnn loop which connect the output to 24 | the next input. 25 | Is called on the loop_function attribute of the rnn decoder 26 | """ 27 | def __init__(self, args): 28 | pass 29 | 30 | def __call__(self, prev_output): 31 | """ Function which apply the preprocessing 32 | Args: 33 | prev_output (tf.Tensor): the ouput on which applying the transformation 34 | Return: 35 | tf.Ops: the processing operator 36 | """ 37 | raise NotImplementedError('Abstract Class') 38 | 39 | def get_op(self): 40 | """ Return the chosen labels from the softmax distribution 41 | Allows to reconstruct the song 42 | """ 43 | return () # Empty tuple 44 | 45 | 46 | class SampleSoftmax(LoopProcessing): 47 | """ Sample from the softmax distribution 48 | """ 49 | @staticmethod 50 | def get_module_id(): 51 | return 'sample_softmax' 52 | 53 | def __init__(self, args, *args_module): 54 | 55 | self.temperature = args.temperature # Control the sampling (more or less concervative predictions) (TODO: Could be a argument of modeule, but in this case will automatically be restored when --test, should also be in the save name) 56 | self.chosen_labels = [] # Keep track of the chosen labels (to reconstruct the chosen song) 57 | 58 | def __call__(self, prev_output): 59 | """ Use TODO formula 60 | Args: 61 | prev_output (tf.Tensor): the ouput on which applying the transformation 62 | Return: 63 | tf.Ops: the processing operator 64 | """ 65 | # prev_output size: [batch_size, nb_labels] 66 | nb_labels = prev_output.get_shape().as_list()[-1] 67 | 68 | if False: # TODO: Add option to control argmax 69 | #label_draws = tf.argmax(prev_output, 1) 70 | label_draws = tf.multinomial(tf.log(prev_output), 1) # Draw 1 sample from the distribution 71 | label_draws = tf.squeeze(label_draws, [1]) 72 | self.chosen_labels.append(label_draws) 73 | next_input = tf.one_hot(label_draws, nb_labels) 74 | return next_input 75 | # Could use the Gumbel-Max trick to sample from a softmax distribution ? 76 | 77 | soft_values = tf.exp(tf.div(prev_output, self.temperature)) # Pi = exp(pi/t) 78 | # soft_values size: [batch_size, nb_labels] 79 | 80 | normalisation_coeff = tf.expand_dims(tf.reduce_sum(soft_values, 1), -1) 81 | # normalisation_coeff size: [batch_size, 1] 82 | probs = tf.div(soft_values, normalisation_coeff + 1e-8) # = Pi / sum(Pk) 83 | # probs size: [batch_size, nb_labels] 84 | label_draws = tf.multinomial(tf.log(probs), 1) # Draw 1 sample from the log-probability distribution 85 | # probs label_draws: [batch_size, 1] 86 | label_draws = tf.squeeze(label_draws, [1]) 87 | # label_draws size: [batch_size,] 88 | self.chosen_labels.append(label_draws) 89 | next_input = tf.one_hot(label_draws, nb_labels) # Reencode the next input vector 90 | # next_input size: [batch_size, nb_labels] 91 | return next_input 92 | 93 | def get_op(self): 94 | """ Return the chosen labels from the softmax distribution 95 | Allows to reconstruct the song 96 | """ 97 | return self.chosen_labels 98 | 99 | 100 | class ActivateScale(LoopProcessing): 101 | """ Activate using sigmoid and scale the prediction on [-1, 1] 102 | """ 103 | @staticmethod 104 | def get_module_id(): 105 | return 'activate_and_scale' 106 | 107 | def __init__(self, args): 108 | pass 109 | 110 | def __call__(X): 111 | """ Predict the output from prev and scale the result on [-1, 1] 112 | Use sigmoid activation 113 | Args: 114 | X (tf.Tensor): the input 115 | Return: 116 | tf.Ops: the activate_and_scale operator 117 | """ 118 | # TODO: Use tanh instead ? tanh=2*sigm(2*x)-1 119 | with tf.name_scope('activate_and_scale'): 120 | return tf.sub(tf.mul(2.0, tf.nn.sigmoid(X)), 1.0) # x_{i} = 2*sigmoid(y_{i-1}) - 1 121 | -------------------------------------------------------------------------------- /deepmusic/musicdata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Loads the midi song, build the dataset 18 | """ 19 | 20 | from tqdm import tqdm # Progress bar when creating dataset 21 | import pickle # Saving the data 22 | import os # Checking file existence 23 | import numpy as np # Batch data 24 | import json # Load initiators (inputs for generating new songs) 25 | 26 | from deepmusic.moduleloader import ModuleLoader 27 | from deepmusic.midiconnector import MidiConnector 28 | from deepmusic.midiconnector import MidiInvalidException 29 | import deepmusic.songstruct as music 30 | 31 | 32 | class MusicData: 33 | """Dataset class 34 | """ 35 | 36 | def __init__(self, args): 37 | """Load all conversations 38 | Args: 39 | args: parameters of the model 40 | """ 41 | 42 | # Filename and directories constants 43 | self.DATA_VERSION = '0.2' # Assert compatibility between versions 44 | self.DATA_DIR_MIDI = 'data/midi' # Originals midi files 45 | self.DATA_DIR_PLAY = 'data/play' # Target folder to show the reconstructed files 46 | self.DATA_DIR_SAMPLES = 'data/samples' # Training/testing samples after pre-processing 47 | self.DATA_SAMPLES_RAW = 'raw' # Unpreprocessed songs container tag 48 | self.DATA_SAMPLES_EXT = '.pkl' 49 | self.TEST_INIT_FILE = 'data/test/initiator.json' # Initial input for the generated songs 50 | self.FILE_EXT = '.mid' # Could eventually add support for other format later ? 51 | 52 | # Model parameters 53 | self.args = args 54 | 55 | # Dataset 56 | self.songs = [] 57 | self.songs_train = None 58 | self.songs_test = None 59 | 60 | # TODO: Dynamic loading of the the associated dataset flag (ex: data/samples/pianoroll/...) 61 | self.batch_builder = ModuleLoader.batch_builders.build_module(args) 62 | 63 | if not self.args.test: # No need to load the dataset when testing 64 | self._restore_dataset() 65 | 66 | if self.args.play_dataset: 67 | print('Play some songs from the formatted data') 68 | # Generate songs 69 | for i in range(min(10, len(self.songs))): 70 | raw_song = self.batch_builder.reconstruct_song(self.songs[i]) 71 | MidiConnector.write_song(raw_song, os.path.join(self.DATA_DIR_PLAY, str(i))) 72 | # TODO: Display some images corresponding to the loaded songs 73 | raise NotImplementedError('Can\'t play a song for now') 74 | 75 | self._split_dataset() # Warning: the list order will determine the train/test sets (so important that it don't change from run to run) 76 | 77 | # Plot some stats: 78 | print('Loaded: {} songs ({} train/{} test)'.format( 79 | len(self.songs_train) + len(self.songs_test), 80 | len(self.songs_train), 81 | len(self.songs_test)) 82 | ) # TODO: Print average, max, min duration 83 | 84 | def _restore_dataset(self): 85 | """Load/create the conversations data 86 | Done in two steps: 87 | * Extract the midi files as a raw song format 88 | * Transform this raw song as neural networks compatible input 89 | """ 90 | 91 | # Construct the dataset names 92 | samples_path_generic = os.path.join( 93 | self.args.root_dir, 94 | self.DATA_DIR_SAMPLES, 95 | self.args.dataset_tag + '-{}' + self.DATA_SAMPLES_EXT 96 | ) 97 | samples_path_raw = samples_path_generic.format(self.DATA_SAMPLES_RAW) 98 | samples_path_preprocessed = samples_path_generic.format(ModuleLoader.batch_builders.get_chosen_name()) 99 | 100 | # TODO: the _restore_samples from the raw songs and precomputed database should have different versions number 101 | 102 | # Restoring precomputed database 103 | if os.path.exists(samples_path_preprocessed): 104 | print('Restoring dataset from {}...'.format(samples_path_preprocessed)) 105 | self._restore_samples(samples_path_preprocessed) 106 | 107 | # First time we load the database: creating all files 108 | else: 109 | print('Training samples not found. Creating dataset from the songs...') 110 | # Restoring raw songs 111 | if os.path.exists(samples_path_raw): 112 | print('Restoring songs from {}...'.format(samples_path_raw)) 113 | self._restore_samples(samples_path_raw) 114 | 115 | # First time we load the database: creating all files 116 | else: 117 | print('Raw songs not found. Extracting from midi files...') 118 | self._create_raw_songs() 119 | print('Saving raw songs...') 120 | self._save_samples(samples_path_raw) 121 | 122 | # At this point, self.songs contain the list of the raw songs. Each 123 | # song is then preprocessed by the batch builder 124 | 125 | # Generating the data from the raw songs 126 | print('Pre-processing songs...') 127 | for i, song in tqdm(enumerate(self.songs), total=len(self.songs)): 128 | self.songs[i] = self.batch_builder.process_song(song) 129 | 130 | print('Saving dataset...') 131 | np.random.shuffle(self.songs) # Important to do that before saving so the train/test set will be fixed each time we reload the dataset 132 | self._save_samples(samples_path_preprocessed) 133 | 134 | def _restore_samples(self, samples_path): 135 | """ Load samples from file 136 | Args: 137 | samples_path (str): The path where to load the model (all dirs should exist) 138 | Return: 139 | List[Song]: The training data 140 | """ 141 | with open(samples_path, 'rb') as handle: 142 | data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset 143 | 144 | # Check the version 145 | current_version = data['version'] 146 | if current_version != self.DATA_VERSION: 147 | raise UserWarning('Present configuration version {0} does not match {1}.'.format(current_version, self.DATA_VERSION)) 148 | 149 | # Restore parameters 150 | self.songs = data['songs'] 151 | 152 | def _save_samples(self, samples_path): 153 | """ Save samples to file 154 | Args: 155 | samples_path (str): The path where to save the model (all dirs should exist) 156 | """ 157 | 158 | with open(samples_path, 'wb') as handle: 159 | data = { # Warning: If adding something here, also modifying loadDataset 160 | 'version': self.DATA_VERSION, 161 | 'songs': self.songs 162 | } 163 | pickle.dump(data, handle, -1) # Using the highest protocol available 164 | 165 | def _create_raw_songs(self): 166 | """ Create the database from the midi files 167 | """ 168 | midi_dir = os.path.join(self.args.root_dir, self.DATA_DIR_MIDI, self.args.dataset_tag) 169 | midi_files = [os.path.join(midi_dir, f) for f in os.listdir(midi_dir) if f.endswith(self.FILE_EXT)] 170 | 171 | for filename in tqdm(midi_files): 172 | 173 | try: 174 | new_song = MidiConnector.load_file(filename) 175 | except MidiInvalidException as e: 176 | tqdm.write('File ignored ({}): {}'.format(filename, e)) 177 | else: 178 | self.songs.append(new_song) 179 | tqdm.write('Song loaded {}: {} tracks, {} notes, {} ticks/beat'.format( 180 | filename, 181 | len(new_song.tracks), 182 | sum([len(t.notes) for t in new_song.tracks]), 183 | new_song.ticks_per_beat 184 | )) 185 | 186 | if not self.songs: 187 | raise ValueError('Empty dataset. Check that the folder exist and contains supported midi files.') 188 | 189 | def _convert_song2array(self, song): 190 | """ Convert a given song to a numpy multi-dimensional array (piano roll) 191 | The song is temporally normalized, meaning that all ticks and duration will be converted to a specific 192 | ticks_per_beat independent unit. 193 | For now, the changes of tempo are ignored. Only 4/4 is supported. 194 | Warning: The duration is ignored: All note have the same duration (1 unit) 195 | Args: 196 | song (Song): The song to convert 197 | Return: 198 | Array: the numpy array: a binary matrix of shape [NB_NOTES, song_length] 199 | """ 200 | 201 | # Convert the absolute ticks in standardized unit 202 | song_length = len(song) 203 | scale = self._get_scale(song) 204 | 205 | # TODO: Not sure why this plot a decimal value (x.66). Investigate... 206 | # print(song_length/scale) 207 | 208 | # Use sparse array instead ? 209 | piano_roll = np.zeros([music.NB_NOTES, int(np.ceil(song_length/scale))], dtype=int) 210 | 211 | # Adding all notes 212 | for track in song.tracks: 213 | for note in track.notes: 214 | piano_roll[note.get_relative_note()][note.tick//scale] = 1 215 | 216 | return piano_roll 217 | 218 | def _convert_array2song(self, array): 219 | """ Create a new song from a numpy array 220 | A note will be created for each non empty case of the array. The song will contain a single track, and use the 221 | default beats_per_tick as midi resolution 222 | For now, the changes of tempo are ignored. Only 4/4 is supported. 223 | Warning: All note have the same duration, the default value defined in music.Note 224 | Args: 225 | np.array: the numpy array (Warning: could be a array of int or float containing the prediction before the sigmoid) 226 | Return: 227 | song (Song): The song to convert 228 | """ 229 | 230 | new_song = music.Song() 231 | main_track = music.Track() 232 | 233 | scale = self._get_scale(new_song) 234 | 235 | for index, x in np.ndenumerate(array): # Add some notes 236 | if x > 1e-12: # Note added (TODO: What should be the condition, =1 ? sigmoid>0.5 ?) 237 | new_note = music.Note() 238 | 239 | new_note.set_relative_note(index[0]) 240 | new_note.tick = index[1] * scale # Absolute time in tick from the beginning 241 | 242 | main_track.notes.append(new_note) 243 | 244 | new_song.tracks.append(main_track) 245 | 246 | return new_song 247 | 248 | def _split_dataset(self): 249 | """ Create the test/train set from the loaded songs 250 | The dataset has been shuffled when calling this function (Warning: the shuffling 251 | is done and fixed before saving the dataset the first time so it is important to 252 | NOT call shuffle a second time) 253 | """ 254 | split_nb = int(self.args.ratio_dataset * len(self.songs)) 255 | self.songs_train = self.songs[:split_nb] 256 | self.songs_test = self.songs[split_nb:] 257 | self.songs = None # Not needed anymore (free some memory) 258 | 259 | def get_batches(self): 260 | """ Prepare the batches for the current epoch 261 | WARNING: The songs are not shuffled in this functions. We leave the choice 262 | to the batch_builder to manage the shuffling 263 | Return: 264 | list[Batch], list[Batch]: The batches for the training and testing set (can be generators) 265 | """ 266 | return ( 267 | self.batch_builder.get_list(self.songs_train, name='train'), 268 | self.batch_builder.get_list(self.songs_test, name='test'), 269 | ) 270 | 271 | # def get_batches_test(self, ): # TODO: Should only return a single batch (loading done in main class) 272 | # """ Return the batch which initiate the RNN when generating 273 | # The initial batches are loaded from a json file containing the first notes of the song. The note values 274 | # are the standard midi ones. Here is an examples of an initiator file: 275 | # Args: 276 | # TODO 277 | # Return: 278 | # Batch: The generated batch 279 | # """ 280 | # assert self.args.batch_size == 1 281 | # batch = None # TODO 282 | # return batch 283 | 284 | def get_batches_test_old(self): # TODO: This is the old version. Ideally should use the version above 285 | """ Return the batches which initiate the RNN when generating 286 | The initial batches are loaded from a json file containing the first notes of the song. The note values 287 | are the standard midi ones. Here is an examples of an initiator file: 288 | ``` 289 | {"initiator":[ 290 | {"name":"Simple_C4", 291 | "seq":[ 292 | {"notes":[60]} 293 | ]}, 294 | {"name":"some_chords", 295 | "seq":[ 296 | {"notes":[60,64]} 297 | {"notes":[66,68,71]} 298 | {"notes":[60,64]} 299 | ]} 300 | ]} 301 | ``` 302 | Return: 303 | List[Batch], List[str]: The generated batches with the associated names 304 | """ 305 | assert self.args.batch_size == 1 306 | 307 | batches = [] 308 | names = [] 309 | 310 | with open(self.TEST_INIT_FILE) as init_file: 311 | initiators = json.load(init_file) 312 | 313 | for initiator in initiators['initiator']: 314 | raw_song = music.Song() 315 | main_track = music.Track() 316 | 317 | current_tick = 0 318 | for seq in initiator['seq']: # We add a few notes 319 | for note_pitch in seq['notes']: 320 | new_note = music.Note() 321 | new_note.note = note_pitch 322 | new_note.tick = current_tick 323 | main_track.notes.append(new_note) 324 | current_tick += 1 325 | 326 | raw_song.tracks.append(main_track) 327 | raw_song.normalize(inverse=True) 328 | 329 | batch = self.batch_builder.process_batch(raw_song) 330 | 331 | names.append(initiator['name']) 332 | batches.append(batch) 333 | 334 | return batches, names 335 | 336 | @staticmethod 337 | def _convert_to_piano_rolls(outputs): 338 | """ Create songs from the decoder outputs. 339 | Reshape the list of outputs to list of piano rolls 340 | Args: 341 | outputs (List[np.array]): The list of the predictions of the decoder 342 | Return: 343 | List[np.array]: the list of the songs (one song by batch) as piano roll 344 | """ 345 | 346 | # Extract the batches and recreate the array for each batch 347 | piano_rolls = [] 348 | for i in range(outputs[0].shape[0]): # Iterate over the batches 349 | piano_roll = None 350 | for j in range(len(outputs)): # Iterate over the sample length 351 | # outputs[j][i, :] has shape [NB_NOTES, 1] 352 | if piano_roll is None: 353 | piano_roll = [outputs[j][i, :]] 354 | else: 355 | piano_roll = np.append(piano_roll, [outputs[j][i, :]], axis=0) 356 | piano_rolls.append(piano_roll.T) 357 | 358 | return piano_rolls 359 | 360 | def visit_recorder(self, outputs, base_dir, base_name, recorders, chosen_labels=None): 361 | """ Save the predicted output songs using the given recorder 362 | Args: 363 | outputs (List[np.array]): The list of the predictions of the decoder 364 | base_dir (str): Path were to save the outputs 365 | base_name (str): filename of the output (without the extension) 366 | recorders (List[Obj]): Interfaces called to convert the song into a file (ex: midi or png). The recorders 367 | need to implement the method write_song (the method has to add the file extension) and the 368 | method get_input_type. 369 | chosen_labels (list[np.Array[batch_size, int]]): the chosen class at each timestep (useful to reconstruct the generated song) 370 | """ 371 | 372 | if not os.path.exists(base_dir): 373 | os.makedirs(base_dir) 374 | 375 | for batch_id in range(outputs[0].shape[0]): # Loop over batch_size 376 | song = self.batch_builder.reconstruct_batch(outputs, batch_id, chosen_labels) 377 | for recorder in recorders: 378 | if recorder.get_input_type() == 'song': 379 | input = song 380 | elif recorder.get_input_type() == 'array': 381 | #input = self._convert_song2array(song) 382 | continue # TODO: For now, pianoroll desactivated 383 | else: 384 | raise ValueError('Unknown recorder input type.'.format(recorder.get_input_type())) 385 | base_path = os.path.join(base_dir, base_name + '-' + str(batch_id)) 386 | recorder.write_song(input, base_path) 387 | -------------------------------------------------------------------------------- /deepmusic/songstruct.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Hierarchical data structures of a song 18 | """ 19 | 20 | import operator # To rescale the song 21 | 22 | 23 | MIDI_NOTES_RANGE = [21, 108] # Min and max (included) midi note on a piano 24 | # TODO: Warn/throw when we try to add a note outside this range 25 | # TODO: Easy conversion from this range to tensor vector id (midi_note2tf_id) 26 | 27 | NB_NOTES = MIDI_NOTES_RANGE[1] - MIDI_NOTES_RANGE[0] + 1 # Vertical dimension of a song (=88 of keys for a piano) 28 | 29 | BAR_DIVISION = 16 # Nb of tics in a bar (What about waltz ? is 12 better ?) 30 | 31 | 32 | class Note: 33 | """ Structure which encapsulate the song data 34 | """ 35 | def __init__(self): 36 | self.tick = 0 37 | self.note = 0 38 | self.duration = 32 # TODO: Define the default duration / TODO: Use standard musical units (quarter note/eighth note) ?, don't convert here 39 | 40 | def get_relative_note(self): 41 | """ Convert the absolute midi position into the range given by MIDI_NOTES_RANGE 42 | Return 43 | int: The new position relative to the range (position on keyboard) 44 | """ 45 | return self.note - MIDI_NOTES_RANGE[0] 46 | 47 | def set_relative_note(self, rel): 48 | """ Convert given note into a absolute midi position 49 | Args: 50 | rel (int): The new position relative to the range (position on keyboard) 51 | """ 52 | # TODO: assert range (rel < NB_NOTES)? 53 | self.note = rel + MIDI_NOTES_RANGE[0] 54 | 55 | 56 | class Track: 57 | """ Structure which encapsulate a track of the song 58 | Ideally, each track should correspond to a single instrument and one channel. Multiple tracks could correspond 59 | to the same channel if different instruments use the same channel. 60 | """ 61 | def __init__(self): 62 | #self.tempo_map = None # Use a global tempo map 63 | self.instrument = None 64 | self.notes = [] # List[Note] 65 | #self.color = (0, 0, 0) # Color of the track for visual plotting 66 | self.is_drum = False 67 | 68 | def set_instrument(self, msg): 69 | """ Initialize from a mido message 70 | Args: 71 | msg (mido.MidiMessage): a valid control_change message 72 | """ 73 | if self.instrument is not None: # Already an instrument set 74 | return False 75 | 76 | assert msg.type == 'program_change' 77 | 78 | self.instrument = msg.program 79 | if msg.channel == 9 or msg.program > 112: # Warning: Mido shift the channels (start at 0) 80 | self.is_drum = True 81 | 82 | return True 83 | 84 | 85 | class Song: 86 | """ Structure which encapsulate the song data 87 | """ 88 | 89 | # Define the time unit 90 | # TODO: musicdata should have possibility to modify those parameters (through self.args) 91 | # Invert of time note which define the maximum resolution for a song. Ex: 2 for 1/2 note, 4 for 1/4 of note 92 | MAXIMUM_SONG_RESOLUTION = 4 93 | NOTES_PER_BAR = 4 # Waltz not supported 94 | 95 | def __init__(self): 96 | self.ticks_per_beat = 96 97 | self.tempo_map = [] 98 | self.tracks = [] # List[Track] 99 | 100 | def __len__(self): 101 | """ Return the absolute tick when the last note end 102 | Note that the length is recomputed each time the function is called 103 | """ 104 | return max([max([n.tick + n.duration for n in t.notes]) for t in self.tracks]) 105 | 106 | def _get_scale(self): 107 | """ Compute the unit scale factor for the song 108 | The scale factor allow to have a tempo independent time unit, to represent the song as an array 109 | of dimension [key, time_unit]. Once computed, one has just to divide (//) the ticks or multiply 110 | the time units to go from one representation to the other. 111 | 112 | Return: 113 | int: the scale factor for the current song 114 | """ 115 | 116 | # TODO: Assert that the scale factor is not a float (the % =0) 117 | return 4 * self.ticks_per_beat // (Song.MAXIMUM_SONG_RESOLUTION*Song.NOTES_PER_BAR) 118 | 119 | def normalize(self, inverse=False): 120 | """ Transform the song into a tempo independent song 121 | Warning: If the resolution of the song is is more fine that the given 122 | scale, some information will be definitively lost 123 | Args: 124 | inverse (bool): if true, we reverse the normalization 125 | """ 126 | scale = self._get_scale() 127 | op = operator.floordiv if not inverse else operator.mul 128 | 129 | # TODO: Not sure why this plot a decimal value (x.66). Investigate... 130 | # print(song_length/scale) 131 | 132 | # Shifting all notes 133 | for track in self.tracks: 134 | for note in track.notes: 135 | note.tick = op(note.tick, scale) # //= or *= 136 | -------------------------------------------------------------------------------- /deepmusic/tfutils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Some functions to help define neural networks 18 | """ 19 | 20 | import tensorflow as tf 21 | 22 | 23 | def single_layer_perceptron(shape, scope_name): 24 | """ Single layer perceptron 25 | Project X on the output dimension 26 | Args: 27 | shape: a tuple (input dim, output dim) 28 | scope_name (str): encapsulate variables 29 | Return: 30 | tf.Ops: The projection operator (see project_fct()) 31 | """ 32 | assert len(shape) == 2 33 | 34 | # Projection on the keyboard 35 | with tf.variable_scope('weights_' + scope_name): 36 | W = tf.get_variable( 37 | 'weights', 38 | shape, 39 | initializer=tf.truncated_normal_initializer() # TODO: Tune value (fct of input size: 1/sqrt(input_dim)) 40 | ) 41 | b = tf.get_variable( 42 | 'bias', 43 | shape[1], 44 | initializer=tf.constant_initializer() 45 | ) 46 | 47 | def project_fct(X): 48 | """ Project the output of the decoder into the note space 49 | Args: 50 | X (tf.Tensor): input value 51 | """ 52 | # TODO: Could we add an activation function as option ? 53 | with tf.name_scope(scope_name): 54 | return tf.matmul(X, W) + b 55 | 56 | return project_fct 57 | 58 | 59 | def get_rnn_cell(args, scope_name): 60 | """ Return RNN cell, constructed from the parameters 61 | Args: 62 | args: the rnn parameters 63 | scope_name (str): encapsulate variables 64 | Return: 65 | tf.RNNCell: a cell 66 | """ 67 | with tf.variable_scope('weights_' + scope_name): 68 | rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(args.hidden_size, state_is_tuple=True) # Or GRUCell, LSTMCell(args.hidden_size) 69 | #rnn_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell, input_keep_prob=1.0, output_keep_prob=1.0) # TODO: Custom values (WARNING: No dropout when testing !!!, possible to use placeholder ?) 70 | rnn_cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell] * args.num_layers, state_is_tuple=True) 71 | return rnn_cell 72 | -------------------------------------------------------------------------------- /docs/ideas.md: -------------------------------------------------------------------------------- 1 | Just an unorganised list of ideas or models to test. 2 | 3 | * From a bars/lyrical phrases, predict the next one 4 | * Use 2 networks or more (idea from [Andy T](https://github.com/aPToul/Experiments-in-Music)): one to generate short melodies, one for the global structure. 5 | * Maybe instead try some kind of inception architecture but for RNN (multiples RNN with different parameters (LTSM/GRU, hidden size, Relu/...) trained simultaneosly) 6 | * Apply GAN (or its variant) to music generation 7 | * Test with/without attention 8 | * Recurrent DBN (compress information in the middle lstm layers): something like \[500,250,125,250,500\] for the hidden layers 9 | * Include somehow CRF or bidirectional LSTM ? 10 | * Use 1d convolution NN on 1d grid input (Mixed with RNN): relative distance between notes is more important than absolute position. Pb: how to incorporate relation between frames (look at what has been done for video) 11 | * Maybe instead use a 2d convolution on 2d grid input, reformulate the musical composition as an image generation problem (with the help of adversarial model) 12 | * Also include the tempo change as prediction. Continuously (each time steps) or as event (Always predict 0 except sometimes) ? Prediction as a multi-class classification among some predetermined class (allegro, andante,...). Multi objective function with softmax (Loss=a*LossTempo + b*LossNotes). 13 | * Try Skip-Thought Vectors, bidirectional-RNN (predict both past/future) 14 | * Encoder/decoder network: The encoder encode the keyboard disposition at time t and the decoder predict at time t+1. All this network is a cell of the global RNN network. Pb: the notes are read sequencially (solving with bidirectional RNN ?) 15 | 16 | * Training task neural art for music: play pieces with certain pattern/style. 17 | 18 | 19 | What note representation ? 20 | 21 | * Simpler model could be to have a single vector as input of 88 (keys) for each 1/16 of bars: the vector contains 0 if nothing has been played or 1 if the key has been pressed (note duration not taken into account). The song is represented by a giant matrix: (88\*(16\*nbBars)). This representation will be referred as grid input (1d or 2d if we add the temporal dimension to the input tensor). 22 | * Artificially increase dataset: transpose musical pieces ? 23 | * Midi vs ABC representation: one network just print the basic melody in ABC notation (text file as it was a char-rnn), the second network play this file as if it was improvising, playing chords and melody on some based tablature (as jazz man do). 24 | * Maybe try something closer of the musical theory (1rst, 2nd, 3rd degree instead of A,D,C) or something closer to the physic (frequency ? relative distance behind the note). Somehow the model should understand that 2 notes like C3 and C4 "feel" the sames. 25 | * Note2vec ? something which convert each note/chord into a multidimentional space (How to divide/separate chords/notes ? Pb of multiples representations: a same chord can be played in arpegio or using more complex partern. Should be a multilevel representation) 26 | 27 | 28 | 29 | Train conjointly recurrent CNN for the spacial dependency (use lstm containing a CNN ?) AND a standard RNN for the absolute position (best of both words ?) At the end a fully connected layer mix the two outputs to produce the final result: the CNN provide the pattern and the standard network provider the position (use simple res net) 30 | 31 | Look at deconvolution 32 | ResNet 33 | 34 | Input grid 2d: use more channels to represent velocity/duration ? 35 | Or simply the value of the cell represent the duration instead of binary 36 | 37 | Training: Play one bars and only after start to backpropagate for each steps (no need to penalize when the network has no way to know if he is doing right or wrong, meaning at the beginning). The idea is that the first bar is given for free just as initiation/setup (prendre de l'élant). 38 | For training, force the correct answer for each timestep 39 | 40 | Visualize the filters of the CNN. 41 | 42 | Walkthrough: 43 | 44 | * At first try simple midi (Joplin, pop, Bach, Mozart). Then try more complex composer (Chopin, Rachmaninoff < difficult for the change of tempo within the play) 45 | 46 | 47 | What cost fct ? 48 | 49 | Eventually, the penalty should be less if the network predict the same note but not in the right pitch (ex: C4 instead of C5), with a decay the further the prediction is (D5 and D1 more penalized than D4 and D3 if the target is D2) 50 | A first simple solution could be to try to optimize 2 task conjointly (the binary classification on all keyboards key and one on the notes % 12). Pb is: How to compute the prediction for the %12 from the global keyboard prediction ? A simple way could be to simply add the prediction of each note (P(C)=sigm(C1 + C2 + C3 +...)) 51 | 52 | 53 | 54 | TODO 55 | 56 | Plot piano roll image while training with some samples of the train and train set (every 1000 iterations). Saved on a subfolder (training/) ? 57 | 58 | Try learning on multiple sample length at the same time (short, longer) 59 | 60 | OpenCv 61 | When testing, plot the prediction color map and the ground truth conjointly. Do it for training/testing/generatives songs 62 | 63 | Apply k neighbors to find similar segment in the dataset 64 | 65 | Include sample length dans le titre des sons predits? 66 | 67 | 68 | 69 | 70 | 71 | Notes should be modulo 12 (no distinction between C3 and C5). 72 | Limitation of CNN at the boundaries (solution: try cycling: copy notes at the boundaries. Slide the kernel until it has done a cycle). Try 12*4 kernels for the CNN > contains chords 73 | 74 | CNN is here to learn chords/patterns 75 | 76 | Use a 2d cnn one dimension for the chords, the other for the pattern (alberti bass); or use a RNN for the pattern instead 77 | 78 | If a bar is 4 tics. Divide create 4 images for each one of the tics ? 79 | Other solution is use the cnn/RNN structure the CNN has a temporal resolution of 2 tics for instance. The RNN part has a temporal resolution of 1/4 of tics. That's mean at each RNN step, the network re-get some information he has seen on previous step. 80 | 81 | One of the output of the neural networks control the bpm. The cnn/network itself don't take care of the speed (1 tic is 1 tic) but when playing, the network can send signals to increase/decrease the tempo (maybe 5 prebuild tempo and predition as classification problem though softmax). 82 | 83 | Tracks as channel ?? 84 | 85 | Instead of randomly splitting the tracks, only split when there is a bar !! With that, the prediction will be synchronised add probably more 'clean' 86 | 87 | Try shorter sequences (4, 8, 16 ?) For the sample length 88 | 89 | Try tu train the network with variables length sentences 90 | 91 | Having a 2 way RNN network. At each timestep, first, the network slide over the notes to encode the entire keyboard configuration. A decoder predict the next configuration. The decoder output and a second state vector are connected to the next encoder. Should be difficult to optimize due to the deep of the network (88*sample_length steps for one timestep). 92 | Try to use some synthetic gradient tricks ? We backpropagate on each timestep independently. The state vector of the previous step is a learning parameter (how to connect both sides !???!) 93 | 94 | GAN: Use RNN to generate piano roll(image). Then Cnn to discriminate 95 | 96 | Select sample in function of the length of the song and samples sequences (otherwise bias towards short songs): something like nb_sample = 2*song_length//sample_ length 97 | 98 | Load starter sequence for testing (read initial notes from files): input sequence given from which the network will predict what's comes next. 99 | 100 | Control loop fct with List\[boolean placeholder\] (use_previous) 101 | 102 | Keep track of ratio update/weights 103 | Keep track of magnitude of some values (internal state of rnn, weights...) 104 | magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_memory\[1\]))) 105 | tf.scalar_summary("magnitude at t=1", magnitude) 106 | 107 | tf.train.GradientDescentOptimizer class (and related Optimizer classes) call tf.gradients() internally. If you want to access the gradients that are computed for the optimizer, you can call optimizer.compute_gradients() and optimizer.apply_gradients() manually, instead of calling optimizer.minimize() 108 | 109 | 110 | How to monitor dead unit relu ?? 111 | 112 | Try peephole with lstm, apparently better timing (*Learning Precise Timing with LSTM Recurrent Networks*, ... et al.) 113 | 114 | Try rnn pixel like: predict P(xt|xt-1,…,h) 115 | Instead of taking a vector corresponding to a keyboard configuration, each input is just a note (pitch, distance from previous) 116 | 117 | Try changing the note representation: instead of a 88*1, do a 12*nbOctave 118 | 119 | 120 | ## Tools 121 | 122 | Some python library for midi file manipulation 123 | * Mido: seems the best one now for low level manipulation. Close the the original specs (used in this program). 124 | * pretty_midi: higher level (piano rolls function could be really handy). Not python 3 compatible, had to program myself the piano roll conversion. 125 | -------------------------------------------------------------------------------- /docs/imgs/basic_rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/imgs/basic_rnn.png -------------------------------------------------------------------------------- /docs/imgs/endecell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/imgs/endecell.png -------------------------------------------------------------------------------- /docs/imgs/training_begin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/imgs/training_begin.png -------------------------------------------------------------------------------- /docs/imgs/training_end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/imgs/training_end.png -------------------------------------------------------------------------------- /docs/midi.md: -------------------------------------------------------------------------------- 1 | ## 2 | 3 | For this project, due to the lack of open sources python 3 compatibles library, I quickly implemented a simple higher level lib based on [mido](https://github.com/olemb/mido) to read/write midi files. It's really basic so don't expect to support the full midi specification but for simple songs, it's quite efficient. Here is an example to generate a new song: 4 | 5 | ```python 6 | import deepmusic.songstruct as music 7 | from deepmusic.midiconnector import MidiConnector 8 | 9 | test_song = music.Song() 10 | main_track = music.Track() 11 | 12 | for i in range(44): # Add some notes 13 | new_note = music.Note() 14 | 15 | new_note.note = (i%2)*(21+i) +((i+1)%2)*(108-i) 16 | new_note.duration = 32 17 | new_note.tick = 32*i # Absolute time in tick from the begining 18 | 19 | main_track.notes.append(new_note) 20 | 21 | test_song.tracks.append(main_track) 22 | MidiConnector.write_song(test_song, 'data/midi/test.mid') 23 | ``` 24 | -------------------------------------------------------------------------------- /docs/midi/basic_rnn_joplin.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/midi/basic_rnn_joplin.mid -------------------------------------------------------------------------------- /docs/midi/basic_rnn_ragtime_-38000-0-C4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/midi/basic_rnn_ragtime_-38000-0-C4.mid -------------------------------------------------------------------------------- /docs/midi/basic_rnn_ragtime_BaseStructure.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/midi/basic_rnn_ragtime_BaseStructure.mid -------------------------------------------------------------------------------- /docs/midi/basic_rnn_ragtime_TempoChange.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/midi/basic_rnn_ragtime_TempoChange.mid -------------------------------------------------------------------------------- /docs/midi/basic_rnn_ragtime_structure.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Conchylicultor/MusicGenerator/adea76dccaba923b7d3807082ec6f5b512d16bb9/docs/midi/basic_rnn_ragtime_structure.mid -------------------------------------------------------------------------------- /docs/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | Here are presented my experiments and the models I used. 4 | 5 | ## Basic RNN 6 | 7 | As baseline, I tried a simple RNN model (2 LSTM layers). Given a keyboard configuration, the network tries to predict the next one (this architecture is similar to the famous [Char-RNN](https://github.com/karpathy/char-rnn) model, used to generate English sentences). I formulate the prediction as a 88-binary classification problems: for each note on the keyboard, the network try to guess if the note is pressed or released. Because all classifications are not mutually exclusive (two keys can be pressed at the same time), I use a sigmoid cross entropy instead of softmax. For this first try, a lot of assumptions have been made on the song (only 4/4 for signature, quarter-note as maximum note resolution, no tempo changes or other difficulties that the model could not handle). 8 | 9 | ![Basic RNN](imgs/basic_rnn.png) 10 | 11 | *Base RNN structure. The network generate a sequence of keyboard configuration.* 12 | 13 | At first, I try to trained this model on 3 Scott Joplin songs. I choose Ragtime music for the first test because Ragtime songs have a really rigid and well defined structure, and the songs satisfied the assumptions made above. Each song is slitted in small parts and I randomly shuffle the parts between the songs so the network learn simultaneously different songs. 14 | 15 | Because there is a high bias toward the negative class (key not pressed), I was a little afraid that the network would only predict empty songs. It doesn't seems to be the case. On the contrary, the network clearly overfit. When analysing the generated song, we see that the network has memorized entire parts of the original songs. This simple model was mainly a test to validate the fact that a small/medium sized network can encode some rhythmic and harmonic information. Of course, the database wasn't big enough to have a truly original artificial composer, but I'm quite impress by the capability of the network to learn by heart the song (blessed be the LSTM). 16 | 17 | Usually the first notes are quite randoms but after few bars, the network stabilize over a part of the song he remember. When trying to generate sequences longer than the training ones, the network will simply loop over some parts and play it indefinitely. You can listen one of the generated song [here](https://soundcloud.com/reivalk/basic-rnn-joplin-example-overfitting?in=reivalk/sets/music-generator-experiments). We can clearly recognize parts of [Rag Dance](https://youtu.be/tCrj1s1iVas). 18 | 19 | ![Training piano roll](imgs/training_begin.png) ![Training piano roll](imgs/training_end.png) 20 | 21 | *Piano rolls of the predictions. During the first iterations, the networks only predict very repetitive patterns. When training longer, the patterns become more complex* 22 | 23 | I then applied this model on a larger dataset (400+ songs). Here are some samples of generated songs with this model: 24 | * [Sample 1](https://soundcloud.com/reivalk/basic-rnn-ragtime-1?in=reivalk/sets/music-generator-experiments) 25 | * [Sample 2](https://soundcloud.com/reivalk/basic-rnn-ragtime-2?in=reivalk/sets/music-generator-experiments) 26 | 27 | The originals generated midi files can be found in [this folder](midi/). I convert them to mp3 using the utility script present in the root directory. 28 | 29 | There are some problems with this model. One being that even if trained conjointly, the predictions are done independently. Each one of the 88 classifier do not concert the other ones before doing its prediction so for instance the network's part which predict if a E4 should be played has no way to know if G2 is predicted at the same time or not (one way to solve this issue could be to use CRF). Similarly, the neural network don't care about the relative order of the notes in the input vector. We could invert the keys on the keyboard and the prediction would be the same. To improve our model, the model should have a way to know that the interval between C4 and C4# is the same that the one between A3 and A3#. There are some architecture which could be worth exploring. 30 | 31 | One other problem, maybe a little more technical. When the network see a new sequence, it start from a clear memory, like it had never seen anything before and will progressively build and encode what it 'sees' to its internal state. It's impossible for the network to correctly guess the first notes, when internal state is empty (as it is reset at each iteration), so during the first timestep the network has no knowledge on the song it his suppose to predict. A way this issue has been solved on some other model is to use another network called encoder which basically will compute the internal state to give at the network for its first step. Here, using a separate network for the encoding would be cheating, because to generate new songs, we don't want our model to rely on an encoder which would have "prepared" the song to play. The solution I tried is simply to ignore the first prediction (the first steps are just here so the network can encode tonality, rhythm,...) and gradually penalize the mistakes more and more as we progress through the steps. Basically what this change do is telling the network that it's less important to make a mistake during the first timestep when the structure of the song is completely unknown that at the end when it should have integrated the rhythm, tonality,... It's somehow similar to the encoder/decoder architecture, but using a single network for both. 32 | 33 | ## EnDeCell 34 | 35 | In order to solve some of those issues, I tried to develop a more audacious architecture. One of the most important thing in music the spacial relationship/distance between the different notes, so the model should somehow capture that relation. In order to integrate those information, I add a system of encoder/decoder block which will do exactly that. Because music is made of different regular patterns (major chords, minor chords,...), those could potentially be compressed and encapsulated in the encoder output. The role of the encoder is to count the distance between the notes and compress that information into a single state vector. The decoder on the other end does exactly the opposite, generate a keyboard configuration from a state vector. Another way to see it, what I try to create is some kind of embedding for the keyboards configurations, where each keyboard configuration would be projected on a vector space which would represent the 'semantic' of the keyboard. The LSTMs goal would be to links those configurations together. 36 | 37 | Potentially any kind of network could fit inside the enco/deco blocks. A simple CNN for instance could potentially learn filters which could represent chords, but because CNN learn translation invariant filter, we would loose the chord position information. I chose instead two RNNs for encoding/decoding the network similar as the original seq2seq model, which would read the keyboard note by note. 38 | 39 | ![EnDeCell](imgs/endecell.png) 40 | 41 | *The architecture of the EnDeCell RNN. Each cell contain an encoder/decoder block. Each enco/deco block contain a RNN which capture the keyboard state.* 42 | 43 | Because of the complexity of the model (network inside network) and to simplify the training, it should be possible to first train the encoder/decoder separately on some keyboard configuration samples and then integrate the weights into the complete model. 44 | 45 | I had big ambitions for this model, and a lot of variant to test. But after some preliminary test, I must face the reality that I simply don't have the hardware to train it. Indeed the number of recurrent step is simply too big to handle (nb_song_time_step*nb_notes*2 per iteration). However, I keep this model in my head for the day where computational power will be more affordable. 46 | 47 | Some solution to make this model computationally more efficient and to reduce the RNN length could be to make it predict multiple song time steps at once (ex: predicting the next bar from the previous one), or for the encoder to encode multiple keys at once. 48 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2016 Conchylicultor. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | """ 19 | Main script. See README.md for more information 20 | 21 | Use python 3 22 | """ 23 | 24 | import deepmusic 25 | 26 | 27 | if __name__ == "__main__": 28 | composer = deepmusic.Composer() 29 | composer.main() 30 | -------------------------------------------------------------------------------- /save/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2016 Conchylicultor. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | 18 | """ 19 | Some utilities functions, to easily manipulate large volume of downloaded files. 20 | Independent of the main program but useful to extract/create the dataset 21 | """ 22 | 23 | import os 24 | import subprocess 25 | import glob 26 | 27 | 28 | def extract_files(): 29 | """ Recursively extract all files from a given directory 30 | """ 31 | input_dir = '../www.chopinmusic.net/' 32 | output_dir = 'chopin_clean/' 33 | 34 | os.makedirs(output_dir, exist_ok=True) 35 | 36 | print('Extracting:') 37 | i = 0 38 | for filename in glob.iglob(os.path.join(input_dir, '**/*.mid'), recursive=True): 39 | print(filename) 40 | os.rename(filename, os.path.join(output_dir, os.path.basename(filename))) 41 | i += 1 42 | print('{} files extracted.'.format(i)) 43 | 44 | 45 | def rename_files(): 46 | """ Rename all files of the given directory following some rules 47 | """ 48 | input_dir = 'chopin/' 49 | output_dir = 'chopin_clean/' 50 | 51 | assert os.path.exists(input_dir) 52 | os.makedirs(output_dir, exist_ok=True) 53 | 54 | list_files = [f for f in os.listdir(input_dir) if f.endswith('.mid')] 55 | 56 | print('Renaming {} files:'.format(len(list_files))) 57 | for prev_name in list_files: 58 | new_name = prev_name.replace('midi.asp?file=', '') 59 | new_name = new_name.replace('%2F', '_') 60 | print('{} -> {}'.format(prev_name, new_name)) 61 | os.rename(os.path.join(input_dir, prev_name), os.path.join(output_dir, new_name)) 62 | 63 | 64 | def convert_midi2mp3(): 65 | """ Convert all midi files of the given directory to mp3 66 | """ 67 | input_dir = 'docs/midi/' 68 | output_dir = 'docs/mp3/' 69 | 70 | assert os.path.exists(input_dir) 71 | os.makedirs(output_dir, exist_ok=True) 72 | 73 | print('Converting:') 74 | i = 0 75 | for filename in glob.iglob(os.path.join(input_dir, '**/*.mid'), recursive=True): 76 | print(filename) 77 | in_name = filename 78 | out_name = os.path.join(output_dir, os.path.splitext(os.path.basename(filename))[0] + '.mp3') 79 | command = 'timidity {} -Ow -o - | ffmpeg -i - -acodec libmp3lame -ab 64k {}'.format(in_name, out_name) # TODO: Redirect stdout to avoid polluting the screen (have cleaner printing) 80 | subprocess.call(command, shell=True) 81 | i += 1 82 | print('{} files converted.'.format(i)) 83 | 84 | 85 | if __name__ == '__main__': 86 | convert_midi2mp3() 87 | --------------------------------------------------------------------------------