├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── environment.yml ├── images └── sing.png ├── nsynth_100_test.txt ├── requirements.txt ├── setup.py └── sing ├── __init__.py ├── ae ├── __init__.py ├── models.py ├── trainer.py └── utils.py ├── dsp.py ├── fondation ├── __init__.py ├── batch.py ├── datasets.py ├── trainer.py └── utils.py ├── generate.py ├── nsynth ├── __init__.py └── examples.json.gz ├── parser.py ├── sequence ├── __init__.py ├── models.py ├── trainer.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .* 3 | models 4 | data -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SING 2 | 3 | ## Pull Requests 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need 6 | to do this once to work on any of Facebook's open source projects. 7 | 8 | Complete your CLA here: 9 | 10 | SING is the implementation of a research paper. 11 | Therefore, we do not plan on accepting many pull requests for new features. 12 | We certainly welcome them for bug fixes. 13 | 14 | 15 | ## Issues 16 | 17 | We use GitHub issues to track public bugs. Please ensure your description is 18 | clear and has sufficient instructions to be able to reproduce the issue. 19 | 20 | 21 | ## License 22 | By contributing to this repository, you agree that your contributions will be licensed 23 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include nsynth/examples.json.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SING: Symbol-to-Instrument Neural Generator 2 | 3 | SING is a deep learning based music notes synthetizer that can be trained on the 4 | [NSynth dataset][nsynth]. 5 | Despite being 32 times faster to train and 2,500 faster for inference, 6 | SING produces audio with significantly improved perceptual quality compared to 7 | the NSynth wavenet-like autoencoder [[1]](#ref_nsynth) as measured by 8 | Mean Opinion Scores based on human evaluations. 9 | 10 | The architecture and results obtained are detailed in our paper 11 | [SING: Symbol-to-Instrument Neural Generator][sing_nips]. 12 | SING is based on a LSTM based sequence generator and a 13 | convolutional decoder: 14 | 15 |

16 | Schema representing the structure of SING. A LSTM is followed by a convolutional decoder

17 | 18 | 19 | 20 | ## Requirements 21 | 22 | SING works with python3.6 and newest. 23 | To use SING, you must have decently recent version of the following 24 | package installed: 25 | 26 | - numpy 27 | - requests 28 | - pytorch (needs to be >= 4.1.0 as we use torch.stft) 29 | - scipy 30 | - tqdm 31 | 32 | If you have anaconda installed, you can run from the root of this repository: 33 | 34 | conda env update 35 | conda activate sing 36 | 37 | This will create a `sing` environmnent with all the dependencies installed. 38 | Alternatively, you can use pip to install those: 39 | 40 | pip3 install -r requirements.txt 41 | 42 | 43 | SING can optionally be installed using the usual `setup.py` 44 | although this is not required. 45 | 46 | ### Obtaining the NSynth dataset 47 | 48 | If you want to train SING from scratch, you will need a copy of the NSynth 49 | dataset [[1]](#ref_nsynth). To download it, you use the following instructions 50 | (**WARNING**, NSynth is 30GB so this will take a bit of time): 51 | 52 | mkdir data && cd data &&\ 53 | wget http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-train.jsonwav.tar.gz &&\ 54 | tar xf nsynth-train.jsonwav.tar.gz 55 | 56 | 57 | ## Using SING 58 | 59 | Once installed or from the root of this repository, you can use a family 60 | of commands detailed hereafter of the form 61 | 62 | python3 -m sing.* 63 | 64 | 65 | ### Common flags 66 | 67 | For either training or generation, use the `--cuda` flag for GPU acceleration 68 | and `--parallel` flag to use all available GPUs. Depending on the memory 69 | and number of GPUs available, consider tweaking the batch size using the 70 | `--batch-size` flag. The default is 64 but 256 was used in the paper. 71 | 72 | 73 | ### Training 74 | 75 | If you already have the NSynth dataset downloaded somewhere, run 76 | 77 | python3 -m sing.train [--cuda [--parallel]] --data PATH_TO_NSYNTH \ 78 | --output PATH_TO_SING_MODEL [--checkpoint PATH_TO_CHECKPOINTS] 79 | 80 | `PATH_TO_NSYNTH` is by default set to `data/nsynth-train`. 81 | The final model will be saved at `PATH_TO_SING_MODEL` (default is `models/sing.th`). If you want 82 | to save checkpoints after each epoch, or to resume a previously interrupted 83 | training, use the `--checkpoint` option. 84 | 85 | ### Generation 86 | 87 | For generation, you do not need the NSynth dataset but you should have a trained SING model. 88 | 89 | python3 -m sing.generate [--cuda [--parallel]] \ 90 | --model PATH_TO_SING_MODEL PATH_TO_ITEM_LIST 91 | 92 | `PATH_TO_ITEM_LIST` should be a file with one dataset item name per list, 93 | for instance `organ_electronic_044-055-127`. 94 | 95 | Alternatively, you can download a pretrained model using 96 | 97 | python3 -m sing.generate [--cuda [--parallel]] --dl PATH_TO_ITEM_LIST 98 | 99 | By default, the model will be downloaded under `models/sing.th` but a 100 | different path can be provided using the `--model` option. 101 | The pretrained model can be directly download [here](https://dl.fbaipublicfiles.com/sing/sing.th). 102 | 103 | ### Results reproduction 104 | 105 | To reproduce the results of Table 1 in our paper, simply run 106 | 107 | ```bash 108 | # For the L1 spectral losss 109 | python3 -m sing.train [--cuda [--parallel]] --l1 110 | # For the L1 spectral loss without time embeddings 111 | python3 -m sing.train [--cuda [--parallel]] --l1 --time-dim=0 112 | # For the Wav loss 113 | python3 -m sing.train [--cuda [--parallel]] --wav 114 | ``` 115 | 116 | To reproduce the audio samples used for the human evaluations, simply run 117 | from the root of the git repository 118 | 119 | python3 -m sing.generate [--cuda [--parallel]] --dl nsynth_100_test.txt 120 | 121 | The file `nsynth_100_test.txt` has been generated using the following code: 122 | 123 | ```python 124 | from sing import nsynth 125 | from sing.fondation.datasets import RandomSubset 126 | dset = nsynth.get_nsynth_metadata() 127 | train, valid, test = nsynth.make_datasets(dset) 128 | 129 | evaluation = RandomSubset(test, 100) 130 | open("nsynth_100_test.txt", "w").write("\n".join( 131 | evaluation[i].metadata['name'] for i in range(len(evaluation)))) 132 | ``` 133 | 134 | ## Generated audio 135 | 136 | A comparison of audio samples generated by SING and the NSynth Wavenet based autoencoder [[1]](#ref_nsynth) 137 | is available on [the paper webpage](https://research.fb.com/wp-content/themes/fb-research/research/sing-paper/). 138 | 139 | 140 | ## Thanks 141 | 142 | We thank the Magenta team for their inspiring work on NSynth. 143 | 144 | ## License 145 | 146 | For conveniance we have included a copy of the metadata of the NSynth dataset 147 | in this repository. The dataset has been released by Google Inc 148 | under the Creative Commons Attribution 4.0 International (CC BY 4.0) license. 149 | 150 | SING is released under Creative Commons Attribution 4.0 International 151 | (CC BY 4.0) license, as found in the LICENSE file. 152 | 153 | ## Bibliography 154 | 155 | [1]: Jesse Engel, Cinjon Resnick, Adam Roberts, 156 | Sander Dieleman, Douglas Eck, 157 | Karen Simonyan, and Mohammad Norouzi. [Neural Audio Synthesis of Musical Notes with WaveNet Autoencoders](https://arxiv.org/pdf/1704.01279.pdf). 2017. 158 | 159 | 160 | 161 | [nsynth]: https://magenta.tensorflow.org/datasets/nsynth 162 | [sing_nips]: https://research.fb.com/publications/sing-symbol-to-instrument-neural-generator 163 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sing 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - numpy>=1.15 6 | - python>=3.6 7 | - pytorch>=0.4.1 8 | - requests>=2.19 9 | - scipy>=1.1 10 | - tqdm>=4.26 -------------------------------------------------------------------------------- /images/sing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SING/72054bdb23b4ced393c0d435c124db64c1e4cb26/images/sing.png -------------------------------------------------------------------------------- /nsynth_100_test.txt: -------------------------------------------------------------------------------- 1 | bass_synthetic_126-025-025 2 | synth_lead_synthetic_006-045-050 3 | bass_synthetic_065-070-025 4 | keyboard_electronic_063-087-100 5 | reed_acoustic_031-057-050 6 | bass_synthetic_044-094-025 7 | bass_synthetic_095-091-050 8 | bass_synthetic_087-091-127 9 | mallet_electronic_011-074-127 10 | mallet_electronic_007-059-100 11 | brass_acoustic_049-082-050 12 | keyboard_electronic_000-091-075 13 | string_acoustic_048-062-075 14 | string_acoustic_008-033-075 15 | keyboard_electronic_026-055-127 16 | keyboard_electronic_070-049-025 17 | organ_electronic_050-105-127 18 | string_acoustic_044-060-025 19 | bass_synthetic_064-092-075 20 | organ_electronic_085-083-127 21 | mallet_acoustic_013-071-075 22 | keyboard_electronic_026-098-100 23 | mallet_acoustic_002-060-127 24 | keyboard_electronic_100-039-100 25 | bass_synthetic_120-071-100 26 | organ_electronic_061-044-127 27 | vocal_acoustic_023-057-127 28 | string_acoustic_030-051-025 29 | brass_acoustic_003-041-050 30 | bass_synthetic_093-052-075 31 | organ_electronic_016-026-100 32 | organ_electronic_080-025-100 33 | brass_acoustic_001-054-127 34 | guitar_acoustic_009-069-100 35 | brass_acoustic_014-061-050 36 | keyboard_acoustic_001-061-100 37 | organ_electronic_077-047-100 38 | bass_synthetic_130-056-075 39 | guitar_electronic_032-081-127 40 | mallet_synthetic_001-066-100 41 | keyboard_electronic_042-057-050 42 | bass_synthetic_052-059-025 43 | keyboard_acoustic_002-062-075 44 | guitar_acoustic_019-074-025 45 | bass_synthetic_128-097-100 46 | guitar_electronic_004-038-050 47 | bass_synthetic_086-048-100 48 | keyboard_electronic_055-068-127 49 | guitar_electronic_023-046-100 50 | guitar_electronic_015-065-025 51 | bass_synthetic_015-098-127 52 | brass_acoustic_035-058-127 53 | bass_synthetic_063-036-025 54 | reed_acoustic_033-044-050 55 | organ_electronic_065-041-100 56 | bass_synthetic_111-036-100 57 | organ_electronic_037-047-050 58 | bass_synthetic_140-108-050 59 | brass_acoustic_040-059-025 60 | organ_electronic_092-042-025 61 | keyboard_electronic_019-099-025 62 | reed_acoustic_029-077-127 63 | string_acoustic_043-027-025 64 | bass_synthetic_121-043-127 65 | string_acoustic_007-060-025 66 | keyboard_electronic_015-089-050 67 | organ_electronic_050-057-025 68 | bass_synthetic_078-045-075 69 | keyboard_electronic_088-081-075 70 | brass_acoustic_000-039-127 71 | guitar_electronic_017-093-075 72 | bass_synthetic_117-068-127 73 | mallet_electronic_013-087-050 74 | flute_acoustic_012-090-127 75 | bass_synthetic_136-073-025 76 | mallet_electronic_006-085-127 77 | mallet_acoustic_016-082-050 78 | organ_electronic_044-025-025 79 | mallet_acoustic_046-032-127 80 | guitar_acoustic_032-067-100 81 | organ_electronic_014-084-127 82 | organ_electronic_098-030-025 83 | mallet_acoustic_042-103-050 84 | keyboard_electronic_065-093-127 85 | mallet_acoustic_058-108-050 86 | mallet_electronic_007-036-127 87 | keyboard_electronic_092-090-127 88 | string_acoustic_059-070-025 89 | guitar_electronic_009-054-100 90 | bass_synthetic_044-075-050 91 | mallet_acoustic_051-065-075 92 | bass_electronic_030-047-100 93 | flute_acoustic_027-077-050 94 | bass_synthetic_123-086-050 95 | bass_synthetic_117-086-100 96 | mallet_acoustic_004-065-100 97 | bass_synthetic_094-023-025 98 | organ_electronic_020-054-025 99 | brass_acoustic_011-056-050 100 | keyboard_electronic_058-038-127 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | requests 3 | scipy 4 | torch>=0.4.1 5 | tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) 2018-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # Inspired from https://github.com/kennethreitz/setup.py 9 | 10 | from pathlib import Path 11 | 12 | from setuptools import find_packages, setup 13 | 14 | NAME = 'sing' 15 | DESCRIPTION = 'SING: Symbol-to-Instrument Neural Generator' 16 | URL = 'https://github.com/facebookresearch/SING' 17 | EMAIL = 'defossez@fb.com' 18 | AUTHOR = 'Alexandre Defossez' 19 | REQUIRES_PYTHON = '>=3.6.0' 20 | VERSION = "1.0" 21 | 22 | HERE = Path(__file__).parent 23 | 24 | REQUIRED = [i.strip() for i in open(HERE / "requirements.txt").readlines()] 25 | 26 | try: 27 | with open(HERE / "README.md", encoding='utf-8') as f: 28 | long_description = '\n' + f.read() 29 | except FileNotFoundError: 30 | long_description = DESCRIPTION 31 | 32 | setup( 33 | name=NAME, 34 | version=VERSION, 35 | description=DESCRIPTION, 36 | long_description=long_description, 37 | long_description_content_type='text/markdown', 38 | author=AUTHOR, 39 | author_email=EMAIL, 40 | python_requires=REQUIRES_PYTHON, 41 | url=URL, 42 | packages=find_packages(), 43 | install_requires=REQUIRED, 44 | include_package_data=True, 45 | license='Creative Common Attribution-NonCommercial 4.0 International', 46 | classifiers=[ 47 | # Trove classifiers 48 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 49 | 'Programming Language :: Python', 50 | 'Programming Language :: Python :: 3', 51 | 'Programming Language :: Python :: 3.6', 52 | 'Programming Language :: Python :: Implementation :: CPython', 53 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 54 | ], 55 | ) 56 | -------------------------------------------------------------------------------- /sing/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | -------------------------------------------------------------------------------- /sing/ae/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | -------------------------------------------------------------------------------- /sing/ae/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from torch import nn 10 | 11 | from .utils import WindowedConv1d, WindowedConvTranpose1d 12 | 13 | 14 | class ConvolutionalDecoder(nn.Module): 15 | """ 16 | Convolutional decoder that takes a downsampled embedding and turns it 17 | into a waveform. 18 | Together with :class:`ConvolutionalEncoder`, it forms a 19 | :class:`ConvolutionalAE` 20 | 21 | Arguments: 22 | channels (int): number of channels accross all the inner layers 23 | stride (int): stride of the final :class:`nn.ConvTranspose1d` 24 | dimension (int): dimension of the embedding 25 | kernel_size (int): size of the kernel of the final 26 | :class:`nn.ConvTranspose1d` 27 | context_size (int): kernel size of the first convolution, 28 | this is called a context as one can see it as providing 29 | information about the previous and following embeddings 30 | rewrite_layers (int): after the first convolution, perform 31 | `rewrite_layers` `1x1` convolutions 32 | window_name (str or None): name of the window used to smooth 33 | the convolutions. See :func:`sing.dsp.get_window` 34 | squared_window (bool): if `True`, square the smoothing window 35 | """ 36 | 37 | def __init__(self, 38 | channels=4096, 39 | stride=256, 40 | dimension=128, 41 | kernel_size=1024, 42 | context_size=9, 43 | rewrite_layers=2, 44 | window_name="hann", 45 | squared_window=True): 46 | super(ConvolutionalDecoder, self).__init__() 47 | layers = [] 48 | layers.extend([ 49 | nn.Conv1d( 50 | in_channels=dimension, 51 | out_channels=channels, 52 | kernel_size=context_size), 53 | nn.ReLU() 54 | ]) 55 | for rewrite in range(rewrite_layers): 56 | layers.extend([ 57 | nn.Conv1d( 58 | in_channels=channels, out_channels=channels, 59 | kernel_size=1), 60 | nn.ReLU() 61 | ]) 62 | 63 | conv_tr = nn.ConvTranspose1d( 64 | in_channels=channels, 65 | out_channels=1, 66 | kernel_size=kernel_size, 67 | stride=stride, 68 | padding=kernel_size - stride) 69 | if window_name is not None: 70 | conv_tr = WindowedConvTranpose1d(conv_tr, window_name, 71 | squared_window) 72 | layers.append(conv_tr) 73 | self.layers = nn.Sequential(*layers) 74 | self.context_size = context_size 75 | self.stride = stride 76 | self.kernel_size = kernel_size 77 | self.strip = kernel_size - stride + (context_size - 1) * stride // 2 78 | 79 | def __repr__(self): 80 | return "ConvolutionalDecoder({})".format(repr(self.layers)) 81 | 82 | def forward(self, embeddings): 83 | return self.layers.forward(embeddings).squeeze(1) 84 | 85 | def wav_length(self, embedding_length): 86 | """ 87 | Given an embedding of a certain size `embedding_length`, 88 | returns the length of the wav that would be generated from it. 89 | """ 90 | return (embedding_length - self.context_size + 2 91 | ) * self.stride - self.kernel_size 92 | 93 | def embedding_length(self, wav_length): 94 | """ 95 | Return the embedding length necessary to generate a wav of length 96 | `wav_length`. 97 | """ 98 | return self.context_size - 2 + ( 99 | wav_length + self.kernel_size) // self.stride 100 | 101 | 102 | class ConvolutionalEncoder(nn.Module): 103 | """ 104 | Convolutional encoder that takes a waveform and turns it 105 | into a downsampled embedding. 106 | Together with :class:`ConvolutionalDecoder`, it forms a 107 | :class:`ConvolutionalAE` 108 | 109 | Arguments: 110 | channels (int): number of channels accross all the inner layers 111 | stride (int): stride of the initial :class:`nn.Conv1d` 112 | dimension (int): dimension of the embedding 113 | kernel_size (int): size of the kernel of the initial 114 | :class:`nn.Conv1d` 115 | rewrite_layers (int): after the first convolution, perform 116 | `rewrite_layers` `1x1` convolutions. 117 | window_name (str or None): name of the window used to smooth 118 | the convolutions. See :func:`sing.dsp.get_window` 119 | squared_window (bool): if `True`, square the smoothing window 120 | """ 121 | 122 | def __init__(self, 123 | channels=4096, 124 | stride=256, 125 | dimension=128, 126 | kernel_size=1024, 127 | rewrite_layers=2, 128 | window_name="hann", 129 | squared_window=True): 130 | super(ConvolutionalEncoder, self).__init__() 131 | layers = [] 132 | conv = nn.Conv1d( 133 | in_channels=1, 134 | out_channels=channels, 135 | kernel_size=kernel_size, 136 | stride=stride) 137 | if window_name is not None: 138 | conv = WindowedConv1d(conv, window_name, squared_window) 139 | layers.extend([conv, nn.ReLU()]) 140 | for rewrite in range(rewrite_layers): 141 | layers.extend([ 142 | nn.Conv1d( 143 | in_channels=channels, out_channels=channels, 144 | kernel_size=1), 145 | nn.ReLU() 146 | ]) 147 | 148 | layers.append( 149 | nn.Conv1d( 150 | in_channels=channels, out_channels=dimension, kernel_size=1)) 151 | self.layers = nn.Sequential(*layers) 152 | 153 | def __repr__(self): 154 | return "ConvolutionalEncoder({!r})".format(self.layers) 155 | 156 | def forward(self, signal): 157 | return self.layers.forward(signal.unsqueeze(1)) 158 | 159 | 160 | class ConvolutionalAE(nn.Module): 161 | """ 162 | Convolutional autoencoder made from :class:`ConvolutionalEncoder` and 163 | :class:`ConvolutionalDecoder`. 164 | 165 | Arguments: 166 | channels (int): number of channels accross all the inner layers 167 | stride (int): downsampling stride going from the waveform 168 | to the embedding 169 | dimension (int): dimension of the embedding 170 | kernel_size (int): kernel size of the initial convolution 171 | and last conv transpose 172 | context_size (int): kernel size of the first 173 | convolution of the decoder 174 | rewrite_layers (int): after the first convolution, perform 175 | `rewrite_layers` `1x1` convolutions, both in the encoder 176 | and decoder. 177 | window_name (str or None): name of the window used to smooth 178 | the convolutions. See :func:`sing.dsp.get_window` 179 | squared_window (bool): if `True`, square the smoothing window 180 | """ 181 | 182 | def __init__(self, 183 | channels=4096, 184 | stride=256, 185 | dimension=128, 186 | kernel_size=1024, 187 | context_size=9, 188 | rewrite_layers=2, 189 | window_name="hann", 190 | squared_window=True): 191 | super(ConvolutionalAE, self).__init__() 192 | self.encoder = ConvolutionalEncoder( 193 | channels=channels, 194 | stride=stride, 195 | dimension=dimension, 196 | kernel_size=kernel_size, 197 | rewrite_layers=rewrite_layers, 198 | window_name=window_name, 199 | squared_window=squared_window) 200 | self.decoder = ConvolutionalDecoder( 201 | channels=channels, 202 | stride=stride, 203 | dimension=dimension, 204 | kernel_size=kernel_size, 205 | context_size=context_size, 206 | rewrite_layers=rewrite_layers, 207 | window_name=window_name, 208 | squared_window=squared_window) 209 | 210 | print(self) 211 | 212 | def encode(self, signal): 213 | """ 214 | Returns the embedding for the waveform `signal`. 215 | """ 216 | return self.encoder.forward(signal) 217 | 218 | def decode(self, embeddings): 219 | """ 220 | Return the waveforms from `embeddings` 221 | """ 222 | return self.decoder.forward(embeddings) 223 | 224 | def forward(self, signal): 225 | return self.decode(self.encode(signal)) 226 | 227 | def __repr__(self): 228 | return "ConvolutionalAE(encoder={!r},decoder={!r})".format( 229 | self.encoder, self.decoder) 230 | -------------------------------------------------------------------------------- /sing/ae/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from ..fondation import utils, trainer 10 | 11 | 12 | class AutoencoderTrainer(trainer.BaseTrainer): 13 | """ 14 | Trainer for the autoencoder. 15 | """ 16 | 17 | def _train_batch(self, batch): 18 | rebuilt, target = self._get_rebuilt_target(batch) 19 | self.optimizer.zero_grad() 20 | loss = self.train_loss(rebuilt, target) 21 | loss.backward() 22 | self.optimizer.step() 23 | return loss.item() 24 | 25 | def _get_rebuilt_target(self, batch): 26 | wav = batch.tensors['wav'] 27 | target = utils.unpad1d(wav, self.model.decoder.strip) 28 | rebuilt = self.parallel.forward(wav) 29 | return rebuilt, target 30 | -------------------------------------------------------------------------------- /sing/ae/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from .. import dsp 13 | 14 | 15 | class WindowedConv1d(nn.Module): 16 | """ 17 | Smooth a convolution using a window. 18 | 19 | Arguments: 20 | conv (nn.Conv1d): convolution module to wrap 21 | window_name (str or None): name of the window used to smooth 22 | the convolutions. See :func:`sing.dsp.get_window` 23 | squared (bool): if `True`, square the smoothing window 24 | """ 25 | 26 | def __init__(self, conv, window_name='hann', squared=True): 27 | super(WindowedConv1d, self).__init__() 28 | self.window_name = window_name 29 | if squared: 30 | self.window_name += "**2" 31 | self.register_buffer('window', 32 | dsp.get_window( 33 | window_name, 34 | conv.weight.size(-1), 35 | squared=squared)) 36 | self.conv = conv 37 | 38 | def forward(self, input): 39 | weight = self.window * self.conv.weight 40 | return F.conv1d( 41 | input, 42 | weight, 43 | bias=self.conv.bias, 44 | stride=self.conv.stride, 45 | dilation=self.conv.dilation, 46 | groups=self.conv.groups, 47 | padding=self.conv.padding) 48 | 49 | def __repr__(self): 50 | return "WindowedConv1d(window={},conv={})".format( 51 | self.window_name, self.conv) 52 | 53 | 54 | class WindowedConvTranpose1d(nn.Module): 55 | """ 56 | Smooth a transposed convolution using a window. 57 | 58 | Arguments: 59 | conv (nn.Conv1d): convolution module to wrap 60 | window_name (str or None): name of the window used to smooth 61 | the convolutions. See :func:`sing.dsp.get_window` 62 | squared (bool): if `True`, square the smoothing window 63 | """ 64 | 65 | def __init__(self, conv_tr, window_name='hann', squared=True): 66 | super(WindowedConvTranpose1d, self).__init__() 67 | self.window_name = window_name 68 | if squared: 69 | self.window_name += "**2" 70 | self.register_buffer('window', 71 | dsp.get_window( 72 | window_name, 73 | conv_tr.weight.size(-1), 74 | squared=squared)) 75 | self.conv_tr = conv_tr 76 | 77 | def forward(self, input): 78 | weight = self.window * self.conv_tr.weight 79 | return F.conv_transpose1d( 80 | input, 81 | weight, 82 | bias=self.conv_tr.bias, 83 | stride=self.conv_tr.stride, 84 | padding=self.conv_tr.padding, 85 | output_padding=self.conv_tr.output_padding, 86 | groups=self.conv_tr.groups, 87 | dilation=self.conv_tr.dilation) 88 | 89 | def __repr__(self): 90 | return "WindowedConvTranpose1d(window={},conv_tr={})".format( 91 | self.window_name, self.conv_tr) 92 | -------------------------------------------------------------------------------- /sing/dsp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def power(spec): 15 | """ 16 | Given a complex spectrogram, return the power spectrum. 17 | 18 | Shape: 19 | - `spec`: `(*, 2, F, T)` 20 | - Output: `(*, F, T)` 21 | """ 22 | return spec[..., 0]**2 + spec[..., 1]**2 23 | 24 | 25 | def get_window(name, window_length, squared=False): 26 | """ 27 | Returns a windowing function. 28 | 29 | Arguments: 30 | window (str): name of the window, currently only 'hann' is available 31 | window_length (int): length of the window 32 | squared (bool): if true, square the window 33 | 34 | Returns: 35 | torch.FloatTensor: window of size `window_length` 36 | """ 37 | if name == "hann": 38 | window = torch.hann_window(window_length) 39 | else: 40 | raise ValueError("Invalid window name {}".format(name)) 41 | if squared: 42 | window *= window 43 | return window 44 | 45 | 46 | class STFT(nn.Module): 47 | """ 48 | Compute the STFT. 49 | See :mod:`torch.stft` for a definition of the parameters. 50 | 51 | Arguments: 52 | n_fft (int): performs a FFT over `n_fft` samples 53 | hop_length (int or None): stride of the STFT transform. If `None` 54 | uses `n_fft // 4` 55 | window_name (str or None): name of the window used for the STFT. 56 | No window is used if `None`. 57 | 58 | """ 59 | 60 | def __init__(self, n_fft=1024, hop_length=None, window_name='hann'): 61 | super(STFT, self).__init__() 62 | assert n_fft % 2 == 0 63 | window = None 64 | if window_name is not None: 65 | window = get_window(window_name, n_fft) 66 | self.register_buffer("window", window) 67 | self.hop_length = hop_length or n_fft // 4 68 | self.n_fft = n_fft 69 | 70 | def forward(self, input): 71 | return torch.stft( 72 | input, 73 | window=self.window, 74 | n_fft=self.n_fft, 75 | hop_length=self.hop_length, 76 | center=False) 77 | 78 | 79 | class SpectralLoss(nn.Module): 80 | """ 81 | Compute a loss between two log power-spectrograms. 82 | 83 | Arguments: 84 | base_loss (function): loss used to compare the log power-spectrograms. 85 | For instance :func:`F.mse_loss` 86 | epsilon (float): offset for the log, i.e. `log(epsilon + ...)` 87 | **kwargs (dict): see :class:`STFT` 88 | """ 89 | 90 | def __init__(self, base_loss=F.mse_loss, epsilon=1, **kwargs): 91 | super(SpectralLoss, self).__init__() 92 | self.base_loss = base_loss 93 | self.epsilon = epsilon 94 | self.stft = STFT(**kwargs) 95 | 96 | def _log_spectrogram(self, signal): 97 | return torch.log(self.epsilon + power(self.stft.forward(signal))) 98 | 99 | def forward(self, a, b): 100 | spec_a = self._log_spectrogram(a) 101 | spec_b = self._log_spectrogram(b) 102 | return self.base_loss(spec_a, spec_b) 103 | 104 | 105 | def float_wav_to_short(wav): 106 | """ 107 | Given a float waveform, return a short waveform. 108 | The input waveform will be clamped between -1 and 1. 109 | """ 110 | return (wav.clamp(-1, 1) * (2**15 - 1)).short() 111 | -------------------------------------------------------------------------------- /sing/fondation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | -------------------------------------------------------------------------------- /sing/fondation/batch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from torch.utils.data.dataloader import default_collate 10 | 11 | 12 | class BatchItem: 13 | """ 14 | Reprensents a single batch item. A :class:`Batch` can be built 15 | from multiple :class:`BatchItem` using :func:`collate`. 16 | 17 | Attributes: 18 | metadata (dict[str, object]): Contains all the metadata 19 | about the batch item. Those elements will not be 20 | collated together 21 | when building a batch 22 | tensors (dict[str, tensor]): Contains all the tensors 23 | for the batch item. Those elements will be collated 24 | together when building a batch using :func:`default_collate`. 25 | 26 | """ 27 | 28 | def __init__(self, metadata=None, tensors=None): 29 | self.metadata = dict(metadata) if metadata else {} 30 | self.tensors = dict(tensors) if tensors else {} 31 | 32 | 33 | def collate(items): 34 | """ 35 | Collate together all the items into a :class:`Batch`. 36 | The metadata dictionaries will be added to a list 37 | and the tensors will be collated using 38 | :func:`torch.utils.data.dataloader.default_collate`. 39 | 40 | Args: 41 | items (list[BatchItem]): list of the items in the batch 42 | 43 | Returns: 44 | Batch: a batch made from `items`. 45 | """ 46 | metadata = [item.metadata for item in items] 47 | tensors = default_collate([item.tensors for item in items]) 48 | return Batch(metadata=metadata, tensors=tensors) 49 | 50 | 51 | class Batch: 52 | """ 53 | Represents a batch. Supports iteration 54 | (yields individual :class:`BatchItem`) and indexing. Slice 55 | indexing will return another :class:`Batch`. 56 | 57 | Attributes: 58 | metadata (list[dict[str, object]]): a list of dictionaries 59 | for each element in the batch. 60 | Each dictionary contains information 61 | about the corresponding item. 62 | tensors (dict[str, tensor]): a dictionary of collated tensors. 63 | The first dimension of each tensor will always be `B`, 64 | the batch size. 65 | """ 66 | 67 | def __init__(self, metadata, tensors): 68 | self.metadata = metadata 69 | self.tensors = tensors 70 | 71 | def __len__(self): 72 | return len(self.metadata) 73 | 74 | def __iter__(self): 75 | for i in range(len(self)): 76 | yield self[i] 77 | 78 | def __getitem__(self, index): 79 | if isinstance(index, slice): 80 | if index.step is not None: 81 | raise IndexError("Does not support slice with step") 82 | metadata = self.metadata[index] 83 | tensors = { 84 | name: tensor[index] 85 | for name, tensor in self.tensors.items() 86 | } 87 | return Batch(metadata=metadata, tensors=tensors) 88 | else: 89 | return BatchItem( 90 | metadata=self.metadata[index], 91 | tensors={ 92 | name: tensor[index] 93 | for name, tensor in self.tensors.items() 94 | }) 95 | 96 | def apply(self, function): 97 | """ 98 | Apply a function to all tensors. 99 | 100 | Arguments: 101 | function: callable to be applied to all tensors. 102 | 103 | Returns: 104 | Batch: A new batch 105 | """ 106 | tensors = { 107 | name: function(tensor) 108 | for name, tensor in self.tensors.items() 109 | } 110 | return Batch(metadata=self.metadata, tensors=tensors) 111 | 112 | def apply_(self, function): 113 | """ 114 | Inplace variance of :meth:`apply`. 115 | """ 116 | other = self.apply(function) 117 | self.tensors = other.tensors 118 | return self 119 | 120 | def cuda(self, *args, **kwargs): 121 | """ 122 | Returns a new batch on GPU. 123 | """ 124 | return self.apply(lambda x: x.cuda()) 125 | 126 | def cuda_(self, *args, **kwargs): 127 | """ 128 | Move the batch inplace to GPU. 129 | """ 130 | return self.apply_(lambda x: x.cuda()) 131 | 132 | def cpu(self, *args, **kwargs): 133 | """ 134 | Returns a new batch on CPU. 135 | """ 136 | return self.apply(lambda x: x.cpu()) 137 | 138 | def cpu_(self, *args, **kwargs): 139 | """ 140 | Move the batch inplace to CPU. 141 | """ 142 | return self.apply_(lambda x: x.cpu_()) 143 | -------------------------------------------------------------------------------- /sing/fondation/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import random 10 | 11 | import torch 12 | 13 | from .utils import random_seed_manager 14 | 15 | 16 | class DatasetSubset: 17 | """ 18 | Represents a subset of a dataset. 19 | 20 | Arguments: 21 | dataset (Dataset): dataset to take a subset of. 22 | indexes (list[int]): list of indexes to keep. 23 | """ 24 | 25 | def __init__(self, dataset, indexes): 26 | self.dataset = dataset 27 | self.indexes = torch.LongTensor(indexes) 28 | 29 | def __len__(self): 30 | return len(self.indexes) 31 | 32 | def __getitem__(self, index): 33 | return self.dataset[self.indexes[index]] 34 | 35 | 36 | class RandomSubset(DatasetSubset): 37 | """ 38 | A random subset of a given size built from another dataset. 39 | 40 | Arguments: 41 | dataset (Dataset): dataset to take a random subset of. 42 | size (int): size of the random subset 43 | random_seed (int): random seed used to select the indexes. 44 | """ 45 | 46 | def __init__(self, dataset, size, random_seed=42): 47 | indexes = list(range(len(dataset))) 48 | with random_seed_manager(random_seed): 49 | random.shuffle(indexes) 50 | 51 | super(RandomSubset, self).__init__( 52 | dataset=dataset, indexes=indexes[:size]) 53 | -------------------------------------------------------------------------------- /sing/fondation/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch 10 | from torch import nn 11 | from torch import optim 12 | from torch.utils.data import DataLoader 13 | import tqdm 14 | 15 | from . import utils 16 | from .batch import collate 17 | 18 | 19 | class BaseTrainer: 20 | """ 21 | Base class for all the epoch-based trainers. Takes care of various 22 | task common to all training like checkpointing, 23 | iterating over the different datasets, computing evaluation metrics etc. 24 | 25 | Arguments: 26 | model (nn.Module): model to train 27 | train_loss (nn.Module): loss used for training 28 | eval_losses (dict[str, nn.Module]): dictionary of evaluation losses 29 | train_dataset (Dataset): dataset used for training 30 | eval_datasets (dict[str, Dataset]): dictionary of datasets 31 | on which each evaluation loss will be computed 32 | epochs (int): number of epochs to train for 33 | suffix (str): suffix used for logging, for instance 34 | if the suffix is `"_phase1"`, during training, `"train_phase1"` 35 | will be displayed 36 | batch_size (int): batch size 37 | cuda (bool): if true, runs on GPU 38 | parallel (bool): if true, use all available GPUs 39 | lr (float): learning rate for :class:`optim.Adam` 40 | checkpoint_path (Path): path to save checkpoint to. If `None`, no 41 | checkpointing is performed. Otherwise, a checkpoint is saved 42 | at the end of each epoch and overwrites the previous one. 43 | 44 | """ 45 | 46 | def __init__(self, 47 | model, 48 | train_loss, 49 | eval_losses, 50 | train_dataset, 51 | eval_datasets, 52 | epochs, 53 | suffix="", 54 | batch_size=32, 55 | cuda=True, 56 | parallel=False, 57 | lr=0.0001, 58 | checkpoint_path=None): 59 | self.model = model 60 | self.parallel = nn.DataParallel(model) if parallel else model 61 | self.is_parallel = parallel 62 | self.train_loss = train_loss 63 | self.eval_losses = nn.ModuleDict(eval_losses) 64 | self.batch_size = batch_size 65 | self.cuda = cuda 66 | self.suffix = suffix 67 | 68 | self.train_dataset = train_dataset 69 | self.eval_datasets = eval_datasets 70 | self.epochs = epochs 71 | self.checkpoint_path = checkpoint_path 72 | 73 | parameters = [p for p in self.model.parameters() if p.requires_grad] 74 | self.optimizer = optim.Adam(parameters, lr=lr) 75 | 76 | if self.cuda: 77 | self.model.cuda() 78 | self.train_loss.cuda() 79 | self.eval_losses.cuda() 80 | else: 81 | self.model.cpu() 82 | self.train_loss.cpu() 83 | self.eval_losses.cpu() 84 | 85 | def _train_epoch(self, dataset, epoch): 86 | """ 87 | Train a single epoch on the given dataset and displays 88 | statistics from time to time. 89 | """ 90 | loader = DataLoader( 91 | dataset, 92 | batch_size=self.batch_size, 93 | shuffle=True, 94 | collate_fn=collate) 95 | iterator = utils.progress_iterator(loader, divisions=20) 96 | 97 | total_loss = 0 98 | with tqdm.tqdm(total=len(dataset), unit="ex") as bar: 99 | for idx, (progress, batch) in enumerate(iterator): 100 | if self.cuda: 101 | batch.cuda_() 102 | total_loss += self._train_batch(batch) 103 | bar.update(len(batch)) 104 | if progress: 105 | tqdm.tqdm.write( 106 | "[train{}][{:03d}] {:.1f}%, loss {:.6f}".format( 107 | self.suffix, epoch, progress, 108 | total_loss / (idx + 1))) 109 | return total_loss 110 | 111 | def _eval_dataset(self, dataset_name, dataset, epoch): 112 | """ 113 | Evaluate all the losses `eval_lossers` on the given dataset 114 | and reports the metrics averaged over the entire dataset. 115 | """ 116 | loader = DataLoader( 117 | dataset, batch_size=self.batch_size, collate_fn=collate) 118 | total_losses = {loss_name: 0 for loss_name in self.eval_losses} 119 | with tqdm.tqdm(total=len(dataset), unit="ex") as bar: 120 | for batch in loader: 121 | if self.cuda: 122 | batch.cuda_() 123 | rebuilt, target = self._get_rebuilt_target(batch) 124 | for name, loss in self.eval_losses.items(): 125 | total_losses[name] += loss(rebuilt, 126 | target).item() * len(batch) 127 | bar.update(len(batch)) 128 | 129 | print("[{}{}][{:03d}] Evaluation: \n{}\n".format( 130 | dataset_name, self.suffix, epoch, "\n".join( 131 | "\t{}={:.6f}".format(name, loss / len(dataset)) 132 | for name, loss in total_losses.items()))) 133 | return total_losses 134 | 135 | def _train_batch(self, batch): 136 | """ 137 | Given a batch, call :meth:`_get_rebuilt_target` 138 | to obtain the `target` and `rebuilt` tensors and call 139 | :attr:`train_loss` on them, compute the gradient and perform 140 | one optimizer step. 141 | 142 | This method can be overriden in subclasses. 143 | """ 144 | rebuilt, target = self._get_rebuilt_target(batch) 145 | self.optimizer.zero_grad() 146 | loss = self.train_loss(rebuilt, target) 147 | loss.backward() 148 | self.optimizer.step() 149 | 150 | return loss.item() 151 | 152 | def _get_rebuilt_target(self, batch): 153 | """ 154 | Should be implemenented in subclasses. 155 | Given a batch, returns a tuple (rebuilt, target). 156 | This tuple will be passed to all the losses in `eval_losses`. 157 | """ 158 | raise NotImplementedError() 159 | 160 | def train(self): 161 | """ 162 | Train :attr:`model` for :attr:`epochs` 163 | """ 164 | last_epoch, state = utils.load_checkpoint(self.checkpoint_path) 165 | if state is not None: 166 | self.model.load_state_dict(state, strict=False) 167 | start_epoch = last_epoch + 1 168 | if start_epoch > self.epochs: 169 | raise ValueError(("Checkpoint has been trained for {} " 170 | "epochs but we aim for {} epochs").format( 171 | start_epoch, self.epochs)) 172 | if start_epoch > 0: 173 | print("Resuming training at epoch {}".format(start_epoch)) 174 | for epoch in range(start_epoch, self.epochs): 175 | self._train_epoch(self.train_dataset, epoch) 176 | utils.save_checkpoint(self.checkpoint_path, epoch, 177 | self.model.state_dict()) 178 | with torch.no_grad(): 179 | for name, dataset in self.eval_datasets.items(): 180 | self._eval_dataset(name, dataset, epoch) 181 | -------------------------------------------------------------------------------- /sing/fondation/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import contextlib 10 | import hashlib 11 | from pathlib import Path 12 | import random 13 | import requests 14 | import sys 15 | 16 | import torch 17 | import tqdm 18 | 19 | 20 | @contextlib.contextmanager 21 | def random_seed_manager(seed): 22 | """ 23 | Context manager that will save the python RNG state, 24 | set the seed to `seed` and on exit, set back the python RNG state. 25 | """ 26 | state = random.getstate() 27 | try: 28 | random.seed(seed) 29 | yield None 30 | finally: 31 | random.setstate(state) 32 | 33 | 34 | def progress_iterator(iterator, divisions=100): 35 | """ 36 | Wraps an iterator of known length and yield a tuple `(progress, item)` 37 | for each `item` in `iterator`. `progress` will be None except `divisions` 38 | times that are evenly spaced. When `progress` is not None 39 | it will contain the current percentage of items that have been seen. 40 | 41 | Arguments: 42 | iterator (iterator): source iterator, should support :func:`len`. 43 | divisions (int): progress will be every 1/divisions of the 44 | iterator length. 45 | 46 | Examples:: 47 | >>> for (progress, element) in progress_iterator(range(500)): 48 | ... if progress: 49 | ... print("{:.0f}% done".format(progress)) 50 | """ 51 | 52 | length = len(iterator) 53 | division_width = length / divisions 54 | next_division = division_width 55 | for idx, element in enumerate(iterator): 56 | progress = None 57 | if (idx + 1) >= next_division or idx + 1 == length: 58 | next_division += division_width 59 | progress = (idx + 1) / length * 100 60 | yield progress, element 61 | 62 | 63 | def unpad1d(tensor, pad): 64 | """ 65 | Opposite of padding, will remove `pad` items on each side 66 | of the last dimension of `tensor`. 67 | 68 | Arguments: 69 | tensor (tensor): tensor to unpad 70 | pad (int): amount of padding to remove on each side. 71 | """ 72 | if pad > 0: 73 | return tensor[..., pad:-pad] 74 | return tensor 75 | 76 | 77 | def load_checkpoint(path): 78 | """ 79 | Arguments: 80 | path (str or Path): path to load 81 | Returns: 82 | (int, object): returns a tuple (epoch, state). 83 | """ 84 | if path is None or not Path(path).exists(): 85 | return -1, None 86 | return torch.load(path) 87 | 88 | 89 | def save_checkpoint(path, epoch, state): 90 | """ 91 | Save a new checkpoint. A temporary file is created 92 | and then renamed to the target path. 93 | 94 | Arguments: 95 | path (str or Path): path to write to 96 | epoch (int): current epoch 97 | state (object): state to save 98 | """ 99 | if path is None: 100 | return 101 | path = Path(path) 102 | tmp_path = path.parent / (path.name + ".tmp") 103 | 104 | torch.save((epoch, state), str(tmp_path)) 105 | tmp_path.rename(path) 106 | 107 | 108 | def download_file(target, url, sha256=None): 109 | """ 110 | Download a file with a progress bar. 111 | 112 | Arguments: 113 | target (Path): target path to write to 114 | url (str): url to download 115 | sha256 (str or None): expected sha256 hexdigest of the file 116 | """ 117 | response = requests.get(url, stream=True) 118 | total_length = int(response.headers.get('content-length', 0)) 119 | 120 | if sha256 is not None: 121 | sha = hashlib.sha256() 122 | update = sha.update 123 | else: 124 | update = lambda x: None 125 | 126 | with tqdm.tqdm(total=total_length, unit="B", unit_scale=True) as bar: 127 | with open(target, "wb") as output: 128 | for data in response.iter_content(chunk_size=4096): 129 | output.write(data) 130 | update(data) 131 | bar.update(len(data)) 132 | if sha256 is not None: 133 | signature = sha.hexdigest() 134 | if sha256 != signature: 135 | target.unlink() 136 | raise ValueError("Invalid sha256 signature when downloading {}. " 137 | "Expected {} but got {}".format( 138 | url, sha256, signature)) 139 | 140 | 141 | def fatal(message, error_code=1): 142 | """ 143 | Print `message` to stderr and exit with the code `error_code`. 144 | """ 145 | print(message, file=sys.stderr) 146 | sys.exit(1) 147 | -------------------------------------------------------------------------------- /sing/generate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import argparse 10 | from pathlib import Path 11 | import sys 12 | 13 | from scipy.io import wavfile 14 | import torch 15 | from torch import nn 16 | from torch.utils.data import DataLoader 17 | import tqdm 18 | 19 | from . import dsp, nsynth 20 | from .fondation.batch import collate 21 | from .fondation.datasets import DatasetSubset 22 | from .fondation import utils 23 | from .sequence.models import download_pretrained_model 24 | 25 | 26 | def get_parser(): 27 | parser = argparse.ArgumentParser( 28 | "sing.generate", 29 | description="Generate audio samples from a trained SING model", 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser.add_argument( 32 | "--model", 33 | type=Path, 34 | default="models/sing.th", 35 | help="Path to the trained SING model as outputted by sing.main") 36 | parser.add_argument( 37 | "--dl", 38 | action="store_true", 39 | help="Download if necessary a pretrained SING model.") 40 | parser.add_argument( 41 | "--output", 42 | type=Path, 43 | default="generated", 44 | help="Path where the generated samples will be saved") 45 | parser.add_argument( 46 | "--metadata", 47 | default=nsynth.get_metadata_path(), 48 | type=Path, 49 | help="path to the dataset metadata file") 50 | 51 | parser.add_argument( 52 | "list", 53 | type=Path, 54 | help="File containing a list of names from the nsynth dataset. " 55 | "Those notes will be generated by SING") 56 | parser.add_argument( 57 | "--batch-size", type=int, default=32, help="Batch size") 58 | parser.add_argument("--cuda", action="store_true", help="Use cuda") 59 | parser.add_argument( 60 | "--parallel", action="store_true", help="Use multiple gpus") 61 | parser.add_argument( 62 | "--unpad", 63 | default=512, 64 | type=int, 65 | help="Amount of unpadding to perform") 66 | return parser 67 | 68 | 69 | def main(): 70 | args = get_parser().parse_args() 71 | 72 | if not args.model.exists(): 73 | if args.dl: 74 | print("Downloading pretrained SING model") 75 | args.model.parent.mkdir(parents=True, exist_ok=True) 76 | download_pretrained_model(args.model) 77 | else: 78 | utils.fatal("No model found for path {}. To download " 79 | "a pretrained model, use --dl".format(args.model)) 80 | elif args.dl: 81 | print( 82 | "WARNING: --dl is set but {} already exist.".format(args.model), 83 | file=sys.stderr) 84 | model = torch.load(args.model) 85 | 86 | if args.cuda: 87 | model.cuda() 88 | if args.parallel: 89 | model = nn.DataParallel(model) 90 | 91 | args.output.mkdir(exist_ok=True, parents=True) 92 | dataset = nsynth.NSynthMetadata(args.metadata) 93 | 94 | names = [name.strip() for name in open(args.list)] 95 | indexes = [dataset.names.index(name) for name in names] 96 | 97 | to_generate = DatasetSubset(dataset, indexes) 98 | loader = DataLoader( 99 | to_generate, batch_size=args.batch_size, collate_fn=collate) 100 | 101 | with tqdm.tqdm(total=len(to_generate), unit="ex") as bar: 102 | for batch in loader: 103 | if args.cuda: 104 | batch.cuda_() 105 | with torch.no_grad(): 106 | rebuilt = model.forward(**batch.tensors) 107 | rebuilt = utils.unpad1d(rebuilt, args.unpad) 108 | for metadata, wav in zip(batch.metadata, rebuilt): 109 | path = args.output / (metadata['name'] + ".wav") 110 | wavfile.write( 111 | str(path), metadata['sample_rate'], 112 | dsp.float_wav_to_short(wav).cpu().detach().numpy()) 113 | bar.update(len(batch)) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /sing/nsynth/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from collections import defaultdict 10 | import gzip 11 | import json 12 | from pathlib import Path 13 | import random 14 | 15 | from scipy.io import wavfile 16 | import torch 17 | from torch.nn import functional as F 18 | 19 | from ..fondation.batch import BatchItem 20 | from ..fondation.datasets import DatasetSubset 21 | from ..fondation import utils 22 | 23 | 24 | class NSynthMetadata: 25 | """ 26 | NSynth metadata without the wavforms. 27 | 28 | Arguments: 29 | path (Path): path to the NSynth dataset. 30 | This path should contain a `examples.json` file. 31 | 32 | An item of the nsynth metadata dataset will contain the follow tensors: 33 | - instrument (LongTensor) 34 | - pitch (LongTensor) 35 | - velocity (LongTensor) 36 | - instrument_family (LongTensor) 37 | - index (LongTensor) 38 | 39 | Attributes: 40 | cardinalities (dict[str, int]): cardinality of 41 | instrument, instrument_family, pitch and velocity 42 | instruments (dict[str, int]): mapping from instrument 43 | name to instrument index 44 | """ 45 | _json_cache = {} 46 | 47 | _FEATURES = ['instrument', 'instrument_family', 'pitch', 'velocity'] 48 | 49 | def _map_velocity(self, metadata): 50 | velocity_mapping = { 51 | 25: 0, 52 | 50: 1, 53 | 75: 2, 54 | 100: 4, 55 | 127: 5, 56 | } 57 | for meta in self._metadata.values(): 58 | meta["velocity"] = velocity_mapping[meta['velocity']] 59 | 60 | def __init__(self, path): 61 | self.path = Path(path) 62 | 63 | # Cache the json to avoid reparsing it everytime 64 | if self.path in self._json_cache: 65 | self._metadata = self._json_cache[self.path] 66 | else: 67 | if self.path.suffix == ".gz": 68 | file = gzip.open(self.path) 69 | else: 70 | file = open(self.path, "rb") 71 | self._metadata = json.load(file) 72 | self._map_velocity(self._metadata) 73 | self._json_cache[self.path] = self._metadata 74 | 75 | self.names = sorted(self._metadata.keys()) 76 | 77 | # Compute the mapping instrument_name -> instrument id 78 | self.instruments = {} 79 | for meta in self._metadata.values(): 80 | self.instruments[meta["instrument_str"]] = meta["instrument"] 81 | 82 | # Compute the cardinality for the features velocity, instrument, 83 | # pitch and instrument_family 84 | self.cardinalities = {} 85 | for feature in self._FEATURES: 86 | self.cardinalities[feature] = 1 + max( 87 | i[feature] for i in self._metadata.values()) 88 | 89 | def __len__(self): 90 | return len(self.names) 91 | 92 | def __getitem__(self, index): 93 | if hasattr(index, "item"): 94 | index = index.item() 95 | name = self.names[index] 96 | metadata = self._metadata[name] 97 | tensors = {} 98 | 99 | metadata['name'] = name 100 | metadata['index'] = index 101 | for feature in self._FEATURES: 102 | tensors[feature] = torch.LongTensor([metadata[feature]]) 103 | 104 | return BatchItem(metadata=metadata, tensors=tensors) 105 | 106 | 107 | class NSynthDataset: 108 | """ 109 | NSynth dataset. 110 | 111 | Arguments: 112 | path (Path): path to the NSynth dataset. 113 | This path should contain a `examples.json` file 114 | and an `audio` folder containing the wav files. 115 | pad (int): amount of padding to add to the waveforms. 116 | 117 | Items from this dataset will contain all the information 118 | coming from :class:`NSynthMetadata` as well as a `'wav'` 119 | tensor containing the waveform. 120 | 121 | Attributes: 122 | metadata (NSynthMetadata): metadata only dataset 123 | """ 124 | 125 | def __init__(self, path, pad=0): 126 | self.metadata = NSynthMetadata(Path(path) / "examples.json") 127 | self.pad = pad 128 | 129 | def __len__(self): 130 | return len(self.metadata) 131 | 132 | def __getitem__(self, index): 133 | item = self.metadata[index] 134 | 135 | path = self.metadata.path.parent / "audio" / "{}.wav".format( 136 | item.metadata['name']) 137 | item.metadata['path'] = path 138 | 139 | _, wav = wavfile.read(str(path), mmap=True) 140 | wav = torch.as_tensor(wav, dtype=torch.float) 141 | wav /= 2**15 - 1 142 | item.tensors['wav'] = F.pad(wav, (self.pad, self.pad)) 143 | 144 | return item 145 | 146 | 147 | def make_datasets(dataset, valid_ratio=0.1, test_ratio=0.1, random_seed=42): 148 | """ 149 | Take the original NSynth training dataset and split it into 150 | a train, valid and test set making sure that for a given instrument, 151 | a pitch is present in only one dataset (each pair of instrument and pitch 152 | has multiple occurences, one for each velocity). 153 | """ 154 | 155 | per_pitch_instrument = defaultdict(list) 156 | 157 | if isinstance(dataset, NSynthDataset): 158 | metadata = dataset.metadata 159 | elif isinstance(dataset, NSynthMetadata): 160 | metadata = dataset 161 | else: 162 | raise ValueError( 163 | "Invalid dataset {}, should be an instance of " 164 | "either NSynthDataset or NSynthMetadata.".format(dataset)) 165 | 166 | for index in range(len(metadata)): 167 | item = metadata[index] 168 | per_pitch_instrument[(item.metadata['instrument'], 169 | item.metadata['pitch'])].append(index) 170 | 171 | with utils.random_seed_manager(random_seed): 172 | train = [] 173 | valid = [] 174 | test = [] 175 | for indexes in per_pitch_instrument.values(): 176 | score = random.random() 177 | if score < valid_ratio: 178 | valid.extend(indexes) 179 | elif score < valid_ratio + test_ratio: 180 | test.extend(indexes) 181 | else: 182 | train.extend(indexes) 183 | 184 | return DatasetSubset(dataset, train), DatasetSubset( 185 | dataset, valid), DatasetSubset(dataset, test) 186 | 187 | 188 | def get_metadata_path(): 189 | """ 190 | Get the path to the nsynth-train metadata included with SING. 191 | """ 192 | return Path(__file__).parent / "examples.json.gz" 193 | 194 | 195 | def get_nsynth_metadata(): 196 | return NSynthMetadata(get_metadata_path()) 197 | -------------------------------------------------------------------------------- /sing/nsynth/examples.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/SING/72054bdb23b4ced393c0d435c124db64c1e4cb26/sing/nsynth/examples.json.gz -------------------------------------------------------------------------------- /sing/parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import argparse 10 | from pathlib import Path 11 | 12 | 13 | def get_parser(): 14 | """ 15 | Returns: 16 | argparse.ArgumentParser: parser with all the options 17 | for the training of a SING model. 18 | """ 19 | parser = argparse.ArgumentParser( 20 | "sing.train", 21 | description="Train a SING model on the NSynth dataset", 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | 24 | # Datasets arguments 25 | parser.add_argument( 26 | "--data", 27 | default="data/nsynth-train", 28 | type=Path, 29 | help="path to the dataset, e.g. .../nsynth-train") 30 | parser.add_argument( 31 | "--pad", 32 | type=int, 33 | default=2304, 34 | help="Extra padding added to the waveforms", 35 | ) 36 | 37 | # Loss arguments 38 | parser.add_argument("--wav", action="store_true", help="Use a Wav loss") 39 | parser.add_argument( 40 | "--epsilon", 41 | default=1, 42 | type=float, 43 | help="Offset for power spectrum before taking the log") 44 | parser.add_argument( 45 | "--l1", 46 | action="store_true", 47 | help="Use L1 loss instead of mse", 48 | ) 49 | 50 | # Misc arguments 51 | parser.add_argument("--cuda", action="store_true", help="Use cuda") 52 | parser.add_argument( 53 | "--parallel", action="store_true", help="Use multiple gpus") 54 | parser.add_argument( 55 | "--checkpoint", 56 | type=Path, 57 | default=None, 58 | help="Path to the checkpoint folder") 59 | parser.add_argument( 60 | "--output", 61 | type=Path, 62 | default="models/sing.th", 63 | help="Path to output final SING model") 64 | parser.add_argument( 65 | "-d", "--debug", action="store_true", help="Debug flag") 66 | parser.add_argument( 67 | "-f", "--debug-fast", action="store_true", help="Debug fast flag") 68 | 69 | # Common arguments 70 | parser.add_argument( 71 | "--lr", type=float, default=0.0003, help="Learning rate for Adam") 72 | parser.add_argument( 73 | "--batch-size", type=int, default=64, help="Batch size") 74 | 75 | # Autoencoder arguments 76 | parser.add_argument( 77 | "--ae-epochs", 78 | type=int, 79 | default=50, 80 | help="Number of epochs for the autoencoder") 81 | 82 | parser.add_argument( 83 | "--ae-channels", 84 | type=int, 85 | default=4096, 86 | help="Number of channels in the autoencoder") 87 | parser.add_argument( 88 | "--ae-stride", type=int, default=256, help="Stride of the autoencoder") 89 | parser.add_argument( 90 | "--ae-dimension", 91 | type=int, 92 | default=128, 93 | help="Dimension of the autoencoder embedding") 94 | parser.add_argument( 95 | "--ae-kernel", 96 | type=int, 97 | default=1024, 98 | help="Kernel size of the autoencoder") 99 | parser.add_argument( 100 | "--ae-rewrite", 101 | type=int, 102 | default=2, 103 | help="Number of rewrite layers in the autoencoder") 104 | parser.add_argument( 105 | "--ae-context", 106 | type=int, 107 | default=9, 108 | help="Context size of the decoder") 109 | parser.add_argument( 110 | "--ae-window", 111 | default="hann", 112 | help="Window to use to smooth convolutions. Default to 'hann'. " 113 | "To deactivate, use --ae-no-window") 114 | parser.add_argument( 115 | "--ae-no-window", dest="ae_window", action="store_const", const=None) 116 | parser.add_argument( 117 | "--ae-squared-window", 118 | action="store_true", 119 | default=True, 120 | help="Square the window used to smooth convolutions. " 121 | "To deactivate, use --ae-no-squared-window.") 122 | parser.add_argument( 123 | "--ae-no-squared-window", 124 | action="store_false", 125 | dest="ae_squared_window") 126 | 127 | # Sequence generator arguments 128 | parser.add_argument( 129 | "--seq-hidden-size", 130 | type=int, 131 | default=1024, 132 | help="Size of the LSTM hidden layers") 133 | parser.add_argument( 134 | "--seq-layers", 135 | type=int, 136 | default=3, 137 | help="Number of layers in the LSTM") 138 | parser.add_argument( 139 | "--seq-epochs", 140 | type=int, 141 | default=50, 142 | help="Number of epochs for the sequence generator") 143 | parser.add_argument( 144 | "--seq-truncated", 145 | type=int, 146 | default=32, 147 | help="Truncated gradient for the sequence generator. " 148 | "0 means using the full sequence.") 149 | parser.add_argument( 150 | "--sing-epochs", 151 | type=int, 152 | default=20, 153 | help="Number of fine tuning epochs for the full SING model") 154 | 155 | # Lookup tables arguments 156 | parser.add_argument( 157 | "--time-dim", 158 | type=int, 159 | default=4, 160 | help="Dimension of the time step lookup table") 161 | parser.add_argument( 162 | "--instrument-dim", 163 | type=int, 164 | default=16, 165 | help="Dimension of the instrument embedding") 166 | parser.add_argument( 167 | "--pitch-dim", 168 | type=int, 169 | default=8, 170 | help="Dimension of the pitch embedding") 171 | parser.add_argument( 172 | "--velocity-dim", 173 | type=int, 174 | default=2, 175 | help="Dimension of the velocity embedding") 176 | return parser 177 | -------------------------------------------------------------------------------- /sing/sequence/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | -------------------------------------------------------------------------------- /sing/sequence/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from ..fondation import utils 13 | 14 | 15 | class SequenceGenerator(nn.Module): 16 | """ 17 | LSTM part of the SING model. 18 | 19 | Arguments: 20 | embeddings (dict[str, (int, int)]): 21 | represents the lookup tables used by the model. 22 | For each entry under the key `name`, with value 23 | `(cardinality, dimension)`, the tensor named `name` will be 24 | retrieved. Its values should be in `[0, cardinality - 1]` 25 | and the lookup table will have dimension `dimension` 26 | length (int): length of the generated sequence 27 | output_dimension (int): dimension of each generated sequence item 28 | hidden_size (int): size of each layer, see the documentation of 29 | :class:`nn.LSTM` 30 | num_layers (int): number of layers, see the documentation of 31 | :class:`nn.LSTM` 32 | 33 | """ 34 | 35 | def __init__(self, 36 | embeddings, 37 | length, 38 | time_dimension=4, 39 | output_dimension=128, 40 | hidden_size=1024, 41 | num_layers=3): 42 | super(SequenceGenerator, self).__init__() 43 | self.tables = nn.ModuleList() 44 | self.inputs = [] 45 | input_size = time_dimension 46 | for name, (cardinality, dimension) in sorted(embeddings.items()): 47 | input_size += dimension 48 | self.inputs.append(name) 49 | self.tables.append( 50 | nn.Embedding( 51 | num_embeddings=cardinality, embedding_dim=dimension)) 52 | 53 | if time_dimension == 0: 54 | self.time_table = None 55 | else: 56 | self.time_table = nn.Embedding( 57 | num_embeddings=length, embedding_dim=time_dimension) 58 | self.length = length 59 | 60 | self.lstm = nn.LSTM( 61 | input_size=input_size, 62 | hidden_size=hidden_size, 63 | num_layers=num_layers) 64 | 65 | self.decoder = nn.Linear(hidden_size, output_dimension) 66 | 67 | def forward(self, start=0, length=None, hidden=None, **tensors): 68 | """ 69 | Arguments: 70 | start (int): first time step to generate 71 | length (int): length of the sequence to generate. If `None`, 72 | will be taken to be `self.length - start` 73 | hidden ((torch.FloatTensor, torch.FloatTensor)): 74 | hidden state of the LSTM or `None` to start 75 | from a blank one 76 | **tensors (dict[str, torch.LongTensor]): 77 | dictionary containing the tensors used as inputs 78 | to the lookup tables specified by the `embeddings` 79 | parameter of the constructor 80 | """ 81 | length = self.length - start if length is None else length 82 | 83 | inputs = [] 84 | for name, table in zip(self.inputs, self.tables): 85 | value = tensors[name].transpose(0, 1) 86 | embedding = table.forward(value) 87 | inputs.append(embedding.expand(length, -1, -1)) 88 | 89 | reference = inputs[0] 90 | if self.time_table is not None: 91 | times = torch.arange( 92 | start, start + length, 93 | device=reference.device).view(-1, 1).expand( 94 | -1, reference.size(1)) 95 | inputs.append(self.time_table.forward(times)) 96 | input = torch.cat(inputs, dim=-1) 97 | if hidden is not None: 98 | hidden = [h.transpose(0, 1).contiguous() for h in hidden] 99 | 100 | self.lstm.flatten_parameters() 101 | output, hidden = self.lstm.forward(input, hidden) 102 | decoded = self.decoder(output.view(-1, output.size(-1))).view( 103 | output.size(0), output.size(1), -1) 104 | hidden = [h.transpose(0, 1) for h in hidden] 105 | return decoded.transpose(0, 1).transpose(1, 2), hidden 106 | 107 | 108 | class SING(nn.Module): 109 | """ 110 | Complete SING model. 111 | 112 | Arguments: 113 | sequence_generator (SequenceGenerator): the LSTM based 114 | sequence generator part of SING 115 | decoder (sing.ae.models.ConvolutionalDecoder): 116 | the convolutional decoder part of SING 117 | """ 118 | 119 | def __init__(self, sequence_generator, decoder): 120 | super(SING, self).__init__() 121 | self.sequence_generator = sequence_generator 122 | self.decoder = decoder 123 | 124 | def forward(self, **tensors): 125 | """ 126 | Arguments: 127 | **tensors (dict[str, torch.LongTensor]): 128 | Tensors used as inputs 129 | to the lookup tables specified by the `embeddings` 130 | parameter of :class:`SequenceGenerator` 131 | """ 132 | return self.decoder.forward(self.sequence_generator(**tensors)[0]) 133 | 134 | 135 | def download_pretrained_model(target): 136 | """ 137 | Download a pretrained version of SING. 138 | """ 139 | url = "https://dl.fbaipublicfiles.com/sing/sing.th" 140 | sha256 = "eda8a7ce66f1ccf31cdd34a920290d80aabf96584c4d53df866b744f2862dc1c" 141 | utils.download_file(target, url, sha256=sha256) 142 | -------------------------------------------------------------------------------- /sing/sequence/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | from torch import nn 10 | 11 | from ..fondation import utils, trainer 12 | 13 | 14 | class SequenceGeneratorTrainer(trainer.BaseTrainer): 15 | """ 16 | Trainer for the sequence generator (LSTM) part of SING. 17 | 18 | Arguments: 19 | decoder (sing.ae.models.ConvolutionalDecoder): 20 | decoder, used to compute the metrics on the waveforms 21 | truncated_gradient (int): size of sequence to compute 22 | the gradients over. If `None`, the whole sequence is used 23 | 24 | """ 25 | 26 | def __init__(self, decoder, truncated_gradient=32, **kwargs): 27 | super(SequenceGeneratorTrainer, self).__init__(**kwargs) 28 | self.truncated_gradient = truncated_gradient 29 | self.decoder = decoder 30 | if self.is_parallel: 31 | self.parallel_decoder = nn.DataParallel(decoder) 32 | else: 33 | self.parallel_decoder = decoder 34 | 35 | def _train_batch(self, batch): 36 | embeddings = batch.tensors['embeddings'] 37 | assert embeddings.size(-1) == self.model.length 38 | total_length = self.model.length 39 | hidden = None 40 | 41 | if self.truncated_gradient: 42 | truncated_gradient = self.truncated_gradient 43 | else: 44 | truncated_gradient = total_length 45 | 46 | steps = list(range(0, total_length, truncated_gradient)) 47 | total_loss = 0 48 | for start_time in steps: 49 | sequence_length = min(truncated_gradient, 50 | total_length - start_time) 51 | target = embeddings[..., start_time:start_time + sequence_length] 52 | rebuilt, hidden = self.parallel.forward( 53 | start=start_time, 54 | length=sequence_length, 55 | hidden=hidden, 56 | **batch.tensors) 57 | hidden = tuple([h.detach() for h in hidden]) 58 | self.optimizer.zero_grad() 59 | loss = self.train_loss(rebuilt, target) 60 | loss.backward() 61 | self.optimizer.step() 62 | total_loss += loss.item() / len(steps) 63 | return total_loss 64 | 65 | def _get_rebuilt_target(self, batch): 66 | wav = batch.tensors['wav'] 67 | target = utils.unpad1d(wav, self.decoder.strip) 68 | embeddings, _ = self.parallel.forward(**batch.tensors) 69 | rebuilt = self.parallel_decoder.forward(embeddings) 70 | return rebuilt, target 71 | 72 | 73 | class SINGTrainer(trainer.BaseTrainer): 74 | """ 75 | Trainer for the entire SING model. 76 | """ 77 | 78 | def _get_rebuilt_target(self, batch): 79 | wav = batch.tensors['wav'] 80 | rebuilt = self.parallel.forward(**batch.tensors) 81 | target = utils.unpad1d(wav, self.model.decoder.strip) 82 | return rebuilt, target 83 | -------------------------------------------------------------------------------- /sing/sequence/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch 10 | from torch import nn 11 | from torch.utils.data import DataLoader 12 | import tqdm 13 | 14 | from ..fondation.batch import collate 15 | 16 | 17 | def generate_embeddings_dataset(dataset, encoder, batch_size, cuda, parallel): 18 | """ 19 | Pre-compute all the embeddings for a given dataset. 20 | 21 | Arguments: 22 | dataset (Dataset): dataset to compute the embeddings for. It should 23 | contain a `'wav'` tensor 24 | encoder (sing.ae.models.ConvolutionalEncoder): 25 | encoder to use to generate the embedding 26 | batch_size (int): batch size to use 27 | cuda (bool): if `True`, performs the computation on GPU 28 | parallel (bool): if `True`, use all available GPUs 29 | 30 | Returns: 31 | Dataset: dataset of the same size as `dataset` but with the `'wav'` 32 | tensor replaced by an `'embeddings'` tensor. 33 | 34 | """ 35 | 36 | loader = DataLoader( 37 | dataset, batch_size=batch_size, shuffle=False, collate_fn=collate) 38 | embeddings_dataset = [None] * len(dataset) 39 | 40 | if cuda: 41 | encoder.cuda() 42 | if parallel: 43 | encoder = nn.DataParallel(encoder) 44 | 45 | row = 0 46 | with tqdm.tqdm(total=len(dataset), unit="ex") as bar: 47 | for batch in loader: 48 | if cuda: 49 | batch.cuda_() 50 | with torch.no_grad(): 51 | batch.tensors['embeddings'] = encoder.forward( 52 | batch.tensors['wav']) 53 | del batch.tensors['wav'] 54 | 55 | for item in batch.cpu(): 56 | embeddings_dataset[row] = item 57 | row += 1 58 | bar.update(len(batch)) 59 | return embeddings_dataset 60 | -------------------------------------------------------------------------------- /sing/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) 2018-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import functools 10 | 11 | from torch import nn 12 | import torch 13 | 14 | from .parser import get_parser 15 | from . import nsynth, dsp 16 | from .ae.models import ConvolutionalAE 17 | from .ae.trainer import AutoencoderTrainer 18 | from .fondation import utils, datasets 19 | from .sequence.models import SequenceGenerator, SING 20 | from .sequence.trainer import SequenceGeneratorTrainer, SINGTrainer 21 | from .sequence.utils import generate_embeddings_dataset 22 | 23 | 24 | def train_autoencoder(args, **kwargs): 25 | checkpoint_path = args.checkpoint / "ae.torch" if args.checkpoint else None 26 | model = ConvolutionalAE( 27 | channels=args.ae_channels, 28 | stride=args.ae_stride, 29 | dimension=args.ae_dimension, 30 | kernel_size=args.ae_kernel, 31 | context_size=args.ae_context, 32 | rewrite_layers=args.ae_rewrite, 33 | window_name=args.ae_window, 34 | squared_window=args.ae_squared_window) 35 | advised_pad = model.decoder.strip + 512 36 | if args.pad != advised_pad: 37 | print("Warning, best padding for the current settings is {}, " 38 | "current value is {}.".format(advised_pad, args.pad)) 39 | if args.ae_epochs: 40 | print("Training autoencoder") 41 | AutoencoderTrainer( 42 | suffix="_ae", 43 | model=model, 44 | epochs=args.ae_epochs, 45 | checkpoint_path=checkpoint_path, 46 | **kwargs).train() 47 | return model 48 | 49 | 50 | def train_sequence_generator(args, autoencoder, cardinalities, train_dataset, 51 | eval_datasets, train_loss, eval_losses, **kwargs): 52 | checkpoint_path = (args.checkpoint / "seq.torch" 53 | if args.checkpoint else None) 54 | 55 | wav_length = train_dataset[0].tensors['wav'].size(-1) 56 | embedding_length = autoencoder.decoder.embedding_length( 57 | wav_length - 2 * autoencoder.decoder.strip) 58 | embeddings = { 59 | name: (cardinalities[name], getattr(args, '{}_dim'.format(name))) 60 | for name in ['velocity', 'instrument', 'pitch'] 61 | } 62 | 63 | model = SequenceGenerator( 64 | embeddings=embeddings, 65 | length=embedding_length, 66 | time_dimension=args.time_dim, 67 | output_dimension=args.ae_dimension, 68 | hidden_size=args.seq_hidden_size, 69 | num_layers=args.seq_layers) 70 | 71 | if args.seq_epochs: 72 | print("Precomputing embeddings for all datasets") 73 | generate_embeddings = functools.partial( 74 | generate_embeddings_dataset, 75 | encoder=autoencoder.encoder, 76 | batch_size=args.batch_size, 77 | cuda=args.cuda, 78 | parallel=args.parallel) 79 | train_dataset = generate_embeddings(train_dataset) 80 | 81 | print("Training sequence generator") 82 | SequenceGeneratorTrainer( 83 | suffix="_seq", 84 | model=model, 85 | decoder=autoencoder.decoder, 86 | epochs=args.seq_epochs, 87 | train_loss=nn.MSELoss(), 88 | eval_losses=eval_losses, 89 | train_dataset=train_dataset, 90 | eval_datasets=eval_datasets, 91 | truncated_gradient=args.seq_truncated, 92 | checkpoint_path=checkpoint_path, 93 | **kwargs).train() 94 | return model 95 | 96 | 97 | def fine_tune_sing(args, sequence_generator, decoder, **kwargs): 98 | print("Fine tuning SING") 99 | checkpoint_path = (args.checkpoint / "sing.torch" 100 | if args.checkpoint else None) 101 | model = SING(sequence_generator=sequence_generator, decoder=decoder) 102 | 103 | if args.sing_epochs: 104 | SINGTrainer( 105 | suffix="_sing", 106 | epochs=args.sing_epochs, 107 | model=model, 108 | checkpoint_path=checkpoint_path, 109 | **kwargs).train() 110 | return model 111 | 112 | 113 | def main(): 114 | args = get_parser().parse_args() 115 | 116 | if args.debug: 117 | args.ae_epochs = 1 118 | args.seq_epochs = 1 119 | args.sing_epochs = 1 120 | 121 | if args.debug_fast: 122 | args.ae_channels = 128 123 | args.ae_dimension = 16 124 | args.ae_rewrite = 1 125 | args.seq_hidden_size = 128 126 | args.seq_layers = 1 127 | 128 | if args.checkpoint: 129 | args.checkpoint.mkdir(exist_ok=True, parents=True) 130 | 131 | if not args.data.exists(): 132 | utils.fatal("Could not find the nsynth dataset. " 133 | "To download it, follow the instructions at " 134 | "https://github.com/facebookresearch/SING") 135 | 136 | nsynth_dataset = nsynth.NSynthDataset(args.data, pad=args.pad) 137 | cardinalities = nsynth_dataset.metadata.cardinalities 138 | 139 | train_dataset, valid, test = nsynth.make_datasets(nsynth_dataset) 140 | if args.debug: 141 | train_dataset = datasets.RandomSubset(train_dataset, size=100) 142 | eval_train = datasets.RandomSubset(train_dataset, size=10000) 143 | 144 | if args.debug: 145 | eval_datasets = { 146 | 'eval_train': eval_train, 147 | } 148 | else: 149 | eval_datasets = { 150 | 'eval_train': eval_train, 151 | 'valid': valid, 152 | 'test': test, 153 | } 154 | 155 | base_loss = nn.L1Loss() if args.l1 else nn.MSELoss() 156 | train_loss = base_loss if args.wav else dsp.SpectralLoss( 157 | base_loss, epsilon=args.epsilon) 158 | eval_losses = { 159 | 'wav_l1': nn.L1Loss(), 160 | 'wav_mse': nn.MSELoss(), 161 | 'spec_l1': dsp.SpectralLoss(nn.L1Loss(), epsilon=args.epsilon), 162 | 'spec_mse': dsp.SpectralLoss(nn.MSELoss(), epsilon=args.epsilon), 163 | } 164 | 165 | kwargs = { 166 | 'train_dataset': train_dataset, 167 | 'eval_datasets': eval_datasets, 168 | 'train_loss': train_loss, 169 | 'eval_losses': eval_losses, 170 | 'batch_size': args.batch_size, 171 | 'lr': args.lr, 172 | 'cuda': args.cuda, 173 | 'parallel': args.parallel, 174 | } 175 | 176 | autoencoder = train_autoencoder(args, **kwargs) 177 | sequence_generator = train_sequence_generator(args, autoencoder, 178 | cardinalities, **kwargs) 179 | sing = fine_tune_sing(args, sequence_generator, autoencoder.decoder, 180 | **kwargs) 181 | torch.save(sing.cpu(), str(args.output)) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | --------------------------------------------------------------------------------