├── .github └── workflows │ └── actions.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── darbouka_prior.mp3 ├── log_distance.png ├── log_fidelity.png ├── log_gan.png ├── maxmsp_screenshot.png ├── rave.png ├── rave_attribute.png ├── rave_buffer.png ├── rave_encode_decode.png ├── rave_high_level.png ├── rave_method_forward.png ├── tensorboard_guide.md └── training_setup.md ├── rave ├── __init__.py ├── balancer.py ├── blocks.py ├── configs │ ├── adain.gin │ ├── augmentations │ │ ├── compress.gin │ │ ├── gain.gin │ │ └── mute.gin │ ├── causal.gin │ ├── descript_discriminator.gin │ ├── discrete.gin │ ├── discrete_v3.gin │ ├── hybrid.gin │ ├── noise.gin │ ├── normalize_ambient.gin │ ├── onnx.gin │ ├── prior │ │ └── prior_v1.gin │ ├── raspberry.gin │ ├── snake.gin │ ├── spectral_discriminator.gin │ ├── spherical.gin │ ├── v1.gin │ ├── v2.gin │ ├── v2_nopqmf.gin │ ├── v2_nopqmf_small.gin │ ├── v2_small.gin │ ├── v2_with_augs.gin │ ├── v3.gin │ └── wasserstein.gin ├── core.py ├── dataset.py ├── descript_discriminator.py ├── discriminator.py ├── model.py ├── pqmf.py ├── prior │ ├── __init__.py │ ├── core.py │ ├── model.py │ └── residual_block.py ├── quantization.py ├── resampler.py ├── transforms.py └── version.py ├── requirements.txt ├── scripts ├── __init__.py ├── export.py ├── export_onnx.py ├── generate.py ├── main_cli.py ├── preprocess.py ├── remote_dataset.py ├── train.py └── train_prior.py ├── setup.py └── tests ├── __init__.py ├── test_configs.py ├── test_resampler.py └── test_residual.py /.github/workflows/actions.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | permissions: 4 | pull-requests: write 5 | issues: write 6 | repository-projects: write 7 | contents: write 8 | 9 | on: 10 | pull_request: 11 | push: 12 | branches: [master] 13 | tags: v* 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Set up Python 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: "3.10" 24 | cache: pip 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip setuptools wheel build pytest 28 | python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu 29 | python -m pip install -r requirements.txt 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | if: startsWith(github.ref, 'refs/tags/v') 34 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 35 | with: 36 | user: __token__ 37 | password: ${{ secrets.PYPI_TOKEN }} 38 | 39 | test: 40 | runs-on: ubuntu-latest 41 | steps: 42 | - uses: actions/checkout@v3 43 | - name: Set up Python 44 | uses: actions/setup-python@v3 45 | with: 46 | python-version: "3.10" 47 | cache: pip 48 | - name: Install dependencies 49 | run: | 50 | python -m pip install --upgrade pip setuptools wheel build pytest 51 | python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu 52 | python -m pip install -r requirements.txt 53 | - name: Run tests 54 | run: pytest --junitxml=.test-report.xml 55 | - uses: actions/upload-artifact@v3 56 | if: success() || failure() 57 | with: 58 | name: test-report 59 | path: .test-report.xml 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *pycache* 2 | *DS_Store 3 | lightning_logs/ 4 | *.ckpt 5 | *.ts 6 | *libtorch* 7 | *.wav 8 | *.txt 9 | runs 10 | *.npy 11 | *.yaml 12 | *.onnx 13 | __version__* 14 | PKG-INFO 15 | .junit-test-report.xml 16 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "python.formatting.provider": "yapf", 4 | "python.testing.pytestArgs": [ 5 | "." 6 | ], 7 | "python.testing.unittestEnabled": false, 8 | "python.testing.pytestEnabled": true 9 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International 2 | 3 | Creative Commons Corporation ("Creative Commons") is not a law firm and 4 | does not provide legal services or legal advice. Distribution of 5 | Creative Commons public licenses does not create a lawyer-client or 6 | other relationship. Creative Commons makes its licenses and related 7 | information available on an "as-is" basis. Creative Commons gives no 8 | warranties regarding its licenses, any material licensed under their 9 | terms and conditions, or any related information. Creative Commons 10 | disclaims all liability for damages resulting from their use to the 11 | fullest extent possible. 12 | 13 | Using Creative Commons Public Licenses 14 | 15 | Creative Commons public licenses provide a standard set of terms and 16 | conditions that creators and other rights holders may use to share 17 | original works of authorship and other material subject to copyright and 18 | certain other rights specified in the public license below. The 19 | following considerations are for informational purposes only, are not 20 | exhaustive, and do not form part of our licenses. 21 | 22 | - Considerations for licensors: Our public licenses are intended for 23 | use by those authorized to give the public permission to use 24 | material in ways otherwise restricted by copyright and certain other 25 | rights. Our licenses are irrevocable. Licensors should read and 26 | understand the terms and conditions of the license they choose 27 | before applying it. Licensors should also secure all rights 28 | necessary before applying our licenses so that the public can reuse 29 | the material as expected. Licensors should clearly mark any material 30 | not subject to the license. This includes other CC-licensed 31 | material, or material used under an exception or limitation to 32 | copyright. More considerations for licensors : 33 | wiki.creativecommons.org/Considerations_for_licensors 34 | 35 | - Considerations for the public: By using one of our public licenses, 36 | a licensor grants the public permission to use the licensed material 37 | under specified terms and conditions. If the licensor's permission 38 | is not necessary for any reason–for example, because of any 39 | applicable exception or limitation to copyright–then that use is not 40 | regulated by the license. Our licenses grant only permissions under 41 | copyright and certain other rights that a licensor has authority to 42 | grant. Use of the licensed material may still be restricted for 43 | other reasons, including because others have copyright or other 44 | rights in the material. A licensor may make special requests, such 45 | as asking that all changes be marked or described. Although not 46 | required by our licenses, you are encouraged to respect those 47 | requests where reasonable. More considerations for the public : 48 | wiki.creativecommons.org/Considerations_for_licensees 49 | 50 | Creative Commons Attribution-NonCommercial 4.0 International Public 51 | License 52 | 53 | By exercising the Licensed Rights (defined below), You accept and agree 54 | to be bound by the terms and conditions of this Creative Commons 55 | Attribution-NonCommercial 4.0 International Public License ("Public 56 | License"). To the extent this Public License may be interpreted as a 57 | contract, You are granted the Licensed Rights in consideration of Your 58 | acceptance of these terms and conditions, and the Licensor grants You 59 | such rights in consideration of benefits the Licensor receives from 60 | making the Licensed Material available under these terms and conditions. 61 | 62 | - Section 1 – Definitions. 63 | 64 | - a. Adapted Material means material subject to Copyright and 65 | Similar Rights that is derived from or based upon the Licensed 66 | Material and in which the Licensed Material is translated, 67 | altered, arranged, transformed, or otherwise modified in a 68 | manner requiring permission under the Copyright and Similar 69 | Rights held by the Licensor. For purposes of this Public 70 | License, where the Licensed Material is a musical work, 71 | performance, or sound recording, Adapted Material is always 72 | produced where the Licensed Material is synched in timed 73 | relation with a moving image. 74 | - b. Adapter's License means the license You apply to Your 75 | Copyright and Similar Rights in Your contributions to Adapted 76 | Material in accordance with the terms and conditions of this 77 | Public License. 78 | - c. Copyright and Similar Rights means copyright and/or similar 79 | rights closely related to copyright including, without 80 | limitation, performance, broadcast, sound recording, and Sui 81 | Generis Database Rights, without regard to how the rights are 82 | labeled or categorized. For purposes of this Public License, the 83 | rights specified in Section 2(b)(1)-(2) are not Copyright and 84 | Similar Rights. 85 | - d. Effective Technological Measures means those measures that, 86 | in the absence of proper authority, may not be circumvented 87 | under laws fulfilling obligations under Article 11 of the WIPO 88 | Copyright Treaty adopted on December 20, 1996, and/or similar 89 | international agreements. 90 | - e. Exceptions and Limitations means fair use, fair dealing, 91 | and/or any other exception or limitation to Copyright and 92 | Similar Rights that applies to Your use of the Licensed 93 | Material. 94 | - f. Licensed Material means the artistic or literary work, 95 | database, or other material to which the Licensor applied this 96 | Public License. 97 | - g. Licensed Rights means the rights granted to You subject to 98 | the terms and conditions of this Public License, which are 99 | limited to all Copyright and Similar Rights that apply to Your 100 | use of the Licensed Material and that the Licensor has authority 101 | to license. 102 | - h. Licensor means the individual(s) or entity(ies) granting 103 | rights under this Public License. 104 | - i. NonCommercial means not primarily intended for or directed 105 | towards commercial advantage or monetary compensation. For 106 | purposes of this Public License, the exchange of the Licensed 107 | Material for other material subject to Copyright and Similar 108 | Rights by digital file-sharing or similar means is NonCommercial 109 | provided there is no payment of monetary compensation in 110 | connection with the exchange. 111 | - j. Share means to provide material to the public by any means or 112 | process that requires permission under the Licensed Rights, such 113 | as reproduction, public display, public performance, 114 | distribution, dissemination, communication, or importation, and 115 | to make material available to the public including in ways that 116 | members of the public may access the material from a place and 117 | at a time individually chosen by them. 118 | - k. Sui Generis Database Rights means rights other than copyright 119 | resulting from Directive 96/9/EC of the European Parliament and 120 | of the Council of 11 March 1996 on the legal protection of 121 | databases, as amended and/or succeeded, as well as other 122 | essentially equivalent rights anywhere in the world. 123 | - l. You means the individual or entity exercising the Licensed 124 | Rights under this Public License. Your has a corresponding 125 | meaning. 126 | 127 | - Section 2 – Scope. 128 | 129 | - a. License grant. 130 | - 1. Subject to the terms and conditions of this Public 131 | License, the Licensor hereby grants You a worldwide, 132 | royalty-free, non-sublicensable, non-exclusive, irrevocable 133 | license to exercise the Licensed Rights in the Licensed 134 | Material to: 135 | - A. reproduce and Share the Licensed Material, in whole 136 | or in part, for NonCommercial purposes only; and 137 | - B. produce, reproduce, and Share Adapted Material for 138 | NonCommercial purposes only. 139 | - 2. Exceptions and Limitations. For the avoidance of doubt, 140 | where Exceptions and Limitations apply to Your use, this 141 | Public License does not apply, and You do not need to comply 142 | with its terms and conditions. 143 | - 3. Term. The term of this Public License is specified in 144 | Section 6(a). 145 | - 4. Media and formats; technical modifications allowed. The 146 | Licensor authorizes You to exercise the Licensed Rights in 147 | all media and formats whether now known or hereafter 148 | created, and to make technical modifications necessary to do 149 | so. The Licensor waives and/or agrees not to assert any 150 | right or authority to forbid You from making technical 151 | modifications necessary to exercise the Licensed Rights, 152 | including technical modifications necessary to circumvent 153 | Effective Technological Measures. For purposes of this 154 | Public License, simply making modifications authorized by 155 | this Section 2(a)(4) never produces Adapted Material. 156 | - 5. Downstream recipients. 157 | - A. Offer from the Licensor – Licensed Material. Every 158 | recipient of the Licensed Material automatically 159 | receives an offer from the Licensor to exercise the 160 | Licensed Rights under the terms and conditions of this 161 | Public License. 162 | - B. No downstream restrictions. You may not offer or 163 | impose any additional or different terms or conditions 164 | on, or apply any Effective Technological Measures to, 165 | the Licensed Material if doing so restricts exercise of 166 | the Licensed Rights by any recipient of the Licensed 167 | Material. 168 | - 6. No endorsement. Nothing in this Public License 169 | constitutes or may be construed as permission to assert or 170 | imply that You are, or that Your use of the Licensed 171 | Material is, connected with, or sponsored, endorsed, or 172 | granted official status by, the Licensor or others 173 | designated to receive attribution as provided in Section 174 | 3(a)(1)(A)(i). 175 | - b. Other rights. 176 | - 1. Moral rights, such as the right of integrity, are not 177 | licensed under this Public License, nor are publicity, 178 | privacy, and/or other similar personality rights; however, 179 | to the extent possible, the Licensor waives and/or agrees 180 | not to assert any such rights held by the Licensor to the 181 | limited extent necessary to allow You to exercise the 182 | Licensed Rights, but not otherwise. 183 | - 2. Patent and trademark rights are not licensed under this 184 | Public License. 185 | - 3. To the extent possible, the Licensor waives any right to 186 | collect royalties from You for the exercise of the Licensed 187 | Rights, whether directly or through a collecting society 188 | under any voluntary or waivable statutory or compulsory 189 | licensing scheme. In all other cases the Licensor expressly 190 | reserves any right to collect such royalties, including when 191 | the Licensed Material is used other than for NonCommercial 192 | purposes. 193 | 194 | - Section 3 – License Conditions. 195 | 196 | Your exercise of the Licensed Rights is expressly made subject to 197 | the following conditions. 198 | 199 | - a. Attribution. 200 | - 1. If You Share the Licensed Material (including in modified 201 | form), You must: 202 | - A. retain the following if it is supplied by the 203 | Licensor with the Licensed Material: 204 | - i. identification of the creator(s) of the Licensed 205 | Material and any others designated to receive 206 | attribution, in any reasonable manner requested by 207 | the Licensor (including by pseudonym if designated); 208 | - ii. a copyright notice; 209 | - iii. a notice that refers to this Public License; 210 | - iv. a notice that refers to the disclaimer of 211 | warranties; 212 | - v. a URI or hyperlink to the Licensed Material to 213 | the extent reasonably practicable; 214 | - B. indicate if You modified the Licensed Material and 215 | retain an indication of any previous modifications; and 216 | - C. indicate the Licensed Material is licensed under this 217 | Public License, and include the text of, or the URI or 218 | hyperlink to, this Public License. 219 | - 2. You may satisfy the conditions in Section 3(a)(1) in any 220 | reasonable manner based on the medium, means, and context in 221 | which You Share the Licensed Material. For example, it may 222 | be reasonable to satisfy the conditions by providing a URI 223 | or hyperlink to a resource that includes the required 224 | information. 225 | - 3. If requested by the Licensor, You must remove any of the 226 | information required by Section 3(a)(1)(A) to the extent 227 | reasonably practicable. 228 | - 4. If You Share Adapted Material You produce, the Adapter's 229 | License You apply must not prevent recipients of the Adapted 230 | Material from complying with this Public License. 231 | 232 | - Section 4 – Sui Generis Database Rights. 233 | 234 | Where the Licensed Rights include Sui Generis Database Rights that 235 | apply to Your use of the Licensed Material: 236 | 237 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the 238 | right to extract, reuse, reproduce, and Share all or a 239 | substantial portion of the contents of the database for 240 | NonCommercial purposes only; 241 | - b. if You include all or a substantial portion of the database 242 | contents in a database in which You have Sui Generis Database 243 | Rights, then the database in which You have Sui Generis Database 244 | Rights (but not its individual contents) is Adapted Material; 245 | and 246 | - c. You must comply with the conditions in Section 3(a) if You 247 | Share all or a substantial portion of the contents of the 248 | database. 249 | 250 | For the avoidance of doubt, this Section 4 supplements and does not 251 | replace Your obligations under this Public License where the 252 | Licensed Rights include other Copyright and Similar Rights. 253 | 254 | - Section 5 – Disclaimer of Warranties and Limitation of Liability. 255 | 256 | - a. Unless otherwise separately undertaken by the Licensor, to 257 | the extent possible, the Licensor offers the Licensed Material 258 | as-is and as-available, and makes no representations or 259 | warranties of any kind concerning the Licensed Material, whether 260 | express, implied, statutory, or other. This includes, without 261 | limitation, warranties of title, merchantability, fitness for a 262 | particular purpose, non-infringement, absence of latent or other 263 | defects, accuracy, or the presence or absence of errors, whether 264 | or not known or discoverable. Where disclaimers of warranties 265 | are not allowed in full or in part, this disclaimer may not 266 | apply to You. 267 | - b. To the extent possible, in no event will the Licensor be 268 | liable to You on any legal theory (including, without 269 | limitation, negligence) or otherwise for any direct, special, 270 | indirect, incidental, consequential, punitive, exemplary, or 271 | other losses, costs, expenses, or damages arising out of this 272 | Public License or use of the Licensed Material, even if the 273 | Licensor has been advised of the possibility of such losses, 274 | costs, expenses, or damages. Where a limitation of liability is 275 | not allowed in full or in part, this limitation may not apply to 276 | You. 277 | - c. The disclaimer of warranties and limitation of liability 278 | provided above shall be interpreted in a manner that, to the 279 | extent possible, most closely approximates an absolute 280 | disclaimer and waiver of all liability. 281 | 282 | - Section 6 – Term and Termination. 283 | 284 | - a. This Public License applies for the term of the Copyright and 285 | Similar Rights licensed here. However, if You fail to comply 286 | with this Public License, then Your rights under this Public 287 | License terminate automatically. 288 | - b. Where Your right to use the Licensed Material has terminated 289 | under Section 6(a), it reinstates: 290 | 291 | - 1. automatically as of the date the violation is cured, 292 | provided it is cured within 30 days of Your discovery of the 293 | violation; or 294 | - 2. upon express reinstatement by the Licensor. 295 | 296 | For the avoidance of doubt, this Section 6(b) does not affect 297 | any right the Licensor may have to seek remedies for Your 298 | violations of this Public License. 299 | 300 | - c. For the avoidance of doubt, the Licensor may also offer the 301 | Licensed Material under separate terms or conditions or stop 302 | distributing the Licensed Material at any time; however, doing 303 | so will not terminate this Public License. 304 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 305 | License. 306 | 307 | - Section 7 – Other Terms and Conditions. 308 | 309 | - a. The Licensor shall not be bound by any additional or 310 | different terms or conditions communicated by You unless 311 | expressly agreed. 312 | - b. Any arrangements, understandings, or agreements regarding the 313 | Licensed Material not stated herein are separate from and 314 | independent of the terms and conditions of this Public License. 315 | 316 | - Section 8 – Interpretation. 317 | 318 | - a. For the avoidance of doubt, this Public License does not, and 319 | shall not be interpreted to, reduce, limit, restrict, or impose 320 | conditions on any use of the Licensed Material that could 321 | lawfully be made without permission under this Public License. 322 | - b. To the extent possible, if any provision of this Public 323 | License is deemed unenforceable, it shall be automatically 324 | reformed to the minimum extent necessary to make it enforceable. 325 | If the provision cannot be reformed, it shall be severed from 326 | this Public License without affecting the enforceability of the 327 | remaining terms and conditions. 328 | - c. No term or condition of this Public License will be waived 329 | and no failure to comply consented to unless expressly agreed to 330 | by the Licensor. 331 | - d. Nothing in this Public License constitutes or may be 332 | interpreted as a limitation upon, or waiver of, any privileges 333 | and immunities that apply to the Licensor or You, including from 334 | the legal processes of any jurisdiction or authority. 335 | 336 | Creative Commons is not a party to its public licenses. Notwithstanding, 337 | Creative Commons may elect to apply one of its public licenses to 338 | material it publishes and in those instances will be considered the 339 | "Licensor." The text of the Creative Commons public licenses is 340 | dedicated to the public domain under the CC0 Public Domain Dedication. 341 | Except for the limited purpose of indicating that material is shared 342 | under a Creative Commons public license or as otherwise permitted by the 343 | Creative Commons policies published at creativecommons.org/policies, 344 | Creative Commons does not authorize the use of the trademark "Creative 345 | Commons" or any other trademark or logo of Creative Commons without its 346 | prior written consent including, without limitation, in connection with 347 | any unauthorized modifications to any of its public licenses or any 348 | other arrangements, understandings, or agreements concerning use of 349 | licensed material. For the avoidance of doubt, this paragraph does not 350 | form part of the public licenses. 351 | 352 | Creative Commons may be contacted at creativecommons.org. 353 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include rave/configs/*.gin 2 | include rave/configs/augmentations/*.gin 3 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![rave_logo](docs/rave.png) 2 | 3 | # RAVE: Realtime Audio Variational autoEncoder 4 | 5 | Official implementation of _RAVE: A variational autoencoder for fast and high-quality neural audio synthesis_ ([article link](https://arxiv.org/abs/2111.05011)) by Antoine Caillon and Philippe Esling. 6 | 7 | If you use RAVE as a part of a music performance or installation, be sure to cite either this repository or the article ! 8 | 9 | If you want to share / discuss / ask things about RAVE and other research from ACIDS, you can do so in our [discord server](https://discord.gg/r9umPrGEWv) ! 10 | 11 | Please check the FAQ before posting an issue! 12 | 13 | **RAVE VST** RAVE VST for Windows, Mac and Linux is available as beta on the [corresponding Forum IRCAM webpage](https://forum.ircam.fr/projects/detail/rave-vst/). For problems, please write an issue here or [on the Forum IRCAM discussion page](https://discussion.forum.ircam.fr/c/rave-vst/651). 14 | 15 | **Tutorials** : new tutorials are available on the Forum IRCAM webpage, and video versions are coming soon! 16 | - [Tutorial: Neural Synthesis in a DAW with RAVE](https://forum.ircam.fr/article/detail/neural-synthesis-in-a-daw-with-rave/) 17 | - [Tutorial: Neural Synthesis in Max 8 with RAVE](https://forum.ircam.fr/article/detail/tutorial-neural-synthesis-in-max-8-with-rave/) 18 | - [Tutorial: Training RAVE models on custom data](https://forum.ircam.fr/article/detail/training-rave-models-on-custom-data/) 19 | 20 | ## Previous versions 21 | 22 | The original implementation of the RAVE model can be restored using 23 | 24 | ```bash 25 | git checkout v1 26 | ``` 27 | 28 | ## Installation 29 | 30 | Install RAVE using 31 | 32 | ```bash 33 | pip install acids-rave 34 | ``` 35 | 36 | **Warning** It is strongly advised to install `torch` and `torchaudio` before `acids-rave`, so you can choose the appropriate version of torch on the [library website](http://www.pytorch.org). For future compatibility with new devices (and modern Python environments), `rave-acids` does not enforce torch==1.13 anymore. 37 | 38 | You will need **ffmpeg** on your computer. You can install it locally inside your virtual environment using 39 | 40 | ```bash 41 | conda install ffmpeg 42 | ``` 43 | 44 | 45 | 46 | ## Colab 47 | 48 | A colab to train RAVEv2 is now available thanks to [hexorcismos](https://github.com/moiseshorta) ! 49 | [![colab_badge](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ih-gv1iHEZNuGhHPvCHrleLNXvooQMvI?usp=sharing) 50 | 51 | ## Usage 52 | 53 | Training a RAVE model usually involves 3 separate steps, namely _dataset preparation_, _training_ and _export_. 54 | 55 | ### Dataset preparation 56 | 57 | You can know prepare a dataset using two methods: regular and lazy. Lazy preprocessing allows RAVE to be trained directly on the raw files (i.e. mp3, ogg), without converting them first. **Warning**: lazy dataset loading will increase your CPU load by a large margin during training, especially on Windows. This can however be useful when training on large audio corpus which would not fit on a hard drive when uncompressed. In any case, prepare your dataset using 58 | 59 | ```bash 60 | rave preprocess --input_path /audio/folder --output_path /dataset/path --channels X (--lazy) 61 | ``` 62 | 63 | ### Training 64 | 65 | RAVEv2 has many different configurations. The improved version of the v1 is called `v2`, and can therefore be trained with 66 | 67 | ```bash 68 | rave train --config v2 --db_path /dataset/path --out_path /model/out --name give_a_name --channels X 69 | ``` 70 | 71 | We also provide a discrete configuration, similar to SoundStream or EnCodec 72 | 73 | ```bash 74 | rave train --config discrete ... 75 | ``` 76 | 77 | By default, RAVE is built with non-causal convolutions. If you want to make the model causal (hence lowering the overall latency of the model), you can use the causal mode 78 | 79 | ```bash 80 | rave train --config discrete --config causal ... 81 | ``` 82 | 83 | New in 2.3, data augmentations are also available to improve the model's generalization in low data regimes. You can add data augmentation by adding augmentation configuration files with the `--augment` keyword 84 | 85 | ```bash 86 | rave train --config v2 --augment mute --augment compress 87 | ``` 88 | 89 | Many other configuration files are available in `rave/configs` and can be combined. Here is a list of all the available configurations & augmentations : 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 |
TypeNameDescription
Architecturev1Original continuous model (minimum GPU memory : 8Go)
v2Improved continuous model (faster, higher quality) (minimum GPU memory : 16Go)
v2_smallv2 with a smaller receptive field, adpated adversarial training, and noise generator, adapted for timbre transfer for stationary signals (minimum GPU memory : 8Go)
v2_nopqmf(experimental) v2 without pqmf in generator (more efficient for bending purposes) (minimum GPU memory : 16Go)
v3v2 with Snake activation, descript discriminator and Adaptive Instance Normalization for real style transfer (minimum GPU memory : 32Go)
discreteDiscrete model (similar to SoundStream or EnCodec) (minimum GPU memory : 18Go)
onnxNoiseless v1 configuration for onnx usage (minimum GPU memory : 6Go)
raspberryLightweight configuration compatible with realtime RaspberryPi 4 inference (minimum GPU memory : 5Go)
Regularization (v2 only)defaultVariational Auto Encoder objective (ELBO)
wassersteinWasserstein Auto Encoder objective (MMD)
sphericalSpherical Auto Encoder objective
Discriminatorspectral_discriminatorUse the MultiScale discriminator from EnCodec.
OtherscausalUse causal convolutions
noiseEnables noise synthesizer V2
hybridEnable mel-spectrogram input
AugmentationsmuteRandomly mutes data batches (default prob : 0.1). Enforces the model to learn silence
compressRandomly compresses the waveform (equivalent to light non-linear amplification of batches)
gainApplies a random gain to waveform (default range : [-6, 3])
198 | 199 | ### Export 200 | 201 | Once trained, export your model to a torchscript file using 202 | 203 | ```bash 204 | rave export --run /path/to/your/run (--streaming) 205 | ``` 206 | 207 | Setting the `--streaming` flag will enable cached convolutions, making the model compatible with realtime processing. **If you forget to use the streaming mode and try to load the model in Max, you will hear clicking artifacts.** 208 | 209 | ## Prior 210 | 211 | For discrete models, we redirect the user to the `msprior` library [here](https://github.com/caillonantoine/msprior). However, as this library is still experimental, the prior from version 1.x has been re-integrated in v2.3. 212 | 213 | ### Training 214 | 215 | To train a prior for a pretrained RAVE model : 216 | 217 | ```bash 218 | rave train_prior --model /path/to/your/run --db_path /path/to/your_preprocessed_data --out_path /path/to/output 219 | ``` 220 | 221 | this will train a prior over the latent of the pretrained model `path/to/your/run`, and save the model and tensorboard logs to folder `/path/to/output`. 222 | 223 | ### Scripting 224 | 225 | To script a prior along with a RAVE model, export your model by providing the `--prior` keyword to your pretrained prior : 226 | 227 | ```bash 228 | rave export --run /path/to/your/run --prior /path/to/your/prior (--streaming) 229 | ``` 230 | 231 | ## Pretrained models 232 | 233 | Several pretrained streaming models [are available here](https://acids-ircam.github.io/rave_models_download). We'll keep the list updated with new models. 234 | 235 | ## Realtime usage 236 | 237 | This section presents how RAVE can be loaded inside [`nn~`](https://acids-ircam.github.io/nn_tilde/) in order to be used live with Max/MSP or PureData. 238 | 239 | ### Reconstruction 240 | 241 | A pretrained RAVE model named `darbouka.gin` available on your computer can be loaded inside `nn~` using the following syntax, where the default method is set to forward (i.e. encode then decode) 242 | 243 | 244 | 245 | This does the same thing as the following patch, but slightly faster. 246 | 247 | 248 | 249 | ### High-level manipulation 250 | 251 | Having an explicit access to the latent representation yielded by RAVE allows us to interact with the representation using Max/MSP or PureData signal processing tools: 252 | 253 | 254 | 255 | ### Style transfer 256 | 257 | By default, RAVE can be used as a style transfer tool, based on the large compression ratio of the model. We recently added a technique inspired from StyleGAN to include Adaptive Instance Normalization to the reconstruction process, effectively allowing to define _source_ and _target_ styles directly inside Max/MSP or PureData, using the attribute system of `nn~`. 258 | 259 | 260 | 261 | Other attributes, such as `enable` or `gpu` can enable/disable computation, or use the gpu to speed up things (still experimental). 262 | 263 | ## Offline usage 264 | 265 | A batch generation script has been released in v2.3 to allow transformation of large amount of files 266 | 267 | ```bash 268 | rave generate model_path path_1 path_2 --out out_path 269 | ``` 270 | 271 | where `model_path` is the path to your trained model (original or scripted), `path_X` a list of audio files or directories, and `out_path` the out directory of the generations. 272 | 273 | ## Discussion 274 | 275 | If you have questions, want to share your experience with RAVE or share musical pieces done with the model, you can use the [Discussion tab](https://github.com/acids-ircam/RAVE/discussions) ! 276 | 277 | ## Demonstration 278 | 279 | ### RAVE x nn~ 280 | 281 | Demonstration of what you can do with RAVE and the nn~ external for maxmsp ! 282 | 283 | [![RAVE x nn~](http://img.youtube.com/vi/dMZs04TzxUI/mqdefault.jpg)](https://www.youtube.com/watch?v=dMZs04TzxUI) 284 | 285 | ### embedded RAVE 286 | 287 | Using nn~ for puredata, RAVE can be used in realtime on embedded platforms ! 288 | 289 | [![RAVE x nn~](http://img.youtube.com/vi/jAIRf4nGgYI/mqdefault.jpg)](https://www.youtube.com/watch?v=jAIRf4nGgYI) 290 | 291 | # Frequently Asked Question (FAQ) 292 | 293 | **Question** : my preprocessing is stuck, showing `0it[00:00, ?it/s]`
294 | **Answer** : This means that the audio files in your dataset are too short to provide a sufficient temporal scope to RAVE. Try decreasing the signal window with the `--num_signal XXX(samples)` with `preprocess`, without forgetting afterwards to add the `--n_signal XXX(samples)` with `train` 295 | 296 | **Question** : During training I got an exception resembling `ValueError: n_components=128 must be between 0 and min(n_samples, n_features)=64 with svd_solver='full'`
297 | **Answer** : This means that your dataset does not have enough data batches to compute the intern latent PCA, that requires at least 128 examples (then batches). 298 | 299 | 300 | # Funding 301 | 302 | This work is led at IRCAM, and has been funded by the following projects 303 | 304 | - [ANR MakiMono](https://acids.ircam.fr/course/makimono/) 305 | - [ACTOR](https://www.actorproject.org/) 306 | - [DAFNE+](https://dafneplus.eu/) N° 101061548 307 | 308 | 309 | -------------------------------------------------------------------------------- /docs/darbouka_prior.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/darbouka_prior.mp3 -------------------------------------------------------------------------------- /docs/log_distance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/log_distance.png -------------------------------------------------------------------------------- /docs/log_fidelity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/log_fidelity.png -------------------------------------------------------------------------------- /docs/log_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/log_gan.png -------------------------------------------------------------------------------- /docs/maxmsp_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/maxmsp_screenshot.png -------------------------------------------------------------------------------- /docs/rave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave.png -------------------------------------------------------------------------------- /docs/rave_attribute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_attribute.png -------------------------------------------------------------------------------- /docs/rave_buffer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_buffer.png -------------------------------------------------------------------------------- /docs/rave_encode_decode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_encode_decode.png -------------------------------------------------------------------------------- /docs/rave_high_level.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_high_level.png -------------------------------------------------------------------------------- /docs/rave_method_forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_method_forward.png -------------------------------------------------------------------------------- /docs/tensorboard_guide.md: -------------------------------------------------------------------------------- 1 | # Tensorboard guide 2 | 3 | ## Latent space size estimation 4 | 5 | During training, RAVE regularly estimates the **size** of the latent space given a specific dataset for a given *fidelity*. The fidelity parameter is a percentage that defines how well the model should be able to reconstruct an input audio sample. 6 | 7 | Usually values around 80% yield correct yet not accurate reconstructions. Values around 95% are most of the time sufficient to have both a compact latent space and correct reconstructions. 8 | 9 | We log the estimated size of the latent space for several values of fidelity in tensorboard (80, 90, 95 and 99%). 10 | 11 | ![log_fidelity](log_fidelity.png) 12 | 13 | ## Reconstrution error 14 | 15 | The values you should look at for tracking the reconstruction error of the model are the *distance* and *validation* logs 16 | 17 | ![log_distance.png](log_distance.png) 18 | 19 | When the 2 phase kicks in, those values increase - **that's usually normal** 20 | 21 | ## Adversarial losses 22 | 23 | The `loss_dis, loss_gen, pred_true, pred_fake` losses only appear during the second phase. They are usually harder to read, as most of GAN losses are, bu we include here an example of what *normal* logs should look like 24 | 25 | ![log_gan.png](log_gan.png) -------------------------------------------------------------------------------- /docs/training_setup.md: -------------------------------------------------------------------------------- 1 | ![logo](rave.png) 2 | 3 | # Training setup 4 | 5 | 1. You should train on a _CUDA-enabled_ machine (i.e with an nvidia-card) 6 | - You can use either **Linux** or **Windows** 7 | - However we advise to use **Linux** if available 8 | - Training RAVE without a hardware accelerator (GPU, TPU) will take ages, and is not recommended 9 | 2. Make sure that you have CUDA enabled 10 | - Go to a terminal an enter `nvidia-smi` 11 | - If a message appears with the name of your graphic card and the available memory, it's all good ! 12 | - Otherwise, you have to install **cuda** on your computer (we don't provide support for that, lots of guides are available online) 13 | 3. Let's install python ! 14 | 15 | # Python installation 16 | 17 | Python is often pre-installed on most computers, but we won't use this version. Instead, we will install a **conda** distribution on the machine. This keeps different versions of python separate for different projects, and allows regular users to install new packages without sudo access. 18 | 19 | You can follow the [instructions here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) to install a miniconda environment on your computer. 20 | 21 | Once installed, you know that you are inside your miniconda environment if there's a "`(base)`" at the beginning of your terminal. 22 | 23 | # RAVE installation 24 | 25 | We will create a new virtual environment for RAVE. 26 | 27 | ```bash 28 | conda create -n rave python=3.9 29 | ``` 30 | 31 | Each time we want to use RAVE, we can (and **should**) activate this environment using 32 | 33 | ```bash 34 | conda activate rave 35 | ``` 36 | 37 | Let's clone RAVE and install the requirements ! 38 | 39 | ```bash 40 | git clone https://github.com/acids-ircam/RAVE 41 | cd RAVE 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | You can now use `python cli_helper.py` to start a new training ! 46 | 47 | # About the dataset 48 | 49 | A good rule of thumb is **more is better**. You might want to have _at least_ 3h of homogeneous recordings to train RAVE, more if your dataset is complex (e.g mixtures of instruments, lots of variations...) 50 | 51 | If you have a folder filled with various audio files (any extension, any sampling rate), you can use the `resample` utility in this folder 52 | 53 | ```bash 54 | conda activate rave 55 | resample --sr TARGET_SAMPLING_RATE --augment 56 | ``` 57 | 58 | It will convert, resample, crop and augment all audio files present in the directory to an output directory called `out_TARGET_SAMPLING_RATE/` (which is the one you should give to `cli_helper.py` when asked for the path of the .wav files). 59 | -------------------------------------------------------------------------------- /rave/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cached_conv as cc 4 | import gin 5 | import torch 6 | 7 | 8 | BASE_PATH: Path = Path(__file__).parent 9 | 10 | gin.add_config_file_search_path(BASE_PATH) 11 | gin.add_config_file_search_path(BASE_PATH.joinpath('configs')) 12 | gin.add_config_file_search_path(BASE_PATH.joinpath('configs', 'augmentations')) 13 | 14 | 15 | def __safe_configurable(name): 16 | try: 17 | setattr(cc, name, gin.get_configurable(f"cc.{name}")) 18 | except ValueError: 19 | setattr(cc, name, gin.external_configurable(getattr(cc, name), module="cc")) 20 | 21 | # cc.get_padding = gin.external_configurable(cc.get_padding, module="cc") 22 | # cc.Conv1d = gin.external_configurable(cc.Conv1d, module="cc") 23 | # cc.ConvTranspose1d = gin.external_configurable(cc.ConvTranspose1d, module="cc") 24 | 25 | __safe_configurable("get_padding") 26 | __safe_configurable("Conv1d") 27 | __safe_configurable("ConvTranspose1d") 28 | 29 | from .blocks import * 30 | from .discriminator import * 31 | from .model import RAVE, BetaWarmupCallback 32 | from .pqmf import * 33 | from .balancer import * 34 | -------------------------------------------------------------------------------- /rave/balancer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import gin.torch 3 | 4 | 5 | @gin.configurable 6 | class Balancer(nn.Module): 7 | def __init__(self): 8 | super().__init__(self) 9 | 10 | def forward(self, *args, **kwargs): 11 | raise RuntimeError('Balancer has been disabled in newest RAVE version. \n' \ 12 | 'If you try to import checkpoint trained with a previous version, remove it from configuration.') -------------------------------------------------------------------------------- /rave/configs/adain.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import blocks 5 | 6 | blocks.EncoderV2: 7 | adain = @blocks.AdaptiveInstanceNormalization 8 | 9 | blocks.GeneratorV2: 10 | adain = @blocks.AdaptiveInstanceNormalization -------------------------------------------------------------------------------- /rave/configs/augmentations/compress.gin: -------------------------------------------------------------------------------- 1 | # dataset.get_dataset: 2 | # augmentations = [ 3 | # @augmentations/transforms.RandomCompress(), 4 | # ] 5 | 6 | add_augmentation: 7 | aug = @augmentations/transforms.RandomCompress() 8 | 9 | -------------------------------------------------------------------------------- /rave/configs/augmentations/gain.gin: -------------------------------------------------------------------------------- 1 | add_augmentation: 2 | aug = @augmentations/transforms.RandomGain() 3 | 4 | -------------------------------------------------------------------------------- /rave/configs/augmentations/mute.gin: -------------------------------------------------------------------------------- 1 | add_augmentation: 2 | aug = @augmentations/transforms.RandomMute() 3 | 4 | -------------------------------------------------------------------------------- /rave/configs/causal.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import cached_conv as cc 4 | 5 | cc.get_padding.mode = 'causal' 6 | -------------------------------------------------------------------------------- /rave/configs/descript_discriminator.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import descript_discriminator 5 | 6 | rave.RAVE: 7 | discriminator = @descript_discriminator.DescriptDiscriminator -------------------------------------------------------------------------------- /rave/configs/discrete.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | include "configs/v2.gin" 4 | 5 | import rave 6 | from rave import core 7 | from rave import blocks 8 | from rave import discriminator 9 | from rave import quantization 10 | 11 | import torch.nn as nn 12 | 13 | NUM_QUANTIZERS = 16 14 | RATIOS = [4, 4, 2, 2] 15 | LATENT_SIZE = 128 16 | CODEBOOK_SIZE = 1024 17 | DYNAMIC_MASKING = False 18 | CAPACITY = 96 19 | NOISE_AUGMENTATION = 128 20 | PHASE_1_DURATION = 200000 21 | 22 | core.AudioDistanceV1.log_epsilon = 1 23 | 24 | # ENCODER 25 | 26 | blocks.DiscreteEncoder: 27 | encoder_cls = @blocks.EncoderV2 28 | vq_cls = @quantization.ResidualVectorQuantization 29 | num_quantizers = %NUM_QUANTIZERS 30 | noise_augmentation = %NOISE_AUGMENTATION 31 | 32 | blocks.EncoderV2: 33 | n_out = 1 34 | 35 | quantization.ResidualVectorQuantization: 36 | num_quantizers = %NUM_QUANTIZERS 37 | dim = %LATENT_SIZE 38 | codebook_size = %CODEBOOK_SIZE 39 | 40 | # RAVE 41 | rave.RAVE: 42 | encoder = @blocks.DiscreteEncoder 43 | phase_1_duration = %PHASE_1_DURATION 44 | warmup_quantize = -1 45 | discriminator = @discriminator.CombineDiscriminators 46 | gan_loss = @core.hinge_gan 47 | valid_signal_crop = True 48 | num_skipped_features = 0 49 | update_discriminator_every = 4 50 | 51 | rave.BetaWarmupCallback: 52 | initial_value = .1 53 | target_value = .1 54 | warmup_len = 1 -------------------------------------------------------------------------------- /rave/configs/discrete_v3.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | include "configs/discrete.gin" 4 | include "configs/snake.gin" 5 | include "configs/descript_discriminator.gin" 6 | 7 | import rave 8 | 9 | rave.BetaWarmupCallback: 10 | initial_value = 1e-6 11 | target_value = 5e-2 12 | warmup_len = 20000 -------------------------------------------------------------------------------- /rave/configs/hybrid.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from rave import blocks 4 | from rave import core 5 | from torchaudio import transforms 6 | 7 | import rave 8 | 9 | include "configs/v2.gin" 10 | 11 | N_FFT = 2048 12 | N_MELS = 128 13 | HOP_LENGTH = 256 14 | ENCODER_RATIOS = [2, 2, 2] 15 | NUM_GRU_LAYERS = 2 16 | 17 | blocks.EncoderV2: 18 | data_size = %N_MELS 19 | ratios = %ENCODER_RATIOS 20 | dilations = [1] 21 | 22 | core.n_fft_to_num_bands: 23 | n_fft = %N_FFT 24 | 25 | transforms.MelSpectrogram: 26 | sample_rate = %SAMPLING_RATE 27 | n_fft = %N_FFT 28 | win_length = %N_FFT 29 | hop_length = %HOP_LENGTH 30 | normalized = True 31 | n_mels = %N_MELS 32 | 33 | blocks.GeneratorV2: 34 | recurrent_layer = @blocks.GRU 35 | 36 | blocks.GRU: 37 | latent_size = %LATENT_SIZE 38 | num_layers = %NUM_GRU_LAYERS 39 | 40 | rave.RAVE: 41 | spectrogram = @transforms.MelSpectrogram() 42 | input_mode = "mel" 43 | -------------------------------------------------------------------------------- /rave/configs/noise.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from rave import blocks 4 | 5 | blocks.GeneratorV2: 6 | noise_module = @blocks.NoiseGeneratorV2 7 | 8 | blocks.NoiseGeneratorV2: 9 | hidden_size = 128 10 | data_size = %N_BAND 11 | ratios = [2, 2, 2] 12 | noise_bands = 5 -------------------------------------------------------------------------------- /rave/configs/normalize_ambient.gin: -------------------------------------------------------------------------------- 1 | dataset.get_dataset: 2 | augmentations = [ 3 | @augmentations/transforms.Compress() 4 | ] 5 | 6 | augmentations/transforms.Compress: 7 | time='0.01,0.01' 8 | lookup='6:-30,-15,-10,-8,0,-5' 9 | sr=%SAMPLING_RATE -------------------------------------------------------------------------------- /rave/configs/onnx.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | include "configs/v1.gin" 4 | 5 | import rave 6 | from rave import blocks 7 | 8 | CAPACITY = 32 9 | 10 | blocks.Generator.use_noise = False -------------------------------------------------------------------------------- /rave/configs/prior/prior_v1.gin: -------------------------------------------------------------------------------- 1 | VariationalPrior: 2 | resolution = 32 3 | res_size = 512 4 | skp_size=256 5 | kernel_size=3 6 | cycle_size=4 7 | n_layers=10 8 | sr=@get_model_sr() 9 | -------------------------------------------------------------------------------- /rave/configs/raspberry.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | include "configs/onnx.gin" 4 | 5 | CAPACITY = 16 -------------------------------------------------------------------------------- /rave/configs/snake.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from rave import blocks 4 | 5 | ACTIVATION = @blocks.Snake 6 | 7 | blocks.ResidualLayer: 8 | activation = %ACTIVATION 9 | 10 | blocks.DilatedUnit: 11 | activation = %ACTIVATION 12 | 13 | blocks.UpsampleLayer: 14 | activation = %ACTIVATION 15 | 16 | blocks.NoiseGeneratorV2: 17 | activation = %ACTIVATION 18 | 19 | blocks.EncoderV2: 20 | activation = %ACTIVATION 21 | 22 | blocks.GeneratorV2: 23 | activation = %ACTIVATION -------------------------------------------------------------------------------- /rave/configs/spectral_discriminator.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import discriminator 5 | 6 | discriminator.MultiScaleSpectralDiscriminator: 7 | scales = [4096, 2048, 1024, 512, 256] 8 | convnet = @discriminator.EncodecConvNet 9 | 10 | discriminator.EncodecConvNet: 11 | capacity = 32 12 | 13 | discriminator.CombineDiscriminators: 14 | discriminators = [ 15 | @discriminator.MultiScaleDiscriminator, 16 | @discriminator.MultiScaleSpectralDiscriminator 17 | ] -------------------------------------------------------------------------------- /rave/configs/spherical.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import blocks 5 | 6 | LATENT_SIZE = 16 7 | 8 | blocks.EncoderV2.n_out = 1 9 | 10 | blocks.SphericalEncoder: 11 | encoder_cls = @blocks.EncoderV2 12 | 13 | rave.RAVE: 14 | encoder = @blocks.SphericalEncoder 15 | phase_1_duration = 200000 16 | -------------------------------------------------------------------------------- /rave/configs/v1.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import pqmf 5 | from rave import core 6 | from rave import blocks 7 | from rave import discriminator 8 | from rave import dataset 9 | 10 | import cached_conv as cc 11 | import torch 12 | 13 | SAMPLING_RATE = 44100 14 | CAPACITY = 64 15 | N_BAND = 16 16 | LATENT_SIZE = 128 17 | RATIOS = [4, 4, 4, 2] 18 | PHASE_1_DURATION = 1000000 19 | 20 | # CORE CONFIGURATION 21 | core.AudioDistanceV1: 22 | multiscale_stft = @core.MultiScaleSTFT 23 | log_epsilon = 1e-7 24 | 25 | core.MultiScaleSTFT: 26 | scales = [2048, 1024, 512, 256, 128] 27 | sample_rate = %SAMPLING_RATE 28 | magnitude = True 29 | 30 | dataset.split_dataset.max_residual = 1000 31 | 32 | # CONVOLUTION CONFIGURATION 33 | cc.Conv1d.bias = False 34 | cc.ConvTranspose1d.bias = False 35 | 36 | # PQMF 37 | pqmf.CachedPQMF: 38 | attenuation = 100 39 | n_band = %N_BAND 40 | 41 | blocks.normalization.mode = 'weight_norm' 42 | 43 | # ENCODER 44 | blocks.Encoder: 45 | data_size = %N_BAND 46 | capacity = %CAPACITY 47 | latent_size = %LATENT_SIZE 48 | ratios = %RATIOS 49 | sample_norm = False 50 | repeat_layers = 1 51 | 52 | variational/blocks.Encoder.n_out = 2 53 | 54 | blocks.VariationalEncoder: 55 | encoder = @variational/blocks.Encoder 56 | 57 | # DECODER 58 | blocks.Generator: 59 | latent_size = %LATENT_SIZE 60 | capacity = %CAPACITY 61 | data_size = %N_BAND 62 | ratios = %RATIOS 63 | loud_stride = 1 64 | use_noise = True 65 | 66 | blocks.ResidualStack: 67 | kernel_sizes = [3] 68 | dilations_list = [[1, 1], [3, 1], [5, 1]] 69 | 70 | blocks.NoiseGenerator: 71 | ratios = [4, 4, 4] 72 | noise_bands = 5 73 | 74 | # DISCRIMINATOR 75 | discriminator.ConvNet: 76 | in_size = 1 77 | out_size = 1 78 | capacity = %CAPACITY 79 | n_layers = 4 80 | stride = 4 81 | 82 | scales/discriminator.ConvNet: 83 | conv = @torch.nn.Conv1d 84 | kernel_size = 15 85 | 86 | discriminator.MultiScaleDiscriminator: 87 | n_discriminators = 3 88 | convnet = @scales/discriminator.ConvNet 89 | 90 | feature_matching/core.mean_difference: 91 | norm = 'L1' 92 | 93 | # MODEL ASSEMBLING 94 | rave.RAVE: 95 | latent_size = %LATENT_SIZE 96 | pqmf = @pqmf.CachedPQMF 97 | sampling_rate = %SAMPLING_RATE 98 | encoder = @blocks.VariationalEncoder 99 | decoder = @blocks.Generator 100 | discriminator = @discriminator.MultiScaleDiscriminator 101 | phase_1_duration = %PHASE_1_DURATION 102 | gan_loss = @core.hinge_gan 103 | valid_signal_crop = False 104 | feature_matching_fun = @feature_matching/core.mean_difference 105 | num_skipped_features = 0 106 | audio_distance = @core.AudioDistanceV1 107 | multiband_audio_distance = @core.AudioDistanceV1 108 | weights = { 109 | 'feature_matching': 10 110 | } 111 | 112 | rave.BetaWarmupCallback: 113 | initial_value = .1 114 | target_value = .1 115 | warmup_len = 1 -------------------------------------------------------------------------------- /rave/configs/v2.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import core 5 | from rave import blocks 6 | from rave import discriminator 7 | 8 | import torch.nn as nn 9 | 10 | include "configs/v1.gin" 11 | 12 | KERNEL_SIZE = 3 13 | DILATIONS = [ 14 | [1, 3, 9], 15 | [1, 3, 9], 16 | [1, 3, 9], 17 | [1, 3], 18 | ] 19 | RATIOS = [4, 4, 4, 2] 20 | CAPACITY = 96 21 | NOISE_AUGMENTATION = 0 22 | 23 | core.AudioDistanceV1.log_epsilon = 1e-7 24 | 25 | core.get_augmented_latent_size: 26 | latent_size = %LATENT_SIZE 27 | noise_augmentation = %NOISE_AUGMENTATION 28 | 29 | # ENCODER 30 | blocks.EncoderV2: 31 | data_size = %N_BAND 32 | capacity = %CAPACITY 33 | ratios = %RATIOS 34 | latent_size = %LATENT_SIZE 35 | n_out = 2 36 | kernel_size = %KERNEL_SIZE 37 | dilations = %DILATIONS 38 | 39 | blocks.VariationalEncoder: 40 | encoder = @variational/blocks.EncoderV2 41 | 42 | # GENERATOR 43 | blocks.GeneratorV2: 44 | data_size = %N_BAND 45 | capacity = %CAPACITY 46 | ratios = %RATIOS 47 | latent_size = @core.get_augmented_latent_size() 48 | kernel_size = %KERNEL_SIZE 49 | dilations = %DILATIONS 50 | amplitude_modulation = True 51 | 52 | # DISCRIMINATOR 53 | periods/discriminator.ConvNet: 54 | conv = @nn.Conv2d 55 | kernel_size = (5, 1) 56 | 57 | spectral/discriminator.ConvNet: 58 | conv = @nn.Conv1d 59 | kernel_size = 5 60 | stride = 2 61 | 62 | discriminator.MultiPeriodDiscriminator: 63 | periods = [2, 3, 5, 7, 11] 64 | convnet = @periods/discriminator.ConvNet 65 | 66 | discriminator.MultiScaleSpectralDiscriminator1d: 67 | scales = [4096, 2048, 1024, 512, 256] 68 | convnet = @spectral/discriminator.ConvNet 69 | 70 | discriminator.CombineDiscriminators: 71 | discriminators = [ 72 | @discriminator.MultiPeriodDiscriminator, 73 | @discriminator.MultiScaleDiscriminator, 74 | # @discriminator.MultiScaleSpectralDiscriminator1d, 75 | ] 76 | 77 | feature_matching/core.mean_difference: 78 | relative = True 79 | 80 | # RAVE 81 | rave.RAVE: 82 | discriminator = @discriminator.CombineDiscriminators 83 | valid_signal_crop = True 84 | num_skipped_features = 1 85 | decoder = @blocks.GeneratorV2 86 | update_discriminator_every = 4 87 | weights = { 88 | 'feature_matching': 20, 89 | } 90 | 91 | rave.BetaWarmupCallback: 92 | initial_value = 1e-6 93 | target_value = 5e-2 94 | warmup_len = 20000 95 | -------------------------------------------------------------------------------- /rave/configs/v2_nopqmf.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import core 5 | from rave import dataset 6 | from rave import blocks 7 | from rave import discriminator 8 | from rave import transforms 9 | 10 | import torch.nn as nn 11 | 12 | include "configs/v1.gin" 13 | 14 | KERNEL_SIZE = 3 15 | DILATIONS = [ 16 | [1, 3, 9], 17 | [1, 3, 9], 18 | [1, 3, 9], 19 | [1, 3], 20 | ] 21 | RATIOS = [4, 4, 4, 2] 22 | CAPACITY = 64 23 | NOISE_AUGMENTATION = 0 24 | 25 | core.AudioDistanceV1.log_epsilon = 1e-7 26 | 27 | core.get_augmented_latent_size: 28 | latent_size = %LATENT_SIZE 29 | noise_augmentation = %NOISE_AUGMENTATION 30 | 31 | 32 | # AUGMENTATIONS 33 | dataset.get_dataset: 34 | augmentations = [ 35 | @augmentations/transforms.RandomCompress() 36 | ] 37 | 38 | augmentations/transforms.RandomCompress: 39 | amp_range = [-60,-10] 40 | threshold=-40 41 | prob = 0.5 42 | sr=%SAMPLING_RATE 43 | 44 | # ENCODER 45 | blocks.EncoderV2: 46 | data_size = %N_BAND 47 | capacity = %CAPACITY 48 | ratios = [4, 4, 4, 2] 49 | latent_size = %LATENT_SIZE 50 | n_out = 2 51 | kernel_size = %KERNEL_SIZE 52 | dilations = %DILATIONS 53 | 54 | blocks.VariationalEncoder: 55 | encoder = @variational/blocks.EncoderV2 56 | 57 | # GENERATOR 58 | blocks.GeneratorV2: 59 | capacity = %CAPACITY 60 | ratios = [8, 8, 8, 4] 61 | latent_size = @core.get_augmented_latent_size() 62 | kernel_size = %KERNEL_SIZE 63 | dilations = %DILATIONS 64 | amplitude_modulation = True 65 | 66 | # DISCRIMINATOR 67 | periods/discriminator.ConvNet: 68 | conv = @nn.Conv2d 69 | kernel_size = (5, 1) 70 | 71 | spectral/discriminator.ConvNet: 72 | conv = @nn.Conv1d 73 | kernel_size = 5 74 | stride = 2 75 | 76 | discriminator.MultiPeriodDiscriminator: 77 | periods = [2, 3, 5, 7, 11] 78 | convnet = @periods/discriminator.ConvNet 79 | 80 | discriminator.MultiScaleSpectralDiscriminator1d: 81 | scales = [4096, 2048, 1024, 512, 256] 82 | convnet = @spectral/discriminator.ConvNet 83 | 84 | discriminator.CombineDiscriminators: 85 | discriminators = [ 86 | @discriminator.MultiPeriodDiscriminator, 87 | @discriminator.MultiScaleDiscriminator, 88 | # @discriminator.MultiScaleSpectralDiscriminator1d, 89 | ] 90 | 91 | feature_matching/core.mean_difference: 92 | relative = True 93 | 94 | # RAVE 95 | rave.RAVE: 96 | n_bands = %N_BAND 97 | discriminator = @discriminator.CombineDiscriminators 98 | valid_signal_crop = True 99 | num_skipped_features = 1 100 | decoder = @blocks.GeneratorV2 101 | phase_1_duration = 1000000 102 | weights = { 103 | 'feature_matching': 20 104 | } 105 | update_discriminator_every = 4 106 | output_mode = "raw" 107 | audio_monitor_epochs = 10 108 | 109 | rave.BetaWarmupCallback: 110 | initial_value = 1e-6 111 | target_value = 1e-2 112 | warmup_len = 500000 113 | 114 | -------------------------------------------------------------------------------- /rave/configs/v2_nopqmf_small.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import core 5 | from rave import dataset 6 | from rave import blocks 7 | from rave import discriminator 8 | from rave import balancer 9 | from rave import transforms 10 | 11 | import torch.nn as nn 12 | 13 | include "configs/v1.gin" 14 | 15 | KERNEL_SIZE = 3 16 | DILATIONS = [ 17 | [1, 3, 9], 18 | [1, 3, 9], 19 | [1, 3, 9], 20 | [1, 3], 21 | ] 22 | RATIOS = [4, 4, 4, 2] 23 | CAPACITY = 64 24 | NOISE_AUGMENTATION = 0 25 | 26 | core.AudioDistanceV1.log_epsilon = 1e-7 27 | 28 | core.get_augmented_latent_size: 29 | latent_size = %LATENT_SIZE 30 | noise_augmentation = %NOISE_AUGMENTATION 31 | 32 | 33 | # AUGMENTATIONS 34 | dataset.get_dataset: 35 | augmentations = [ 36 | @augmentations/transforms.Compress() 37 | ] 38 | 39 | augmentations/transforms.Compress: 40 | amp_range = [-60,-10] 41 | threshold=-40 42 | prob = 0.5 43 | 44 | # ENCODER 45 | blocks.EncoderV2: 46 | data_size = %N_BAND 47 | capacity = %CAPACITY 48 | ratios = [4, 4, 4, 2] 49 | latent_size = %LATENT_SIZE 50 | n_out = 2 51 | kernel_size = %KERNEL_SIZE 52 | dilations = %DILATIONS 53 | 54 | blocks.VariationalEncoder: 55 | encoder = @variational/blocks.EncoderV2 56 | 57 | # GENERATOR 58 | blocks.GeneratorV2: 59 | capacity = %CAPACITY 60 | ratios = [8, 8, 8, 4] 61 | latent_size = @core.get_augmented_latent_size() 62 | kernel_size = %KERNEL_SIZE 63 | dilations = %DILATIONS 64 | amplitude_modulation = True 65 | 66 | # DISCRIMINATOR 67 | periods/discriminator.ConvNet: 68 | conv = @nn.Conv2d 69 | kernel_size = (5, 1) 70 | 71 | spectral/discriminator.ConvNet: 72 | conv = @nn.Conv1d 73 | kernel_size = 5 74 | stride = 2 75 | 76 | discriminator.MultiPeriodDiscriminator: 77 | periods = [2, 3, 5, 7, 11] 78 | convnet = @periods/discriminator.ConvNet 79 | 80 | discriminator.MultiScaleSpectralDiscriminator1d: 81 | scales = [4096, 2048, 1024, 512, 256] 82 | convnet = @spectral/discriminator.ConvNet 83 | 84 | discriminator.CombineDiscriminators: 85 | discriminators = [ 86 | @discriminator.MultiPeriodDiscriminator, 87 | @discriminator.MultiScaleDiscriminator, 88 | # @discriminator.MultiScaleSpectralDiscriminator1d, 89 | ] 90 | 91 | feature_matching/core.mean_difference: 92 | relative = True 93 | 94 | # RAVE 95 | rave.RAVE: 96 | n_bands = %N_BAND 97 | discriminator = @discriminator.CombineDiscriminators 98 | valid_signal_crop = True 99 | num_skipped_features = 1 100 | decoder = @blocks.GeneratorV2 101 | phase_1_duration = 500000 102 | loss_weights = {'reg': 0.02, 'feature_matching': 20} 103 | update_discriminator_every = 4 104 | enable_pqmf_encode = True 105 | enable_pqmf_decode = False 106 | audio_monitor_epochs = 10 107 | 108 | -------------------------------------------------------------------------------- /rave/configs/v2_small.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import core 5 | from rave import blocks 6 | from rave import discriminator 7 | 8 | import torch.nn as nn 9 | 10 | include "configs/v1.gin" 11 | 12 | KERNEL_SIZE = 3 13 | DILATIONS = [ 14 | [1, 3, 9], 15 | [1, 3, 9], 16 | [1, 3, 9], 17 | [1, 3], 18 | ] 19 | RATIOS = [4, 2, 2, 2] 20 | CAPACITY = 48 21 | NOISE_AUGMENTATION = 0 22 | 23 | core.AudioDistanceV1.log_epsilon = 1e-7 24 | 25 | core.get_augmented_latent_size: 26 | latent_size = %LATENT_SIZE 27 | noise_augmentation = %NOISE_AUGMENTATION 28 | 29 | # ENCODER 30 | blocks.EncoderV2: 31 | data_size = %N_BAND 32 | capacity = %CAPACITY 33 | ratios = %RATIOS 34 | latent_size = %LATENT_SIZE 35 | n_out = 2 36 | kernel_size = %KERNEL_SIZE 37 | dilations = %DILATIONS 38 | 39 | blocks.VariationalEncoder: 40 | encoder = @variational/blocks.EncoderV2 41 | 42 | blocks.NoiseGeneratorV2: 43 | hidden_size = 64 44 | data_size = %N_BAND 45 | ratios = [2, 2, 2] 46 | noise_bands = 32 47 | 48 | # GENERATOR 49 | blocks.GeneratorV2: 50 | data_size = %N_BAND 51 | capacity = %CAPACITY 52 | ratios = %RATIOS 53 | latent_size = @core.get_augmented_latent_size() 54 | kernel_size = %KERNEL_SIZE 55 | dilations = %DILATIONS 56 | amplitude_modulation = True 57 | noise_module = @blocks.NoiseGeneratorV2 58 | 59 | # DISCRIMINATOR 60 | periods/discriminator.ConvNet: 61 | conv = @nn.Conv2d 62 | kernel_size = (5, 1) 63 | 64 | spectral/discriminator.ConvNet: 65 | conv = @nn.Conv1d 66 | kernel_size = 5 67 | stride = 2 68 | 69 | discriminator.MultiPeriodDiscriminator: 70 | periods = [2, 3, 5, 7, 11] 71 | convnet = @periods/discriminator.ConvNet 72 | 73 | discriminator.MultiScaleSpectralDiscriminator1d: 74 | scales = [4096, 2048, 1024, 512, 256] 75 | convnet = @spectral/discriminator.ConvNet 76 | 77 | discriminator.CombineDiscriminators: 78 | discriminators = [ 79 | @discriminator.MultiPeriodDiscriminator, 80 | @discriminator.MultiScaleDiscriminator, 81 | # @discriminator.MultiScaleSpectralDiscriminator1d, 82 | ] 83 | 84 | feature_matching/core.mean_difference: 85 | relative = True 86 | 87 | # RAVE 88 | rave.RAVE: 89 | discriminator = @discriminator.CombineDiscriminators 90 | valid_signal_crop = True 91 | num_skipped_features = 1 92 | decoder = @blocks.GeneratorV2 93 | update_discriminator_every = 2 94 | weights = { 95 | 'feature_matching': 20, 96 | } 97 | 98 | rave.BetaWarmupCallback: 99 | initial_value = .01 100 | target_value = .01 101 | warmup_len = 300000 -------------------------------------------------------------------------------- /rave/configs/v2_with_augs.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import core 5 | from rave import dataset 6 | from rave import blocks 7 | from rave import discriminator 8 | from rave import transforms 9 | 10 | from torchaudio import transforms as ta_transforms 11 | 12 | import torch.nn as nn 13 | 14 | include "configs/v1.gin" 15 | 16 | KERNEL_SIZE = 3 17 | DILATIONS = [ 18 | [1, 3, 9], 19 | [1, 3, 9], 20 | [1, 3, 9], 21 | [1, 3], 22 | ] 23 | ENCODER_RATIOS = [2, 2, 2] 24 | RATIOS = [4, 4, 4, 2] 25 | CAPACITY = 96 26 | NOISE_AUGMENTATION = 0 27 | 28 | # MELSPEC PROPERTIES 29 | N_FFT = 2048 30 | N_MELS = 128 31 | HOP_LENGTH = 256 32 | NUM_GRU_LAYERS = 2 33 | 34 | core.AudioDistanceV1.log_epsilon = 1e-7 35 | 36 | core.get_augmented_latent_size: 37 | latent_size = %LATENT_SIZE 38 | noise_augmentation = %NOISE_AUGMENTATION 39 | 40 | # AUGMENTATIONS 41 | dataset.get_dataset: 42 | augmentations = [ 43 | @augmentations/transforms.RandomCompress(), 44 | # @augmentations/transforms.FrequencyMasking() 45 | ] 46 | 47 | augmentations/transforms.RandomCompress: 48 | amp_range = [-60,-10] 49 | threshold=-40 50 | prob = 0.5 51 | 52 | ta_transforms.MelSpectrogram: 53 | sample_rate = %SAMPLING_RATE 54 | n_fft = %N_FFT 55 | win_length = %N_FFT 56 | hop_length = %HOP_LENGTH 57 | normalized = True 58 | n_mels = %N_MELS 59 | 60 | # ENCODER 61 | blocks.EncoderV2: 62 | data_size = %N_MELS 63 | ratios = %ENCODER_RATIOS 64 | capacity = %CAPACITY 65 | latent_size = %LATENT_SIZE 66 | n_out = 2 67 | kernel_size = %KERNEL_SIZE 68 | dilations = %DILATIONS 69 | 70 | blocks.VariationalEncoder: 71 | encoder = @variational/blocks.EncoderV2 72 | 73 | # GENERATOR 74 | blocks.GeneratorV2: 75 | data_size = %N_BAND 76 | capacity = %CAPACITY 77 | ratios = %RATIOS 78 | latent_size = @core.get_augmented_latent_size() 79 | kernel_size = %KERNEL_SIZE 80 | dilations = %DILATIONS 81 | amplitude_modulation = True 82 | 83 | # DISCRIMINATOR 84 | periods/discriminator.ConvNet: 85 | conv = @nn.Conv2d 86 | kernel_size = (5, 1) 87 | 88 | spectral/discriminator.ConvNet: 89 | conv = @nn.Conv1d 90 | kernel_size = 5 91 | stride = 2 92 | 93 | discriminator.MultiPeriodDiscriminator: 94 | periods = [2, 3, 5, 7, 11] 95 | convnet = @periods/discriminator.ConvNet 96 | 97 | discriminator.MultiScaleSpectralDiscriminator1d: 98 | scales = [4096, 2048, 1024, 512, 256] 99 | convnet = @spectral/discriminator.ConvNet 100 | 101 | discriminator.CombineDiscriminators: 102 | discriminators = [ 103 | @discriminator.MultiPeriodDiscriminator, 104 | @discriminator.MultiScaleDiscriminator, 105 | # @discriminator.MultiScaleSpectralDiscriminator1d, 106 | ] 107 | 108 | feature_matching/core.mean_difference: 109 | relative = True 110 | 111 | # RAVE 112 | rave.RAVE: 113 | discriminator = @discriminator.CombineDiscriminators 114 | valid_signal_crop = True 115 | num_skipped_features = 1 116 | decoder = @blocks.GeneratorV2 117 | phase_1_duration = 1000000 118 | spectrogram = @ta_transforms.MelSpectrogram() 119 | update_discriminator_every = 4 120 | input_mode = "mel" 121 | output_mode = "pqmf" 122 | audio_monitor_epochs = 10 123 | 124 | -------------------------------------------------------------------------------- /rave/configs/v3.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | include "configs/v2.gin" 4 | include "configs/adain.gin" 5 | include "configs/snake.gin" 6 | include "configs/descript_discriminator.gin" 7 | 8 | import rave 9 | 10 | rave.BetaWarmupCallback: 11 | initial_value = 1e-6 12 | target_value = 5e-2 13 | warmup_len = 20000 -------------------------------------------------------------------------------- /rave/configs/wasserstein.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import rave 4 | from rave import blocks 5 | 6 | LATENT_SIZE = 16 7 | NOISE_AUGMENTATION = 128 8 | PHASE_1_DURATION = 200000 9 | 10 | blocks.EncoderV2.n_out = 1 11 | 12 | blocks.WasserteinEncoder: 13 | encoder_cls = @blocks.EncoderV2 14 | noise_augmentation = %NOISE_AUGMENTATION 15 | 16 | rave.RAVE: 17 | encoder = @blocks.WasserteinEncoder 18 | phase_1_duration = %PHASE_1_DURATION 19 | weights = { 20 | 'fullband_spectral_distance': 2, 21 | 'multiband_spectral_distance': 2, 22 | 'adversarial': 2, 23 | } 24 | 25 | rave.BetaWarmupCallback: 26 | initial_value = 100 27 | target_value = 100 28 | warmup_len = 1 -------------------------------------------------------------------------------- /rave/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from random import random 5 | from typing import Callable, Optional, Sequence, Union 6 | 7 | import GPUtil as gpu 8 | import librosa as li 9 | import lmdb 10 | import numpy as np 11 | import pytorch_lightning as pl 12 | import torch 13 | import torch.fft as fft 14 | import torch.nn as nn 15 | import torchaudio 16 | from einops import rearrange 17 | from scipy.signal import lfilter 18 | 19 | 20 | def mod_sigmoid(x): 21 | return 2 * torch.sigmoid(x)**2.3 + 1e-7 22 | 23 | 24 | def random_angle(min_f=20, max_f=8000, sr=24000): 25 | min_f = np.log(min_f) 26 | max_f = np.log(max_f) 27 | rand = np.exp(random() * (max_f - min_f) + min_f) 28 | rand = 2 * np.pi * rand / sr 29 | return rand 30 | 31 | 32 | def get_augmented_latent_size(latent_size: int, noise_augmentation: int): 33 | return latent_size + noise_augmentation 34 | 35 | 36 | def pole_to_z_filter(omega, amplitude=.9): 37 | z0 = amplitude * np.exp(1j * omega) 38 | a = [1, -2 * np.real(z0), abs(z0)**2] 39 | b = [abs(z0)**2, -2 * np.real(z0), 1] 40 | return b, a 41 | 42 | def random_phase_mangle(x, min_f, max_f, amp, sr): 43 | angle = random_angle(min_f, max_f, sr) 44 | b, a = pole_to_z_filter(angle, amp) 45 | return lfilter(b, a, x) 46 | 47 | 48 | def amp_to_impulse_response(amp, target_size): 49 | """ 50 | transforms frequency amps to ir on the last dimension 51 | """ 52 | amp = torch.stack([amp, torch.zeros_like(amp)], -1) 53 | amp = torch.view_as_complex(amp) 54 | amp = fft.irfft(amp) 55 | 56 | filter_size = amp.shape[-1] 57 | 58 | amp = torch.roll(amp, filter_size // 2, -1) 59 | win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device) 60 | 61 | amp = amp * win 62 | 63 | amp = nn.functional.pad( 64 | amp, 65 | (0, int(target_size) - int(filter_size)), 66 | ) 67 | amp = torch.roll(amp, -filter_size // 2, -1) 68 | 69 | return amp 70 | 71 | def fft_convolve(signal, kernel): 72 | """ 73 | convolves signal by kernel on the last dimension 74 | """ 75 | signal = nn.functional.pad(signal, (0, signal.shape[-1])) 76 | kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0)) 77 | 78 | output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel)) 79 | output = output[..., output.shape[-1] // 2:] 80 | 81 | return output 82 | 83 | 84 | def get_ckpts(folder, name=None): 85 | ckpts = map(str, Path(folder).rglob("*.ckpt")) 86 | if name: 87 | ckpts = filter(lambda e: mode in os.path.basename(str(e)), ckpts) 88 | ckpts = sorted(ckpts, key=os.path.getmtime) 89 | return ckpts 90 | 91 | 92 | def get_versions(folder): 93 | ckpts = map(str, Path(folder).rglob("version_*")) 94 | ckpts = filter(lambda x: os.path.isdir(x), ckpts) 95 | return sorted(Path(dirpath).iterdir(), key=os.path.getmtime) 96 | 97 | def search_for_config(folder): 98 | if os.path.isfile(folder): 99 | folder = os.path.dirname(folder) 100 | configs = list(map(str, Path(folder).rglob("config.gin"))) 101 | if configs != []: 102 | return os.path.abspath(os.path.join(folder, "config.gin")) 103 | configs = list(map(str, Path(folder).rglob("../config.gin"))) 104 | if configs != []: 105 | return os.path.abspath(os.path.join(folder, "../config.gin")) 106 | configs = list(map(str, Path(folder).rglob("../../config.gin"))) 107 | if configs != []: 108 | return os.path.abspath(os.path.join(folder, "../../config.gin")) 109 | else: 110 | return None 111 | 112 | 113 | 114 | def search_for_run(run_path, name=None): 115 | if run_path is None: return None 116 | if ".ckpt" in run_path: return run_path 117 | ckpts = get_ckpts(run_path) 118 | if len(ckpts) != 0: 119 | return ckpts[-1] 120 | else: 121 | print('No checkpoint found') 122 | return None 123 | 124 | 125 | def setup_gpu(): 126 | return gpu.getAvailable(maxMemory=.05) 127 | 128 | 129 | def get_beta_kl(step, warmup, min_beta, max_beta): 130 | if step > warmup: return max_beta 131 | t = step / warmup 132 | min_beta_log = np.log(min_beta) 133 | max_beta_log = np.log(max_beta) 134 | beta_log = t * (max_beta_log - min_beta_log) + min_beta_log 135 | return np.exp(beta_log) 136 | 137 | 138 | def get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta): 139 | return get_beta_kl(step % cycle_size, cycle_size // 2, min_beta, max_beta) 140 | 141 | 142 | def get_beta_kl_cyclic_annealed(step, cycle_size, warmup, min_beta, max_beta): 143 | min_beta = get_beta_kl(step, warmup, min_beta, max_beta) 144 | return get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta) 145 | 146 | 147 | def n_fft_to_num_bands(n_fft: int) -> int: 148 | return n_fft // 2 + 1 149 | 150 | 151 | def hinge_gan(score_real, score_fake): 152 | loss_dis = torch.relu(1 - score_real) + torch.relu(1 + score_fake) 153 | loss_dis = loss_dis.mean() 154 | loss_gen = -score_fake.mean() 155 | return loss_dis, loss_gen 156 | 157 | 158 | def ls_gan(score_real, score_fake): 159 | loss_dis = (score_real - 1).pow(2) + score_fake.pow(2) 160 | loss_dis = loss_dis.mean() 161 | loss_gen = (score_fake - 1).pow(2).mean() 162 | return loss_dis, loss_gen 163 | 164 | 165 | def nonsaturating_gan(score_real, score_fake): 166 | score_real = torch.clamp(torch.sigmoid(score_real), 1e-7, 1 - 1e-7) 167 | score_fake = torch.clamp(torch.sigmoid(score_fake), 1e-7, 1 - 1e-7) 168 | loss_dis = -(torch.log(score_real) + torch.log(1 - score_fake)).mean() 169 | loss_gen = -torch.log(score_fake).mean() 170 | return loss_dis, loss_gen 171 | 172 | def get_minimum_size(model): 173 | N = 2**15 174 | device = next(iter(model.parameters())).device 175 | x = torch.randn(1, model.n_channels, N, requires_grad=True, device=device) 176 | z = model.encode(x) 177 | return int(x.shape[-1] / z.shape[-1]) 178 | 179 | 180 | @torch.enable_grad() 181 | def get_rave_receptive_field(model, n_channels=1): 182 | N = 2**15 183 | model.eval() 184 | device = next(iter(model.parameters())).device 185 | 186 | for module in model.modules(): 187 | if hasattr(module, 'gru_state') or hasattr(module, 'temporal'): 188 | module.disable() 189 | 190 | while True: 191 | x = torch.randn(1, model.n_channels, N, requires_grad=True, device=device) 192 | 193 | z = model.encode(x) 194 | z = model.encoder.reparametrize(z)[0] 195 | y = model.decode(z) 196 | 197 | y[0, 0, N // 2].backward() 198 | assert x.grad is not None, "input has no grad" 199 | 200 | grad = x.grad.data.reshape(-1) 201 | left_grad, right_grad = grad.chunk(2, 0) 202 | large_enough = (left_grad[0] == 0) and right_grad[-1] == 0 203 | if large_enough: 204 | break 205 | else: 206 | N *= 2 207 | left_receptive_field = len(left_grad[left_grad != 0]) 208 | right_receptive_field = len(right_grad[right_grad != 0]) 209 | model.zero_grad() 210 | 211 | for module in model.modules(): 212 | if hasattr(module, 'gru_state') or hasattr(module, 'temporal'): 213 | module.enable() 214 | ratio = x.shape[-1] // z.shape[-1] 215 | rate = model.sr / ratio 216 | print(f"Compression ratio: {ratio}x (~{rate:.1f}Hz @ {model.sr}Hz)") 217 | return left_receptive_field, right_receptive_field 218 | 219 | 220 | def valid_signal_crop(x, left_rf, right_rf): 221 | dim = x.shape[1] 222 | x = x[..., left_rf.item() // dim:] 223 | if right_rf.item(): 224 | x = x[..., :-right_rf.item() // dim] 225 | return x 226 | 227 | 228 | def relative_distance( 229 | x: torch.Tensor, 230 | y: torch.Tensor, 231 | norm: Callable[[torch.Tensor], torch.Tensor], 232 | ) -> torch.Tensor: 233 | return norm(x - y) / norm(x) 234 | 235 | 236 | def mean_difference(target: torch.Tensor, 237 | value: torch.Tensor, 238 | norm: str = 'L1', 239 | relative: bool = False): 240 | diff = target - value 241 | if norm == 'L1': 242 | diff = diff.abs().mean() 243 | if relative: 244 | diff = diff / target.abs().mean() 245 | return diff 246 | elif norm == 'L2': 247 | diff = (diff * diff).mean() 248 | if relative: 249 | diff = diff / (target * target).mean() 250 | return diff 251 | else: 252 | raise Exception(f'Norm must be either L1 or L2, got {norm}') 253 | 254 | 255 | class MelScale(nn.Module): 256 | 257 | def __init__(self, sample_rate: int, n_fft: int, n_mels: int) -> None: 258 | super().__init__() 259 | mel = li.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels) 260 | mel = torch.from_numpy(mel).float() 261 | self.register_buffer('mel', mel) 262 | 263 | def forward(self, x: torch.Tensor) -> torch.Tensor: 264 | mel = self.mel.type_as(x) 265 | y = torch.einsum('bft,mf->bmt', x, mel) 266 | return y 267 | 268 | 269 | class MultiScaleSTFT(nn.Module): 270 | 271 | def __init__(self, 272 | scales: Sequence[int], 273 | sample_rate: int, 274 | magnitude: bool = True, 275 | normalized: bool = False, 276 | num_mels: Optional[int] = None) -> None: 277 | super().__init__() 278 | self.scales = scales 279 | self.magnitude = magnitude 280 | self.num_mels = num_mels 281 | 282 | self.stfts = [] 283 | self.mel_scales = [] 284 | for scale in scales: 285 | self.stfts.append( 286 | torchaudio.transforms.Spectrogram( 287 | n_fft=scale, 288 | win_length=scale, 289 | hop_length=scale // 4, 290 | normalized=normalized, 291 | power=None, 292 | )) 293 | if num_mels is not None: 294 | self.mel_scales.append( 295 | MelScale( 296 | sample_rate=sample_rate, 297 | n_fft=scale, 298 | n_mels=num_mels, 299 | )) 300 | else: 301 | self.mel_scales.append(None) 302 | 303 | self.stfts = nn.ModuleList(self.stfts) 304 | self.mel_scales = nn.ModuleList(self.mel_scales) 305 | 306 | def forward(self, x: torch.Tensor) -> Sequence[torch.Tensor]: 307 | x = rearrange(x, "b c t -> (b c) t") 308 | stfts = [] 309 | for stft, mel in zip(self.stfts, self.mel_scales): 310 | y = stft(x) 311 | if mel is not None: 312 | y = mel(y) 313 | if self.magnitude: 314 | y = y.abs() 315 | else: 316 | y = torch.stack([y.real, y.imag], -1) 317 | stfts.append(y) 318 | 319 | return stfts 320 | 321 | 322 | class AudioDistanceV1(nn.Module): 323 | 324 | def __init__(self, multiscale_stft: Callable[[], nn.Module], 325 | log_epsilon: float) -> None: 326 | super().__init__() 327 | self.multiscale_stft = multiscale_stft() 328 | self.log_epsilon = log_epsilon 329 | 330 | def forward(self, x: torch.Tensor, y: torch.Tensor): 331 | stfts_x = self.multiscale_stft(x) 332 | stfts_y = self.multiscale_stft(y) 333 | distance = 0. 334 | 335 | for x, y in zip(stfts_x, stfts_y): 336 | logx = torch.log(x + self.log_epsilon) 337 | logy = torch.log(y + self.log_epsilon) 338 | 339 | lin_distance = mean_difference(x, y, norm='L2', relative=True) 340 | log_distance = mean_difference(logx, logy, norm='L1') 341 | 342 | distance = distance + lin_distance + log_distance 343 | 344 | return {'spectral_distance': distance} 345 | 346 | 347 | class WeightedInstantaneousSpectralDistance(nn.Module): 348 | 349 | def __init__(self, 350 | multiscale_stft: Callable[[], MultiScaleSTFT], 351 | weighted: bool = False) -> None: 352 | super().__init__() 353 | self.multiscale_stft = multiscale_stft() 354 | self.weighted = weighted 355 | 356 | def phase_to_instantaneous_frequency(self, 357 | x: torch.Tensor) -> torch.Tensor: 358 | x = self.unwrap(x) 359 | x = self.derivative(x) 360 | return x 361 | 362 | def derivative(self, x: torch.Tensor) -> torch.Tensor: 363 | return x[..., 1:] - x[..., :-1] 364 | 365 | def unwrap(self, x: torch.Tensor) -> torch.Tensor: 366 | x = self.derivative(x) 367 | x = (x + np.pi) % (2 * np.pi) 368 | return (x - np.pi).cumsum(-1) 369 | 370 | def forward(self, target: torch.Tensor, pred: torch.Tensor): 371 | stfts_x = self.multiscale_stft(target) 372 | stfts_y = self.multiscale_stft(pred) 373 | spectral_distance = 0. 374 | phase_distance = 0. 375 | 376 | for x, y in zip(stfts_x, stfts_y): 377 | assert x.shape[-1] == 2 378 | 379 | x = torch.view_as_complex(x) 380 | y = torch.view_as_complex(y) 381 | 382 | # AMPLITUDE DISTANCE 383 | x_abs = x.abs() 384 | y_abs = y.abs() 385 | 386 | logx = torch.log1p(x_abs) 387 | logy = torch.log1p(y_abs) 388 | 389 | lin_distance = mean_difference(x_abs, 390 | y_abs, 391 | norm='L2', 392 | relative=True) 393 | log_distance = mean_difference(logx, logy, norm='L1') 394 | 395 | spectral_distance = spectral_distance + lin_distance + log_distance 396 | 397 | # PHASE DISTANCE 398 | x_if = self.phase_to_instantaneous_frequency(x.angle()) 399 | y_if = self.phase_to_instantaneous_frequency(y.angle()) 400 | 401 | if self.weighted: 402 | mask = torch.clip(torch.log1p(x_abs[..., 2:]), 0, 1) 403 | x_if = x_if * mask 404 | y_if = y_if * mask 405 | 406 | phase_distance = phase_distance + mean_difference( 407 | x_if, y_if, norm='L2') 408 | 409 | return { 410 | 'spectral_distance': spectral_distance, 411 | 'phase_distance': phase_distance 412 | } 413 | 414 | 415 | class EncodecAudioDistance(nn.Module): 416 | 417 | def __init__(self, scales: int, 418 | spectral_distance: Callable[[int], nn.Module]) -> None: 419 | super().__init__() 420 | self.waveform_distance = WaveformDistance(norm='L1') 421 | self.spectral_distances = nn.ModuleList( 422 | [spectral_distance(scale) for scale in scales]) 423 | 424 | def forward(self, x, y): 425 | waveform_distance = self.waveform_distance(x, y) 426 | spectral_distance = 0 427 | for dist in self.spectral_distances: 428 | spectral_distance = spectral_distance + dist(x, y) 429 | 430 | return { 431 | 'waveform_distance': waveform_distance, 432 | 'spectral_distance': spectral_distance 433 | } 434 | 435 | 436 | class WaveformDistance(nn.Module): 437 | 438 | def __init__(self, norm: str) -> None: 439 | super().__init__() 440 | self.norm = norm 441 | 442 | def forward(self, x, y): 443 | return mean_difference(y, x, self.norm) 444 | 445 | 446 | class SpectralDistance(nn.Module): 447 | 448 | def __init__( 449 | self, 450 | n_fft: int, 451 | sampling_rate: int, 452 | norm: Union[str, Sequence[str]], 453 | power: Union[int, None], 454 | normalized: bool, 455 | mel: Optional[int] = None, 456 | ) -> None: 457 | super().__init__() 458 | if mel: 459 | self.spec = torchaudio.transforms.MelSpectrogram( 460 | sampling_rate, 461 | n_fft, 462 | hop_length=n_fft // 4, 463 | n_mels=mel, 464 | power=power, 465 | normalized=normalized, 466 | center=False, 467 | pad_mode=None, 468 | ) 469 | else: 470 | self.spec = torchaudio.transforms.Spectrogram( 471 | n_fft, 472 | hop_length=n_fft // 4, 473 | power=power, 474 | normalized=normalized, 475 | center=False, 476 | pad_mode=None, 477 | ) 478 | 479 | if isinstance(norm, str): 480 | norm = (norm, ) 481 | self.norm = norm 482 | 483 | def forward(self, x, y): 484 | x = self.spec(x) 485 | y = self.spec(y) 486 | 487 | distance = 0 488 | for norm in self.norm: 489 | distance = distance + mean_difference(y, x, norm) 490 | return distance 491 | 492 | 493 | class ProgressLogger(object): 494 | 495 | def __init__(self, name: str) -> None: 496 | self.env = lmdb.open("status") 497 | self.name = name 498 | 499 | def update(self, **new_state): 500 | current_state = self.__call__() 501 | with self.env.begin(write=True) as txn: 502 | current_state.update(new_state) 503 | current_state = json.dumps(current_state) 504 | txn.put(self.name.encode(), current_state.encode()) 505 | 506 | def __call__(self): 507 | with self.env.begin(write=True) as txn: 508 | current_state = txn.get(self.name.encode()) 509 | if current_state is not None: 510 | current_state = json.loads(current_state.decode()) 511 | else: 512 | current_state = {} 513 | return current_state 514 | 515 | 516 | class LoggerCallback(pl.Callback): 517 | 518 | def __init__(self, logger: ProgressLogger) -> None: 519 | super().__init__() 520 | self.state = {'step': 0, 'warmed': False} 521 | self.logger = logger 522 | 523 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, 524 | batch_idx) -> None: 525 | self.state['step'] += 1 526 | self.state['warmed'] = pl_module.warmed_up 527 | 528 | if not self.state['step'] % 100: 529 | self.logger.update(**self.state) 530 | 531 | def state_dict(self): 532 | return self.state.copy() 533 | 534 | def load_state_dict(self, state_dict): 535 | self.state.update(state_dict) 536 | 537 | 538 | class ModelCheckpoint(pl.callbacks.ModelCheckpoint): 539 | def __init__(self, step_period: int = None, **kwargs): 540 | super().__init__(**kwargs) 541 | self.step_period = step_period 542 | self.__counter = 0 543 | 544 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 545 | self.__counter += 1 546 | if self.step_period: 547 | if self.__counter % self.step_period == 0: 548 | filename = os.path.join(self.dirpath, f"epoch_{self.__counter}{self.FILE_EXTENSION}") 549 | self._save_checkpoint(trainer, filename) 550 | 551 | 552 | def get_valid_extensions(): 553 | import torchaudio 554 | backend = torchaudio.get_audio_backend() 555 | if backend in ["sox_io", "sox"]: 556 | return ['.'+f for f in torchaudio.utils.sox_utils.list_read_formats()] 557 | elif backend == "ffmpeg": 558 | return ['.'+f for f in torchaudio.utils.ffmpeg_utils.get_audio_decoders()] 559 | elif backend == "soundfile": 560 | return ['.wav', '.flac', '.ogg', '.aiff', '.aif', '.aifc'] 561 | 562 | -------------------------------------------------------------------------------- /rave/dataset.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | import math 4 | import os 5 | import subprocess 6 | from random import random 7 | from typing import Dict, Iterable, Optional, Sequence, Union, Callable 8 | 9 | import gin 10 | import lmdb 11 | import numpy as np 12 | import requests 13 | import torch 14 | import torchaudio 15 | import yaml 16 | from scipy.signal import lfilter 17 | from torch.utils import data 18 | from tqdm import tqdm 19 | from . import transforms 20 | from udls import AudioExample as AudioExampleWrapper 21 | from udls.generated import AudioExample 22 | 23 | 24 | def get_derivator_integrator(sr: int): 25 | alpha = 1 / (1 + 1 / sr * 2 * np.pi * 10) 26 | derivator = ([.5, -.5], [1]) 27 | integrator = ([alpha**2, -alpha**2], [1, -2 * alpha, alpha**2]) 28 | 29 | return lambda x: lfilter(*derivator, x), lambda x: lfilter(*integrator, x) 30 | 31 | 32 | class AudioDataset(data.Dataset): 33 | 34 | @property 35 | def env(self) -> lmdb.Environment: 36 | if self._env is None: 37 | self._env = lmdb.open(self._db_path, lock=False) 38 | return self._env 39 | 40 | @property 41 | def keys(self) -> Sequence[str]: 42 | if self._keys is None: 43 | with self.env.begin() as txn: 44 | self._keys = list(txn.cursor().iternext(values=False)) 45 | return self._keys 46 | 47 | def __init__(self, 48 | db_path: str, 49 | audio_key: str = 'waveform', 50 | transforms: Optional[transforms.Transform] = None, 51 | n_channels: int = 1) -> None: 52 | super().__init__() 53 | self._db_path = db_path 54 | self._audio_key = audio_key 55 | self._env = None 56 | self._keys = None 57 | self._transforms = transforms 58 | self._n_channels = n_channels 59 | lens = [] 60 | with self.env.begin() as txn: 61 | for k in self.keys: 62 | ae = AudioExample.FromString(txn.get(k)) 63 | lens.append(np.frombuffer(ae.buffers['waveform'].data, dtype=np.int16).shape) 64 | 65 | 66 | def __len__(self): 67 | return len(self.keys) 68 | 69 | def __getitem__(self, index): 70 | with self.env.begin() as txn: 71 | ae = AudioExample.FromString(txn.get(self.keys[index])) 72 | 73 | buffer = ae.buffers[self._audio_key] 74 | assert buffer.precision == AudioExample.Precision.INT16 75 | 76 | audio = np.frombuffer(buffer.data, dtype=np.int16) 77 | audio = audio.astype(np.float32) / (2**15 - 1) 78 | audio = audio.reshape(self._n_channels, -1) 79 | 80 | if self._transforms is not None: 81 | audio = self._transforms(audio) 82 | 83 | return audio 84 | 85 | 86 | class LazyAudioDataset(data.Dataset): 87 | 88 | @property 89 | def env(self) -> lmdb.Environment: 90 | if self._env is None: 91 | self._env = lmdb.open(self._db_path, lock=False) 92 | return self._env 93 | 94 | @property 95 | def keys(self) -> Sequence[str]: 96 | if self._keys is None: 97 | with self.env.begin() as txn: 98 | self._keys = list(txn.cursor().iternext(values=False)) 99 | return self._keys 100 | 101 | def __init__(self, 102 | db_path: str, 103 | n_signal: int, 104 | sampling_rate: int, 105 | transforms: Optional[transforms.Transform] = None, 106 | n_channels: int = 1) -> None: 107 | super().__init__() 108 | self._db_path = db_path 109 | self._env = None 110 | self._keys = None 111 | self._transforms = transforms 112 | self._n_signal = n_signal 113 | self._sampling_rate = sampling_rate 114 | self._n_channels = n_channels 115 | 116 | self.parse_dataset() 117 | 118 | def parse_dataset(self): 119 | items = [] 120 | for key in tqdm(self.keys, desc='Discovering dataset'): 121 | with self.env.begin() as txn: 122 | ae = AudioExample.FromString(txn.get(key)) 123 | length = float(ae.metadata['length']) 124 | n_signal = int(math.floor(length * self._sampling_rate)) 125 | n_chunks = n_signal // self._n_signal 126 | items.append(n_chunks) 127 | items = np.asarray(items) 128 | items = np.cumsum(items) 129 | self.items = items 130 | 131 | def __len__(self): 132 | return self.items[-1] 133 | 134 | def __getitem__(self, index): 135 | audio_id = np.where(index < self.items)[0][0] 136 | if audio_id: 137 | index -= self.items[audio_id - 1] 138 | 139 | key = self.keys[audio_id] 140 | 141 | with self.env.begin() as txn: 142 | ae = AudioExample.FromString(txn.get(key)) 143 | 144 | audio = extract_audio( 145 | ae.metadata['path'], 146 | self._n_signal, 147 | self._sampling_rate, 148 | index * self._n_signal, 149 | int(ae.metadata['channels']), 150 | self._n_channels 151 | ) 152 | 153 | if self._transforms is not None: 154 | audio = self._transforms(audio) 155 | 156 | return audio 157 | 158 | def get_channels_from_dataset(db_path): 159 | with open(os.path.join(db_path, 'metadata.yaml'), 'r') as metadata: 160 | metadata = yaml.safe_load(metadata) 161 | return metadata.get('channels') 162 | 163 | def get_training_channels(db_path, target_channels): 164 | dataset_channels = get_channels_from_dataset(db_path) 165 | if dataset_channels is not None: 166 | if target_channels > dataset_channels: 167 | raise RuntimeError('[Error] Requested number of channels is %s, but dataset has %s channels')%(FLAGS.channels, dataset_channels) 168 | n_channels = target_channels or dataset_channels 169 | if n_channels is None: 170 | print('[Warning] channels not found in dataset, taking 1 by default') 171 | n_channels = 1 172 | return n_channels 173 | 174 | class HTTPAudioDataset(data.Dataset): 175 | 176 | def __init__(self, db_path: str): 177 | super().__init__() 178 | self.db_path = db_path 179 | logging.info("starting remote dataset session") 180 | self.length = int(requests.get("/".join([db_path, "len"])).text) 181 | logging.info("connection established !") 182 | 183 | def __len__(self): 184 | return self.length 185 | 186 | def __getitem__(self, index): 187 | example = requests.get("/".join([ 188 | self.db_path, 189 | "get", 190 | f"{index}", 191 | ])).text 192 | example = AudioExampleWrapper(base64.b64decode(example)).get("audio") 193 | return example.copy() 194 | 195 | 196 | def normalize_signal(x: np.ndarray, max_gain_db: int = 30): 197 | peak = np.max(abs(x)) 198 | if peak == 0: return x 199 | 200 | log_peak = 20 * np.log10(peak) 201 | log_gain = min(max_gain_db, -log_peak) 202 | gain = 10**(log_gain / 20) 203 | 204 | return x * gain 205 | 206 | @gin.configurable 207 | def get_dataset(db_path, 208 | sr, 209 | n_signal, 210 | derivative: bool = False, 211 | normalize: bool = False, 212 | rand_pitch: bool = False, 213 | augmentations: Union[None, Iterable[Callable]] = None, 214 | n_channels: int = 1): 215 | if db_path[:4] == "http": 216 | return HTTPAudioDataset(db_path=db_path) 217 | with open(os.path.join(db_path, 'metadata.yaml'), 'r') as metadata: 218 | metadata = yaml.safe_load(metadata) 219 | 220 | sr_dataset = metadata.get('sr', 44100) 221 | lazy = metadata['lazy'] 222 | 223 | transform_list = [ 224 | lambda x: x.astype(np.float32), 225 | transforms.RandomCrop(n_signal), 226 | transforms.RandomApply( 227 | lambda x: random_phase_mangle(x, 20, 2000, .99, sr_dataset), 228 | p=.8, 229 | ), 230 | transforms.Dequantize(16), 231 | ] 232 | 233 | if rand_pitch: 234 | rand_pitch = list(map(float, rand_pitch)) 235 | assert len(rand_pitch) == 2, "rand_pitch must be given two floats" 236 | transform_list.insert(1, transforms.RandomPitch(n_signal, rand_pitch)) 237 | 238 | if sr_dataset != sr: 239 | transform_list.append(transforms.Resample(sr_dataset, sr)) 240 | 241 | if normalize: 242 | transform_list.append(normalize_signal) 243 | 244 | if derivative: 245 | transform_list.append(get_derivator_integrator(sr)[0]) 246 | 247 | if augmentations: 248 | transform_list.extend(augmentations) 249 | 250 | transform_list.append(lambda x: x.astype(np.float32)) 251 | 252 | transform_list = transforms.Compose(transform_list) 253 | 254 | if lazy: 255 | return LazyAudioDataset(db_path, n_signal, sr_dataset, transform_list, n_channels) 256 | else: 257 | return AudioDataset( 258 | db_path, 259 | transforms=transform_list, 260 | n_channels=n_channels 261 | ) 262 | 263 | 264 | @gin.configurable 265 | def split_dataset(dataset, percent, max_residual: Optional[int] = None): 266 | split1 = max((percent * len(dataset)) // 100, 1) 267 | split2 = len(dataset) - split1 268 | if max_residual is not None: 269 | split2 = min(max_residual, split2) 270 | split1 = len(dataset) - split2 271 | print(f'train set: {split1} examples') 272 | print(f'val set: {split2} examples') 273 | split1, split2 = data.random_split( 274 | dataset, 275 | [split1, split2], 276 | generator=torch.Generator().manual_seed(42), 277 | ) 278 | return split1, split2 279 | 280 | 281 | def random_angle(min_f=20, max_f=8000, sr=24000): 282 | min_f = np.log(min_f) 283 | max_f = np.log(max_f) 284 | rand = np.exp(random() * (max_f - min_f) + min_f) 285 | rand = 2 * np.pi * rand / sr 286 | return rand 287 | 288 | 289 | def pole_to_z_filter(omega, amplitude=.9): 290 | z0 = amplitude * np.exp(1j * omega) 291 | a = [1, -2 * np.real(z0), abs(z0)**2] 292 | b = [abs(z0)**2, -2 * np.real(z0), 1] 293 | return b, a 294 | 295 | 296 | def random_phase_mangle(x, min_f, max_f, amp, sr): 297 | angle = random_angle(min_f, max_f, sr) 298 | b, a = pole_to_z_filter(angle, amp) 299 | return lfilter(b, a, x) 300 | 301 | def extract_audio(path: str, n_signal: int, sr: int, 302 | start_sample: int, input_channels: int, channels: int) -> Iterable[np.ndarray]: 303 | # channel mapping 304 | channel_map = range(channels) 305 | if input_channels < channels: 306 | channel_map = (math.ceil(channels / input_channels) * list(range(input_channels)))[:channels] 307 | # time information 308 | start_sec = start_sample / sr 309 | length = (n_signal * 2) / sr 310 | chunks = [] 311 | for i in channel_map: 312 | process = subprocess.Popen( 313 | [ 314 | 'ffmpeg', '-v', 'error', 315 | '-ss', 316 | str(start_sec), 317 | '-i', 318 | path, 319 | '-ar', 320 | str(sr), 321 | '-filter_complex', 322 | 'channelmap=%d-0'%i, 323 | '-t', 324 | str(length), 325 | '-f', 326 | 's16le', 327 | '-' 328 | ], 329 | stdout=subprocess.PIPE, 330 | ) 331 | 332 | chunk = process.communicate()[0] 333 | chunk = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 2**15 334 | chunk = np.concatenate([chunk, np.zeros(n_signal)], -1) 335 | chunks.append(chunk) 336 | return np.stack(chunks)[:, :(n_signal*2)] 337 | -------------------------------------------------------------------------------- /rave/descript_discriminator.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/descriptinc/descript-audio-codec 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | from torchaudio.transforms import Spectrogram 10 | 11 | from .pqmf import kaiser_filter 12 | 13 | 14 | def WNConv1d(*args, **kwargs): 15 | act = kwargs.pop("act", True) 16 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 17 | if not act: 18 | return conv 19 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 20 | 21 | 22 | def WNConv2d(*args, **kwargs): 23 | act = kwargs.pop("act", True) 24 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 25 | if not act: 26 | return conv 27 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 28 | 29 | 30 | class MPD(nn.Module): 31 | 32 | def __init__(self, period, n_channels: int = 1): 33 | super().__init__() 34 | self.period = period 35 | self.convs = nn.ModuleList([ 36 | WNConv2d(n_channels, 32, (5, 1), (3, 1), padding=(2, 0)), 37 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), 38 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), 39 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), 40 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), 41 | ]) 42 | self.conv_post = WNConv2d(1024, 43 | 1, 44 | kernel_size=(3, 1), 45 | padding=(1, 0), 46 | act=False) 47 | 48 | def pad_to_period(self, x): 49 | t = x.shape[-1] 50 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 51 | return x 52 | 53 | def forward(self, x): 54 | fmap = [] 55 | 56 | x = self.pad_to_period(x) 57 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 58 | 59 | for layer in self.convs: 60 | x = layer(x) 61 | fmap.append(x) 62 | 63 | x = self.conv_post(x) 64 | fmap.append(x) 65 | 66 | return fmap 67 | 68 | 69 | class MSD(nn.Module): 70 | 71 | def __init__(self, scale: int, n_channels: int = 1): 72 | super().__init__() 73 | self.convs = nn.ModuleList([ 74 | WNConv1d(n_channels, 16, 15, 1, padding=7), 75 | WNConv1d(16, 64, 41, 4, groups=4, padding=20), 76 | WNConv1d(64, 256, 41, 4, groups=16, padding=20), 77 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20), 78 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), 79 | WNConv1d(1024, 1024, 5, 1, padding=2), 80 | ]) 81 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 82 | 83 | self.scale = scale 84 | 85 | if self.scale != 1: 86 | wc = np.pi / self.scale 87 | filt = kaiser_filter(wc, 140) 88 | if not len(filt) % 2: 89 | filt = np.pad(filt, (1, 0)) 90 | 91 | self.register_buffer( 92 | "downsampler", 93 | torch.from_numpy(filt).reshape(1, 1, -1).float()) 94 | 95 | def forward(self, x): 96 | if self.scale != 1: 97 | x = nn.functional.conv1d( 98 | x, 99 | self.downsampler, 100 | padding=self.downsampler.shape[-1] // 2, 101 | stride=self.scale, 102 | ) 103 | 104 | fmap = [] 105 | 106 | for l in self.convs: 107 | x = l(x) 108 | fmap.append(x) 109 | x = self.conv_post(x) 110 | fmap.append(x) 111 | 112 | return fmap 113 | 114 | 115 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 116 | 117 | 118 | class MRD(nn.Module): 119 | 120 | def __init__( 121 | self, 122 | window_length: int, 123 | hop_factor: float = 0.25, 124 | sample_rate: int = 44100, 125 | bands: list = BANDS, 126 | n_channels: int = 1 127 | ): 128 | super().__init__() 129 | 130 | self.window_length = window_length 131 | self.hop_factor = hop_factor 132 | self.sample_rate = sample_rate 133 | 134 | n_fft = window_length // 2 + 1 135 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 136 | self.bands = bands 137 | 138 | ch = 32 139 | convs = lambda: nn.ModuleList([ 140 | WNConv2d(2 * n_channels, ch, (3, 9), (1, 1), padding=(1, 4)), 141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 143 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), 144 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), 145 | ]) 146 | self.band_convs = nn.ModuleList( 147 | [convs() for _ in range(len(self.bands))]) 148 | self.conv_post = WNConv2d(ch, 149 | 1, (3, 3), (1, 1), 150 | padding=(1, 1), 151 | act=False) 152 | 153 | self.stft = Spectrogram( 154 | n_fft=window_length, 155 | win_length=window_length, 156 | hop_length=int(hop_factor * window_length), 157 | center=True, 158 | return_complex=True, 159 | power=None, 160 | ) 161 | 162 | def spectrogram(self, x): 163 | x = torch.view_as_real(self.stft(x)) 164 | x = rearrange(x, "b c f t p -> b (c p) t f") 165 | # Split into bands 166 | x_bands = [x[..., b[0]:b[1]] for b in self.bands] 167 | return x_bands 168 | 169 | def forward(self, x): 170 | x_bands = self.spectrogram(x) 171 | fmap = [] 172 | 173 | x = [] 174 | for band, stack in zip(x_bands, self.band_convs): 175 | for layer in stack: 176 | band = layer(band) 177 | fmap.append(band) 178 | x.append(band) 179 | 180 | x = torch.cat(x, dim=-1) 181 | x = self.conv_post(x) 182 | fmap.append(x) 183 | 184 | return fmap 185 | 186 | 187 | class DescriptDiscriminator(nn.Module): 188 | 189 | def __init__( 190 | self, 191 | rates: list = [], 192 | periods: list = [2, 3, 5, 7, 11], 193 | fft_sizes: list = [2048, 1024, 512], 194 | sample_rate: int = 44100, 195 | bands: list = BANDS, 196 | n_channels: int = 1, 197 | ): 198 | super().__init__() 199 | discs = [] 200 | discs += [MPD(p, n_channels=n_channels) for p in periods] 201 | discs += [MSD(r, sample_rate=sample_rate, n_channels=n_channels) for r in rates] 202 | discs += [ 203 | MRD(f, sample_rate=sample_rate, bands=bands, n_channels=n_channels) for f in fft_sizes 204 | ] 205 | self.discriminators = nn.ModuleList(discs) 206 | 207 | def preprocess(self, y): 208 | # Remove DC offset 209 | y = y - y.mean(dim=-1, keepdims=True) 210 | # Peak normalize the volume of input audio 211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 212 | return y 213 | 214 | def forward(self, x): 215 | x = self.preprocess(x) 216 | fmaps = [d(x) for d in self.discriminators] 217 | return fmaps 218 | -------------------------------------------------------------------------------- /rave/discriminator.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence, Tuple, Type 2 | 3 | import cached_conv as cc 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torchaudio 8 | 9 | from .blocks import normalization 10 | 11 | 12 | def spectrogram(n_fft: int): 13 | return torchaudio.transforms.Spectrogram( 14 | n_fft, 15 | hop_length=n_fft // 4, 16 | power=None, 17 | normalized=True, 18 | center=False, 19 | pad_mode=None, 20 | ) 21 | 22 | 23 | def rectified_2d_conv_block( 24 | capacity, 25 | kernel_sizes, 26 | strides: Optional[Tuple[int, int]] = None, 27 | dilations: Optional[Tuple[int, int]] = None, 28 | in_size: Optional[int] = None, 29 | out_size: Optional[int] = None, 30 | activation: bool = True, 31 | ): 32 | if dilations is None: 33 | paddings = kernel_sizes[0] // 2, kernel_sizes[1] // 2 34 | else: 35 | fks = (kernel_sizes[0] - 1) * dilations[0], (kernel_sizes[1] - 36 | 1) * dilations[1] 37 | paddings = fks[0] // 2, fks[1] // 2 38 | 39 | conv = normalization( 40 | nn.Conv2d( 41 | in_size or capacity, 42 | out_size or capacity, 43 | kernel_size=kernel_sizes, 44 | stride=strides or (1, 1), 45 | dilation=dilations or (1, 1), 46 | padding=paddings, 47 | )) 48 | 49 | if not activation: return conv 50 | 51 | return nn.Sequential(conv, nn.LeakyReLU(.2)) 52 | 53 | 54 | class EncodecConvNet(nn.Module): 55 | 56 | def __init__(self, capacity: int, n_channels: int = 1) -> None: 57 | super().__init__() 58 | self.net = nn.Sequential( 59 | rectified_2d_conv_block(capacity, (9, 3), in_size=2*n_channels), 60 | rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 1)), 61 | rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 2)), 62 | rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 4)), 63 | rectified_2d_conv_block(capacity, (3, 3)), 64 | rectified_2d_conv_block(capacity, (3, 3), 65 | out_size=1, 66 | activation=False), 67 | ) 68 | 69 | def forward(self, x): 70 | features = [] 71 | for layer in self.net: 72 | x = layer(x) 73 | features.append(x) 74 | return features 75 | 76 | 77 | class ConvNet(nn.Module): 78 | 79 | def __init__(self, in_size, out_size, capacity, n_layers, kernel_size, 80 | stride, conv) -> None: 81 | super().__init__() 82 | channels = [in_size] 83 | channels += list(capacity * 2**np.arange(n_layers)) 84 | 85 | if isinstance(stride, int): 86 | stride = n_layers * [stride] 87 | 88 | net = [] 89 | for i in range(n_layers): 90 | if not isinstance(kernel_size, int): 91 | pad = (cc.get_padding(kernel_size[0], 92 | stride[i], 93 | mode="centered")[0], 0) 94 | s = (stride[i], 1) 95 | else: 96 | pad = cc.get_padding(kernel_size, stride[i], 97 | mode="centered")[0] 98 | s = stride[i] 99 | net.append( 100 | normalization( 101 | conv( 102 | channels[i], 103 | channels[i + 1], 104 | kernel_size, 105 | stride=s, 106 | padding=pad, 107 | ))) 108 | net.append(nn.LeakyReLU(.2)) 109 | net.append(conv(channels[-1], out_size, 1)) 110 | 111 | self.net = nn.Sequential(*net) 112 | 113 | def forward(self, x): 114 | features = [] 115 | for layer in self.net: 116 | x = layer(x) 117 | if isinstance(layer, nn.modules.conv._ConvNd): 118 | features.append(x) 119 | return features 120 | 121 | 122 | class MultiScaleDiscriminator(nn.Module): 123 | 124 | def __init__(self, n_discriminators, convnet, n_channels=1) -> None: 125 | super().__init__() 126 | layers = [] 127 | for i in range(n_discriminators): 128 | layers.append(convnet(in_size=n_channels)) 129 | self.layers = nn.ModuleList(layers) 130 | 131 | def forward(self, x): 132 | features = [] 133 | for layer in self.layers: 134 | features.append(layer(x)) 135 | x = nn.functional.avg_pool1d(x, 2) 136 | return features 137 | 138 | 139 | class MultiScaleSpectralDiscriminator(nn.Module): 140 | 141 | def __init__(self, scales: Sequence[int], 142 | convnet: Callable[[], nn.Module], n_channels: int = 1) -> None: 143 | super().__init__() 144 | self.specs = nn.ModuleList([spectrogram(n) for n in scales]) 145 | self.nets = nn.ModuleList([convnet(n_channels=n_channels) for _ in scales]) 146 | 147 | def forward(self, x): 148 | features = [] 149 | for spec, net in zip(self.specs, self.nets): 150 | spec_x = spec(x) 151 | spec_x = torch.cat([spec_x.real, spec_x.imag], 1) 152 | features.append(net(spec_x)) 153 | return features 154 | 155 | 156 | class MultiScaleSpectralDiscriminator1d(nn.Module): 157 | 158 | def __init__(self, scales: Sequence[int], 159 | convnet: Callable[[int], nn.Module], 160 | n_channels: int = 1) -> None: 161 | super().__init__() 162 | self.specs = nn.ModuleList([spectrogram(n) for n in scales]) 163 | self.nets = nn.ModuleList([convnet(n + 2, n_channels) for n in scales]) 164 | 165 | def forward(self, x): 166 | features = [] 167 | for spec, net in zip(self.specs, self.nets): 168 | spec_x = spec(x).squeeze(1) 169 | spec_x = torch.cat([spec_x.real, spec_x.imag], 1) 170 | features.append(net(spec_x)) 171 | return features 172 | 173 | 174 | class MultiPeriodDiscriminator(nn.Module): 175 | 176 | def __init__(self, periods, convnet, n_channels=1) -> None: 177 | super().__init__() 178 | layers = [] 179 | self.periods = periods 180 | 181 | for _ in periods: 182 | layers.append(convnet(in_size=n_channels)) 183 | 184 | self.layers = nn.ModuleList(layers) 185 | 186 | def forward(self, x): 187 | features = [] 188 | for layer, n in zip(self.layers, self.periods): 189 | features.append(layer(self.fold(x, n))) 190 | return features 191 | 192 | def fold(self, x, n): 193 | pad = (n - (x.shape[-1] % n)) % n 194 | x = nn.functional.pad(x, (0, pad)) 195 | return x.reshape(*x.shape[:2], -1, n) 196 | 197 | 198 | class CombineDiscriminators(nn.Module): 199 | 200 | def __init__(self, discriminators: Sequence[Type[nn.Module]], n_channels=1) -> None: 201 | super().__init__() 202 | self.discriminators = nn.ModuleList(disc_cls(n_channels=n_channels) 203 | for disc_cls in discriminators) 204 | 205 | def forward(self, x): 206 | features = [] 207 | for disc in self.discriminators: 208 | features.extend(disc(x)) 209 | return features 210 | -------------------------------------------------------------------------------- /rave/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from time import time 3 | from typing import Callable, Optional, Iterable, Dict 4 | 5 | import gin, pdb 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange 11 | from sklearn.decomposition import PCA 12 | from pytorch_lightning.trainer.states import RunningStage 13 | 14 | 15 | import rave.core 16 | 17 | from . import blocks 18 | 19 | 20 | _default_loss_weights = { 21 | 'audio_distance': 1., 22 | 'multiband_audio_distance': 1., 23 | 'adversarial': 1., 24 | 'feature_matching' : 20, 25 | } 26 | 27 | class Profiler: 28 | 29 | def __init__(self): 30 | self.ticks = [[time(), None]] 31 | 32 | def tick(self, msg): 33 | self.ticks.append([time(), msg]) 34 | 35 | def __repr__(self): 36 | rep = 80 * "=" + "\n" 37 | for i in range(1, len(self.ticks)): 38 | msg = self.ticks[i][1] 39 | ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] 40 | rep += msg + f": {ellapsed*1000:.2f}ms\n" 41 | rep += 80 * "=" + "\n\n\n" 42 | return rep 43 | 44 | 45 | class WarmupCallback(pl.Callback): 46 | 47 | def __init__(self) -> None: 48 | super().__init__() 49 | self.state = {'training_steps': 0} 50 | 51 | def on_train_batch_start(self, trainer, pl_module, batch, 52 | batch_idx) -> None: 53 | if self.state['training_steps'] >= pl_module.warmup: 54 | pl_module.warmed_up = True 55 | self.state['training_steps'] += 1 56 | 57 | def state_dict(self): 58 | return self.state.copy() 59 | 60 | def load_state_dict(self, state_dict): 61 | self.state.update(state_dict) 62 | 63 | 64 | class QuantizeCallback(WarmupCallback): 65 | 66 | def on_train_batch_(self, trainer, pl_module, batch, 67 | batch_idx) -> None: 68 | 69 | if pl_module.warmup_quantize is None: return 70 | 71 | if self.state['training_steps'] >= pl_module.warmup_quantize: 72 | if isinstance(pl_module.encoder, blocks.DiscreteEncoder): 73 | pl_module.encoder.enabled = torch.tensor(1).type_as( 74 | pl_module.encoder.enabled) 75 | self.state['training_steps'] += 1 76 | 77 | 78 | @gin.configurable 79 | class BetaWarmupCallback(pl.Callback): 80 | 81 | def __init__(self, initial_value: float = .2, 82 | target_value: float = .2, 83 | warmup_len: int = 1, 84 | log: bool = True) -> None: 85 | super().__init__() 86 | self.state = {'training_steps': 0} 87 | self.warmup_len = warmup_len 88 | self.initial_value = initial_value 89 | self.target_value = target_value 90 | self.log_warmup = log 91 | 92 | def on_train_batch_start(self, trainer, pl_module, batch, 93 | batch_idx) -> None: 94 | self.state['training_steps'] += 1 95 | if self.state["training_steps"] >= self.warmup_len: 96 | pl_module.beta_factor = self.target_value 97 | return 98 | 99 | warmup_ratio = self.state["training_steps"] / self.warmup_len 100 | 101 | if self.log_warmup: 102 | beta = math.log(self.initial_value) * (1 - warmup_ratio) + math.log( 103 | self.target_value) * warmup_ratio 104 | pl_module.beta_factor = math.exp(beta) 105 | else: 106 | beta = warmup_ratio * (self.target_value - self.initial_value) + self.initial_value 107 | pl_module.beta_factor = min(beta, self.target_value) 108 | 109 | def state_dict(self): 110 | return self.state.copy() 111 | 112 | def load_state_dict(self, state_dict): 113 | self.state.update(state_dict) 114 | 115 | 116 | @torch.fx.wrap 117 | def _pqmf_encode(pqmf, x: torch.Tensor): 118 | batch_size = x.shape[:-2] 119 | x_multiband = x.reshape(-1, 1, x.shape[-1]) 120 | x_multiband = pqmf(x_multiband) 121 | x_multiband = x_multiband.reshape(*batch_size, -1, x_multiband.shape[-1]) 122 | return x_multiband 123 | 124 | 125 | @torch.fx.wrap 126 | def _pqmf_decode(pqmf, x: torch.Tensor, batch_size: Iterable[int], n_channels: int): 127 | x = x.reshape(x.shape[0] * n_channels, -1, x.shape[-1]) 128 | x = pqmf.inverse(x) 129 | x = x.reshape(*batch_size, n_channels, -1) 130 | return x 131 | 132 | 133 | @gin.configurable 134 | class RAVE(pl.LightningModule): 135 | 136 | def __init__( 137 | self, 138 | latent_size, 139 | sampling_rate, 140 | encoder, 141 | decoder, 142 | discriminator, 143 | phase_1_duration, 144 | gan_loss, 145 | valid_signal_crop, 146 | feature_matching_fun, 147 | num_skipped_features, 148 | audio_distance: Callable[[], nn.Module], 149 | multiband_audio_distance: Callable[[], nn.Module], 150 | n_bands: int = 16, 151 | balancer = None, 152 | weights: Optional[Dict[str, float]] = None, 153 | warmup_quantize: Optional[int] = None, 154 | pqmf: Optional[Callable[[], nn.Module]] = None, 155 | spectrogram: Optional[Callable] = None, 156 | update_discriminator_every: int = 2, 157 | n_channels: int = 1, 158 | input_mode: str = "pqmf", 159 | output_mode: str = "pqmf", 160 | audio_monitor_epochs: int = 1, 161 | # for retro-compatibility 162 | enable_pqmf_encode: Optional[bool] = None, 163 | enable_pqmf_decode: Optional[bool] = None, 164 | is_mel_input: Optional[bool] = None, 165 | loss_weights = None 166 | ): 167 | super().__init__() 168 | self.pqmf = pqmf(n_channels=n_channels) 169 | self.spectrogram = None 170 | if spectrogram is not None: 171 | self.spectrogram = spectrogram 172 | assert input_mode in ['pqmf', 'mel', 'raw'] 173 | assert output_mode in ['raw', 'pqmf'] 174 | self.input_mode = input_mode 175 | self.output_mode = output_mode 176 | # retro-compatibility 177 | if (enable_pqmf_encode is not None) or (enable_pqmf_decode is not None): 178 | self.input_mode = "pqmf" if enable_pqmf_encode else "raw" 179 | self.output_mode = "pqmf" if enable_pqmf_decode else "raw" 180 | if (is_mel_input) is not None: 181 | self.input_mode = "mel" 182 | if loss_weights is not None: 183 | weights = loss_weights 184 | assert weights is not None, "RAVE model requires either weights or loss_weights (depreciated) keyword" 185 | 186 | # setup model 187 | self.encoder = encoder(n_channels=n_channels) 188 | self.decoder = decoder(n_channels=n_channels) 189 | self.discriminator = discriminator(n_channels=n_channels) 190 | 191 | self.audio_distance = audio_distance() 192 | self.multiband_audio_distance = multiband_audio_distance() 193 | 194 | self.gan_loss = gan_loss 195 | 196 | self.register_buffer("latent_pca", torch.eye(latent_size)) 197 | self.register_buffer("latent_mean", torch.zeros(latent_size)) 198 | self.register_buffer("fidelity", torch.zeros(latent_size)) 199 | 200 | self.latent_size = latent_size 201 | 202 | self.automatic_optimization = False 203 | 204 | # SCHEDULE 205 | self.warmup = phase_1_duration 206 | self.warmup_quantize = warmup_quantize 207 | self.weights = _default_loss_weights 208 | self.weights.update(weights) 209 | self.warmed_up = False 210 | 211 | # CONSTANTS 212 | self.sr = sampling_rate 213 | self.valid_signal_crop = valid_signal_crop 214 | self.n_channels = n_channels 215 | self.feature_matching_fun = feature_matching_fun 216 | self.num_skipped_features = num_skipped_features 217 | self.update_discriminator_every = update_discriminator_every 218 | 219 | self.eval_number = 0 220 | self.beta_factor = 1. 221 | self.integrator = None 222 | 223 | self.register_buffer("receptive_field", torch.tensor([0, 0]).long()) 224 | self.audio_monitor_epochs = audio_monitor_epochs 225 | 226 | def configure_optimizers(self): 227 | gen_p = list(self.encoder.parameters()) 228 | gen_p += list(self.decoder.parameters()) 229 | dis_p = list(self.discriminator.parameters()) 230 | 231 | gen_opt = torch.optim.Adam(gen_p, 1e-3, (.5, .9)) 232 | dis_opt = torch.optim.Adam(dis_p, 1e-4, (.5, .9)) 233 | 234 | return ({'optimizer': gen_opt, 235 | 'lr_scheduler': {'scheduler': torch.optim.lr_scheduler.LinearLR(gen_opt, start_factor=1.0, end_factor=0.1, total_iters=self.warmup)}}, 236 | {'optimizer':dis_opt}) 237 | 238 | def _mel_encode(self, x: torch.Tensor): 239 | batch_size = x.shape[:-2] 240 | x = self.spectrogram(x)[..., :-1] 241 | x = torch.log1p(x).reshape(*batch_size, -1, x.shape[-1]) 242 | return x 243 | 244 | def encode(self, x, return_mb: bool = False): 245 | x_enc = x 246 | if self.input_mode == "pqmf": 247 | x_enc = _pqmf_encode(self.pqmf, x_enc) 248 | elif self.input_mode == "mel": 249 | x_enc = self._mel_encode(x) 250 | 251 | z = self.encoder(x_enc) 252 | if return_mb: 253 | if self.input_mode == "pqmf": 254 | return z, x_enc 255 | else: 256 | x_multiband = _pqmf_encode(self.pqmf, x_enc) 257 | return z, x_multiband 258 | return z 259 | 260 | def decode(self, z): 261 | batch_size = z.shape[:-2] 262 | y = self.decoder(z) 263 | if self.output_mode == "pqmf": 264 | y = _pqmf_decode(self.pqmf, y, batch_size=batch_size, n_channels=self.n_channels) 265 | return y 266 | 267 | def forward(self, x): 268 | z = self.encode(x, return_mb=False) 269 | z = self.encoder.reparametrize(z)[0] 270 | return self.decode(z) 271 | 272 | def on_train_batch_end(self, outputs, batch, batch_idx) -> None: 273 | self.lr_schedulers().step() 274 | return super().on_train_batch_end(outputs, batch, batch_idx) 275 | 276 | def split_features(self, features): 277 | feature_real = [] 278 | feature_fake = [] 279 | for scale in features: 280 | true, fake = zip(*map( 281 | lambda x: torch.split(x, x.shape[0] // 2, 0), 282 | scale, 283 | )) 284 | feature_real.append(true) 285 | feature_fake.append(fake) 286 | return feature_real, feature_fake 287 | 288 | def training_step(self, batch, batch_idx): 289 | p = Profiler() 290 | gen_opt, dis_opt = self.optimizers() 291 | x_raw = batch 292 | x_raw.requires_grad = True 293 | 294 | batch_size = x_raw.shape[:-2] 295 | self.encoder.set_warmed_up(self.warmed_up) 296 | self.decoder.set_warmed_up(self.warmed_up) 297 | 298 | # ENCODE INPUT 299 | # get multiband in case 300 | z, x_multiband = self.encode(x_raw, return_mb=True) 301 | 302 | z, reg = self.encoder.reparametrize(z)[:2] 303 | p.tick('encode') 304 | 305 | # DECODE LATENT 306 | y = self.decoder(z) 307 | if self.output_mode == "pqmf": 308 | y_multiband = y 309 | y_raw = _pqmf_decode(self.pqmf, y, batch_size=batch_size, n_channels=self.n_channels) 310 | else: 311 | y_raw = y 312 | y_multiband = _pqmf_encode(self.pqmf, y) 313 | 314 | # TODO this has been added for training with num_samples = 65536 samples, output padding seems to mess with output dimensions. 315 | # this may probably conflict with cached_conv 316 | y_raw = y_raw[..., :x_raw.shape[-1]] 317 | y_multiband = y_multiband[..., :x_multiband.shape[-1]] 318 | 319 | p.tick('decode') 320 | 321 | if self.valid_signal_crop and self.receptive_field.sum(): 322 | x_multiband = rave.core.valid_signal_crop( 323 | x_multiband, 324 | *self.receptive_field, 325 | ) 326 | y_multiband = rave.core.valid_signal_crop( 327 | y_multiband, 328 | *self.receptive_field, 329 | ) 330 | p.tick('crop') 331 | 332 | # DISTANCE BETWEEN INPUT AND OUTPUT 333 | distances = {} 334 | multiband_distance = self.multiband_audio_distance( 335 | x_multiband, y_multiband) 336 | p.tick('mb distance') 337 | for k, v in multiband_distance.items(): 338 | distances[f'multiband_{k}'] = self.weights['multiband_audio_distance'] * v 339 | 340 | fullband_distance = self.audio_distance(x_raw, y_raw) 341 | p.tick('fb distance') 342 | 343 | for k, v in fullband_distance.items(): 344 | distances[f'fullband_{k}'] = self.weights['audio_distance'] * v 345 | 346 | feature_matching_distance = 0. 347 | 348 | if self.warmed_up: # DISCRIMINATION 349 | xy = torch.cat([x_raw, y_raw], 0) 350 | features = self.discriminator(xy) 351 | 352 | feature_real, feature_fake = self.split_features(features) 353 | 354 | loss_dis = 0 355 | loss_adv = 0 356 | 357 | pred_real = 0 358 | pred_fake = 0 359 | 360 | for scale_real, scale_fake in zip(feature_real, feature_fake): 361 | current_feature_distance = sum( 362 | map( 363 | self.feature_matching_fun, 364 | scale_real[self.num_skipped_features:], 365 | scale_fake[self.num_skipped_features:], 366 | )) / len(scale_real[self.num_skipped_features:]) 367 | 368 | feature_matching_distance = feature_matching_distance + current_feature_distance 369 | 370 | _dis, _adv = self.gan_loss(scale_real[-1], scale_fake[-1]) 371 | 372 | pred_real = pred_real + scale_real[-1].mean() 373 | pred_fake = pred_fake + scale_fake[-1].mean() 374 | 375 | loss_dis = loss_dis + _dis 376 | loss_adv = loss_adv + _adv 377 | 378 | feature_matching_distance = feature_matching_distance / len( 379 | feature_real) 380 | 381 | else: 382 | pred_real = torch.tensor(0.).to(x_raw) 383 | pred_fake = torch.tensor(0.).to(x_raw) 384 | loss_dis = torch.tensor(0.).to(x_raw) 385 | loss_adv = torch.tensor(0.).to(x_raw) 386 | p.tick('discrimination') 387 | 388 | # COMPOSE GEN LOSS 389 | loss_gen = {} 390 | loss_gen.update(distances) 391 | p.tick('update loss gen dict') 392 | 393 | if reg.item(): 394 | loss_gen['regularization'] = reg * self.beta_factor 395 | 396 | if self.warmed_up: 397 | loss_gen['feature_matching'] = self.weights['feature_matching'] * feature_matching_distance 398 | loss_gen['adversarial'] = self.weights['adversarial'] * loss_adv 399 | 400 | # OPTIMIZATION 401 | if not (batch_idx % 402 | self.update_discriminator_every) and self.warmed_up: 403 | dis_opt.zero_grad() 404 | loss_dis.backward() 405 | dis_opt.step() 406 | p.tick('dis opt') 407 | else: 408 | gen_opt.zero_grad() 409 | loss_gen_value = 0. 410 | for k, v in loss_gen.items(): 411 | loss_gen_value += v * self.weights.get(k, 1.) 412 | loss_gen_value.backward() 413 | gen_opt.step() 414 | 415 | # LOGGING 416 | self.log("beta_factor", self.beta_factor) 417 | 418 | if self.warmed_up: 419 | self.log("loss_dis", loss_dis) 420 | self.log("pred_real", pred_real.mean()) 421 | self.log("pred_fake", pred_fake.mean()) 422 | 423 | self.log_dict(loss_gen) 424 | p.tick('logging') 425 | 426 | def validation_step(self, x, batch_idx): 427 | 428 | z = self.encode(x) 429 | if isinstance(self.encoder, blocks.VariationalEncoder): 430 | mean = torch.split(z, z.shape[1] // 2, 1)[0] 431 | else: 432 | mean = None 433 | 434 | z = self.encoder.reparametrize(z)[0] 435 | y = self.decode(z) 436 | 437 | distance = self.audio_distance(x, y) 438 | full_distance = sum(distance.values()) 439 | 440 | if self.trainer is not None: 441 | self.log('validation', full_distance) 442 | 443 | return torch.cat([x, y], -1), mean 444 | 445 | def validation_epoch_end(self, out): 446 | if not self.receptive_field.sum(): 447 | print("Computing receptive field for this configuration...") 448 | lrf, rrf = rave.core.get_rave_receptive_field(self, n_channels=self.n_channels) 449 | self.receptive_field[0] = lrf 450 | self.receptive_field[1] = rrf 451 | print( 452 | f"Receptive field: {1000*lrf/self.sr:.2f}ms <-- x --> {1000*rrf/self.sr:.2f}ms" 453 | ) 454 | 455 | if not len(out): return 456 | 457 | audio, z = list(zip(*out)) 458 | audio = list(map(lambda x: x.cpu(), audio)) 459 | 460 | if self.trainer.state.stage == RunningStage.SANITY_CHECKING: 461 | return 462 | 463 | # LATENT SPACE ANALYSIS 464 | if not self.warmed_up and isinstance(self.encoder, 465 | blocks.VariationalEncoder): 466 | z = torch.cat(z, 0) 467 | z = rearrange(z, "b c t -> (b t) c") 468 | 469 | self.latent_mean.copy_(z.mean(0)) 470 | z = z - self.latent_mean 471 | 472 | pca = PCA(z.shape[-1]).fit(z.cpu().numpy()) 473 | 474 | components = pca.components_ 475 | components = torch.from_numpy(components).to(z) 476 | self.latent_pca.copy_(components) 477 | 478 | var = pca.explained_variance_ / np.sum(pca.explained_variance_) 479 | var = np.cumsum(var) 480 | 481 | self.fidelity.copy_(torch.from_numpy(var).to(self.fidelity)) 482 | 483 | var_percent = [.8, .9, .95, .99] 484 | for p in var_percent: 485 | self.log( 486 | f"fidelity_{p}", 487 | np.argmax(var > p).astype(np.float32), 488 | ) 489 | 490 | y = torch.cat(audio, 0)[:8].reshape(-1).numpy() 491 | if self.integrator is not None: 492 | y = self.integrator(y) 493 | self.logger.experiment.add_audio("audio_val", y, self.eval_number, 494 | self.sr) 495 | self.eval_number += 1 496 | 497 | def on_fit_start(self): 498 | tb = self.logger.experiment 499 | 500 | config = gin.operative_config_str() 501 | config = config.split('\n') 502 | config = ['```'] + config + ['```'] 503 | config = '\n'.join(config) 504 | tb.add_text("config", config) 505 | 506 | model = str(self) 507 | model = model.split('\n') 508 | model = ['```'] + model + ['```'] 509 | model = '\n'.join(model) 510 | tb.add_text("model", model) 511 | 512 | -------------------------------------------------------------------------------- /rave/pqmf.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cached_conv as cc 4 | import gin 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from scipy.optimize import fmin 10 | from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord 11 | 12 | 13 | def reverse_half(x): 14 | mask = torch.ones_like(x) 15 | mask[..., 1::2, ::2] = -1 16 | 17 | return x * mask 18 | 19 | 20 | def center_pad_next_pow_2(x): 21 | next_2 = 2**math.ceil(math.log2(x.shape[-1])) 22 | pad = next_2 - x.shape[-1] 23 | return nn.functional.pad(x, (pad // 2, pad // 2 + int(pad % 2))) 24 | 25 | 26 | def make_odd(x): 27 | if not x.shape[-1] % 2: 28 | x = nn.functional.pad(x, (0, 1)) 29 | return x 30 | 31 | 32 | def get_qmf_bank(h, n_band): 33 | """ 34 | Modulates an input protoype filter into a bank of 35 | cosine modulated filters 36 | Parameters 37 | ---------- 38 | h: torch.Tensor 39 | prototype filter 40 | n_band: int 41 | number of sub-bands 42 | """ 43 | k = torch.arange(n_band).reshape(-1, 1) 44 | N = h.shape[-1] 45 | t = torch.arange(-(N // 2), N // 2 + 1) 46 | 47 | p = (-1)**k * math.pi / 4 48 | 49 | mod = torch.cos((2 * k + 1) * math.pi / (2 * n_band) * t + p) 50 | hk = 2 * h * mod 51 | 52 | return hk 53 | 54 | 55 | def kaiser_filter(wc, atten, N=None): 56 | """ 57 | Computes a kaiser lowpass filter 58 | Parameters 59 | ---------- 60 | wc: float 61 | Angular frequency 62 | 63 | atten: float 64 | Attenuation (dB, positive) 65 | """ 66 | N_, beta = kaiserord(atten, wc / np.pi) 67 | N_ = 2 * (N_ // 2) + 1 68 | N = N if N is not None else N_ 69 | h = firwin(N, wc, window=('kaiser', beta), scale=False, nyq=np.pi) 70 | return h 71 | 72 | 73 | def loss_wc(wc, atten, M, N): 74 | """ 75 | Computes the objective described in https://ieeexplore.ieee.org/document/681427 76 | """ 77 | h = kaiser_filter(wc, atten, N) 78 | g = np.convolve(h, h[::-1], "full") 79 | g = abs(g[g.shape[-1] // 2::2 * M][1:]) 80 | return np.max(g) 81 | 82 | 83 | def get_prototype(atten, M, N=None): 84 | """ 85 | Given an attenuation objective and the number of bands 86 | returns the corresponding lowpass filter 87 | """ 88 | wc = fmin(lambda w: loss_wc(w, atten, M, N), 1 / M, disp=0)[0] 89 | return kaiser_filter(wc, atten, N) 90 | 91 | 92 | def polyphase_forward(x, hk, rearrange_filter=True): 93 | """ 94 | Polyphase implementation of the analysis process (fast) 95 | Parameters 96 | ---------- 97 | x: torch.Tensor 98 | signal to analyse ( B x 1 x T ) 99 | 100 | hk: torch.Tensor 101 | filter bank ( M x T ) 102 | """ 103 | x = rearrange(x, "b c (t m) -> b (c m) t", m=hk.shape[0]) 104 | if rearrange_filter: 105 | hk = rearrange(hk, "c (t m) -> c m t", m=hk.shape[0]) 106 | x = nn.functional.conv1d(x, hk, padding=hk.shape[-1] // 2)[..., :-1] 107 | return x 108 | 109 | 110 | def polyphase_inverse(x, hk, rearrange_filter=True): 111 | """ 112 | Polyphase implementation of the synthesis process (fast) 113 | Parameters 114 | ---------- 115 | x: torch.Tensor 116 | signal to synthesize from ( B x 1 x T ) 117 | 118 | hk: torch.Tensor 119 | filter bank ( M x T ) 120 | """ 121 | 122 | m = hk.shape[0] 123 | 124 | if rearrange_filter: 125 | hk = hk.flip(-1) 126 | hk = rearrange(hk, "c (t m) -> m c t", m=m) # polyphase 127 | 128 | pad = hk.shape[-1] // 2 + 1 129 | x = nn.functional.conv1d(x, hk, padding=int(pad))[..., :-1] * m 130 | 131 | x = x.flip(1) 132 | x = rearrange(x, "b (c m) t -> b c (t m)", m=m) 133 | x = x[..., 2 * hk.shape[1]:] 134 | return x 135 | 136 | 137 | def classic_forward(x, hk): 138 | """ 139 | Naive implementation of the analysis process (slow) 140 | Parameters 141 | ---------- 142 | x: torch.Tensor 143 | signal to analyse ( B x 1 x T ) 144 | 145 | hk: torch.Tensor 146 | filter bank ( M x T ) 147 | """ 148 | x = nn.functional.conv1d( 149 | x, 150 | hk.unsqueeze(1), 151 | stride=hk.shape[0], 152 | padding=hk.shape[-1] // 2, 153 | )[..., :-1] 154 | return x 155 | 156 | 157 | def classic_inverse(x, hk): 158 | """ 159 | Naive implementation of the synthesis process (slow) 160 | Parameters 161 | ---------- 162 | x: torch.Tensor 163 | signal to synthesize from ( B x 1 x T ) 164 | 165 | hk: torch.Tensor 166 | filter bank ( M x T ) 167 | """ 168 | hk = hk.flip(-1) 169 | y = torch.zeros(*x.shape[:2], hk.shape[0] * x.shape[-1]).to(x) 170 | y[..., ::hk.shape[0]] = x * hk.shape[0] 171 | y = nn.functional.conv1d( 172 | y, 173 | hk.unsqueeze(0), 174 | padding=hk.shape[-1] // 2, 175 | )[..., 1:] 176 | return y 177 | 178 | 179 | @torch.fx.wrap 180 | class PQMF(nn.Module): 181 | """ 182 | Pseudo Quadrature Mirror Filter multiband decomposition / reconstruction 183 | Parameters 184 | ---------- 185 | attenuation: int 186 | Attenuation of the rejected bands (dB, 80 - 120) 187 | n_band: int 188 | Number of bands, must be a power of 2 if the polyphase implementation 189 | is needed 190 | """ 191 | 192 | def __init__(self, attenuation, n_band, polyphase=True, n_channels = 1): 193 | super().__init__() 194 | h = get_prototype(attenuation, n_band) 195 | 196 | if polyphase: 197 | power = math.log2(n_band) 198 | assert power == math.floor( 199 | power 200 | ), "when using the polyphase algorithm, n_band must be a power of 2" 201 | 202 | h = torch.from_numpy(h).float() 203 | hk = get_qmf_bank(h, n_band) 204 | hk = center_pad_next_pow_2(hk) 205 | 206 | self.register_buffer("hk", hk) 207 | self.register_buffer("h", h) 208 | self.n_band = n_band 209 | self.polyphase = polyphase 210 | self.n_channels = n_channels 211 | 212 | def forward(self, x): 213 | if x.ndim == 2: 214 | return torch.stack([self.forward(x[i]) for i in range(x.shape[0])]) 215 | if self.n_band == 1: 216 | return x 217 | elif self.polyphase: 218 | x = polyphase_forward(x, self.hk) 219 | else: 220 | x = classic_forward(x, self.hk) 221 | 222 | x = reverse_half(x) 223 | 224 | return x 225 | 226 | def inverse(self, x): 227 | if x.ndim == 2: 228 | if self.n_channels == 1: 229 | return self.inverse(x[0]).unsqueeze(0) 230 | else: 231 | x = x.split(self.n_channels, -2) 232 | return torch.stack([self.inverse(x[i]) for i in len(x)]) 233 | 234 | if self.n_band == 1: 235 | return x 236 | 237 | x = reverse_half(x) 238 | 239 | if self.polyphase: 240 | return polyphase_inverse(x, self.hk) 241 | else: 242 | return classic_inverse(x, self.hk) 243 | 244 | 245 | class CachedPQMF(PQMF): 246 | 247 | def __init__(self, *args, **kwargs): 248 | super().__init__(*args, **kwargs) 249 | 250 | hkf = make_odd(self.hk).unsqueeze(1) 251 | 252 | hki = self.hk.flip(-1) 253 | hki = rearrange(hki, "c (t m) -> m c t", m=self.hk.shape[0]) 254 | hki = make_odd(hki) 255 | 256 | self.forward_conv = cc.Conv1d( 257 | hkf.shape[1], 258 | hkf.shape[0], 259 | hkf.shape[2], 260 | padding=cc.get_padding(hkf.shape[-1]), 261 | stride=hkf.shape[0], 262 | bias=False, 263 | ) 264 | self.forward_conv.weight.data.copy_(hkf) 265 | 266 | self.inverse_conv = cc.Conv1d( 267 | hki.shape[1], 268 | hki.shape[0], 269 | hki.shape[-1], 270 | padding=cc.get_padding(hki.shape[-1]), 271 | bias=False, 272 | ) 273 | self.inverse_conv.weight.data.copy_(hki) 274 | 275 | def script_cache(self): 276 | self.forward_conv.script_cache() 277 | self.inverse_conv.script_cache() 278 | 279 | def forward(self, x): 280 | if self.n_band == 1: return x 281 | x = self.forward_conv(x) 282 | x = reverse_half(x) 283 | return x 284 | 285 | def inverse(self, x): 286 | if self.n_band == 1: return x 287 | x = reverse_half(x) 288 | m = self.hk.shape[0] 289 | x = self.inverse_conv(x) * m 290 | x = x.flip(1) 291 | x = x.permute(0, 2, 1) 292 | x = x.reshape(x.shape[0], x.shape[1], -1, m).permute(0, 2, 1, 3) 293 | x = x.reshape(x.shape[0], x.shape[1], -1) 294 | return x 295 | -------------------------------------------------------------------------------- /rave/prior/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /rave/prior/core.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class QuantizedNormal(nn.Module): 7 | def __init__(self, resolution, dither=True): 8 | super().__init__() 9 | self.resolution = resolution 10 | self.dither = dither 11 | self.clamp = 4 12 | 13 | def from_normal(self, x): 14 | return .5 * (1 + torch.erf(x / math.sqrt(2))) 15 | 16 | def to_normal(self, x): 17 | x = torch.erfinv(2 * x - 1) * math.sqrt(2) 18 | return torch.clamp(x, -self.clamp, self.clamp) 19 | 20 | def encode(self, x): 21 | x = self.from_normal(x) 22 | x = torch.floor(x * self.resolution) 23 | x = torch.clamp(x, 0, self.resolution - 1) 24 | return self.to_stack_one_hot(x.long()) 25 | 26 | def to_stack_one_hot(self, x): 27 | x = nn.functional.one_hot(x, self.resolution) 28 | x = x.permute(0, 2, 1, 3) 29 | x = x.reshape(x.shape[0], x.shape[1], -1) 30 | x = x.permute(0, 2, 1).float() 31 | return x 32 | 33 | def decode(self, x): 34 | x = x.permute(0, 2, 1) 35 | x = x.reshape(x.shape[0], x.shape[1], -1, self.resolution) 36 | x = torch.argmax(x, -1) / self.resolution 37 | if self.dither: 38 | x = x + torch.rand_like(x) / self.resolution 39 | x = self.to_normal(x) 40 | x = x.permute(0, 2, 1) 41 | return x 42 | 43 | 44 | class DiagonalShift(nn.Module): 45 | def __init__(self, groups=1): 46 | super().__init__() 47 | assert isinstance(groups, int) 48 | assert groups > 0 49 | self.groups = groups 50 | 51 | def shift(self, x: torch.Tensor, i: int, n_dim: int): 52 | i = i // self.groups 53 | n_dim = n_dim // self.groups 54 | start = i 55 | end = -n_dim + i + 1 56 | end = end if end else None 57 | return x[..., start:end] 58 | 59 | def forward(self, x): 60 | n_dim = x.shape[1] 61 | x = torch.split(x, 1, 1) 62 | x = [ 63 | self.shift(_x, i, n_dim) for _x, i in zip( 64 | x, 65 | torch.arange(n_dim).flip(0), 66 | ) 67 | ] 68 | x = torch.cat(list(x), 1) 69 | return x 70 | 71 | def inverse(self, x): 72 | x = x.flip(1) 73 | x = self.forward(x) 74 | x = x.flip(1) 75 | return x -------------------------------------------------------------------------------- /rave/prior/model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import pytorch_lightning as pl 6 | import gin 7 | from tqdm import tqdm 8 | import math 9 | import numpy as np 10 | 11 | from .residual_block import ResidualBlock 12 | from .core import DiagonalShift, QuantizedNormal 13 | 14 | 15 | import cached_conv as cc 16 | 17 | class Prior(pl.LightningModule): 18 | 19 | def __init__(self, resolution, res_size, skp_size, kernel_size, cycle_size, 20 | n_layers, pretrained_vae=None, fidelity=None, n_channels=1, latent_size=None, sr=44100): 21 | super().__init__() 22 | 23 | self.diagonal_shift = DiagonalShift() 24 | self.quantized_normal = QuantizedNormal(resolution) 25 | 26 | self.synth = pretrained_vae 27 | self.sr = sr 28 | 29 | if latent_size is not None: 30 | self.latent_size = 2**math.ceil(math.log2(latent_size)) 31 | elif fidelity is not None: 32 | assert pretrained_vae, "giving fidelity keyword needs the pretrained_vae keyword to be given" 33 | latent_size = torch.where(pretrained_vae.fidelity > fidelity)[0][0] 34 | self.latent_size = 2**math.ceil(math.log2(latent_size)) 35 | else: 36 | raise RuntimeError('please init Prior with either fidelity or latent_size keywords') 37 | 38 | self.pre_net = nn.Sequential( 39 | cc.Conv1d( 40 | resolution * self.latent_size, 41 | res_size, 42 | kernel_size, 43 | padding=cc.get_padding(kernel_size, mode="causal"), 44 | groups=self.latent_size, 45 | ), 46 | nn.LeakyReLU(.2), 47 | ) 48 | 49 | self.residuals = nn.ModuleList([ 50 | ResidualBlock( 51 | res_size, 52 | skp_size, 53 | kernel_size, 54 | 2**(i % cycle_size), 55 | ) for i in range(n_layers) 56 | ]) 57 | 58 | self.post_net = nn.Sequential( 59 | cc.Conv1d(skp_size, skp_size, 1), 60 | nn.LeakyReLU(.2), 61 | cc.Conv1d( 62 | skp_size, 63 | resolution * self.latent_size, 64 | 1, 65 | groups=self.latent_size, 66 | ), 67 | ) 68 | 69 | self.n_channels = n_channels 70 | self.val_idx = 0 71 | rf = (kernel_size - 1) * sum(2**(np.arange(n_layers) % cycle_size)) + 1 72 | if pretrained_vae is not None: 73 | ratio = self.get_model_ratio() 74 | self.min_receptive_field = 2**math.ceil(math.log2(rf * ratio)) 75 | 76 | def get_model_ratio(self): 77 | x_len = 2**14 78 | x = torch.zeros(1, self.n_channels, x_len) 79 | z = self.encode(x) 80 | ratio_encode = x_len // z.shape[-1] 81 | return ratio_encode 82 | 83 | def configure_optimizers(self): 84 | p = [] 85 | p.extend(list(self.pre_net.parameters())) 86 | p.extend(list(self.residuals.parameters())) 87 | p.extend(list(self.post_net.parameters())) 88 | return torch.optim.Adam(p, lr=1e-4) 89 | 90 | @torch.no_grad() 91 | def encode(self, x): 92 | self.synth.eval() 93 | z = self.synth.encode(x) 94 | z = self.post_process_latent(z) 95 | return z 96 | 97 | @torch.no_grad() 98 | def decode(self, z): 99 | self.synth.eval() 100 | z = self.pre_process_latent(z) 101 | return self.synth.decode(z) 102 | 103 | def forward(self, x): 104 | res = self.pre_net(x) 105 | skp = torch.tensor(0.).to(x) 106 | for layer in self.residuals: 107 | res, skp = layer(res, skp) 108 | x = self.post_net(skp) 109 | return x 110 | 111 | @torch.no_grad() 112 | def generate(self, x, argmax: bool = False): 113 | for i in tqdm(range(x.shape[-1] - 1)): 114 | if cc.USE_BUFFER_CONV: 115 | start = i 116 | else: 117 | start = None 118 | 119 | pred = self.forward(x[..., start:i + 1]) 120 | 121 | if not cc.USE_BUFFER_CONV: 122 | pred = pred[..., -1:] 123 | 124 | pred = self.post_process_prediction(pred, argmax=argmax) 125 | 126 | x[..., i + 1:i + 2] = pred 127 | return x 128 | 129 | def split_classes(self, x): 130 | # B x D*C x T 131 | x = x.permute(0, 2, 1) 132 | x = x.reshape(x.shape[0], x.shape[1], self.latent_size, -1) 133 | x = x.permute(0, 2, 1, 3) # B x D x T x C 134 | return x 135 | 136 | def post_process_prediction(self, x, argmax: bool = False): 137 | x = self.split_classes(x) 138 | shape = x.shape[:-1] 139 | x = x.reshape(-1, x.shape[-1]) 140 | 141 | if argmax: 142 | x = torch.argmax(x, -1) 143 | else: 144 | x = torch.softmax(x - torch.logsumexp(x, -1, keepdim=True), -1) 145 | x = torch.multinomial(x, 1, True).squeeze(-1) 146 | 147 | x = x.reshape(shape[0], shape[1], shape[2]) 148 | x = self.quantized_normal.to_stack_one_hot(x) 149 | return x 150 | 151 | def training_step(self, batch, batch_idx): 152 | x = self.encode(batch) 153 | x = self.quantized_normal.encode(self.diagonal_shift(x)) 154 | pred = self.forward(x) 155 | 156 | x = torch.argmax(self.split_classes(x[..., 1:]), -1) 157 | pred = self.split_classes(pred[..., :-1]) 158 | 159 | loss = nn.functional.cross_entropy( 160 | pred.reshape(-1, self.quantized_normal.resolution), 161 | x.reshape(-1), 162 | ) 163 | 164 | self.log("latent_prediction", loss) 165 | return loss 166 | 167 | def validation_step(self, batch, batch_idx): 168 | x = self.encode(batch) 169 | x = self.quantized_normal.encode(self.diagonal_shift(x)) 170 | pred = self.forward(x) 171 | 172 | x = torch.argmax(self.split_classes(x[..., 1:]), -1) 173 | pred = self.split_classes(pred[..., :-1]) 174 | 175 | loss = nn.functional.cross_entropy( 176 | pred.reshape(-1, self.quantized_normal.resolution), 177 | x.reshape(-1), 178 | ) 179 | 180 | self.log("validation", loss) 181 | return batch 182 | 183 | def validation_epoch_end(self, out): 184 | x = torch.randn_like(self.encode(out[0])) 185 | x = self.quantized_normal.encode(self.diagonal_shift(x)) 186 | z = self.generate(x) 187 | z = self.diagonal_shift.inverse(self.quantized_normal.decode(z)) 188 | 189 | y = self.decode(z) 190 | self.logger.experiment.add_audio( 191 | "generation", 192 | y.reshape(-1), 193 | self.val_idx, 194 | self.synth.sr, 195 | ) 196 | self.val_idx += 1 197 | 198 | @abc.abstractmethod 199 | def post_process_latent(self, z): 200 | raise NotImplementedError() 201 | 202 | @abc.abstractmethod 203 | def pre_process_latent(self, z): 204 | raise NotImplementedError() 205 | 206 | 207 | 208 | @gin.configurable 209 | class VariationalPrior(Prior): 210 | 211 | def post_process_latent(self, z): 212 | z = self.synth.encoder.reparametrize(z)[0] 213 | z = z - self.synth.latent_mean.unsqueeze(-1) 214 | z = F.conv1d(z, self.synth.latent_pca.unsqueeze(-1)) 215 | z = z[:, :self.latent_size] 216 | return z 217 | 218 | def pre_process_latent(self, z): 219 | noise = torch.randn( 220 | z.shape[0], 221 | self.synth.latent_size - z.shape[1], 222 | z.shape[-1], 223 | ).type_as(z) 224 | z = torch.cat([z, noise], 1) 225 | z = F.conv1d(z, self.synth.latent_pca.T.unsqueeze(-1)) 226 | z = z + self.synth.latent_mean.unsqueeze(-1) 227 | return z -------------------------------------------------------------------------------- /rave/prior/residual_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import cached_conv as cc 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | 8 | def __init__(self, res_size, skp_size, kernel_size, dilation): 9 | super().__init__() 10 | fks = (kernel_size - 1) * dilation + 1 11 | 12 | self.dconv = cc.Conv1d( 13 | res_size, 14 | 2 * res_size, 15 | kernel_size, 16 | padding=(fks - 1, 0), 17 | dilation=dilation, 18 | ) 19 | 20 | self.rconv = nn.Conv1d(res_size, res_size, 1) 21 | self.sconv = nn.Conv1d(res_size, skp_size, 1) 22 | 23 | def forward(self, x, skp): 24 | res = x.clone() 25 | 26 | x = self.dconv(x) 27 | xa, xb = torch.split(x, x.shape[1] // 2, 1) 28 | 29 | x = torch.sigmoid(xa) * torch.tanh(xb) 30 | res = res + self.rconv(x) 31 | skp = skp + self.sconv(x) 32 | return res, skp -------------------------------------------------------------------------------- /rave/quantization.py: -------------------------------------------------------------------------------- 1 | # Code adapted from https://github.com/lucidrains/vector-quantize-pytorch 2 | 3 | from typing import Any, Callable, Optional, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from einops import repeat 8 | from torch import nn 9 | 10 | 11 | def ema_inplace(moving_avg, new, decay: float): 12 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 13 | 14 | 15 | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): 16 | return (x + epsilon) / (x.sum() + n_categories * epsilon) 17 | 18 | 19 | def uniform_init(*shape: int): 20 | t = torch.empty(shape) 21 | nn.init.kaiming_uniform_(t) 22 | return t 23 | 24 | 25 | def sample_vectors(samples, num: int): 26 | num_samples, device = samples.shape[0], samples.device 27 | 28 | if num_samples >= num: 29 | indices = torch.randperm(num_samples, device=device)[:num] 30 | else: 31 | indices = torch.randint(0, num_samples, (num, ), device=device) 32 | 33 | return samples[indices] 34 | 35 | 36 | def kmeans(samples, num_clusters: int, num_iters: int = 10): 37 | dim, dtype = samples.shape[-1], samples.dtype 38 | 39 | means = sample_vectors(samples, num_clusters) 40 | 41 | for _ in range(num_iters): 42 | diffs = samples[:, None] - means[None] 43 | dists = -(diffs**2).sum(dim=-1) 44 | 45 | buckets = dists.max(dim=-1).indices 46 | bins = torch.bincount(buckets, minlength=num_clusters) 47 | zero_mask = bins == 0 48 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 49 | 50 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 51 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) 52 | new_means = new_means / bins_min_clamped[..., None] 53 | 54 | means = torch.where(zero_mask[..., None], means, new_means) 55 | 56 | return means, bins 57 | 58 | 59 | class EuclideanCodebook(nn.Module): 60 | """Codebook with Euclidean distance. 61 | Args: 62 | dim (int): Dimension. 63 | codebook_size (int): Codebook size. 64 | kmeans_init (bool): Whether to use k-means to initialize the codebooks. 65 | If set to true, run the k-means algorithm on the first training batch and use 66 | the learned centroids as initialization. 67 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. 68 | decay (float): Decay for exponential moving average over the codebooks. 69 | epsilon (float): Epsilon value for numerical stability. 70 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 71 | that have an exponential moving average cluster size less than the specified threshold with 72 | randomly selected vector from the current batch. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | dim: int, 78 | codebook_size: int, 79 | kmeans_init: int = False, 80 | kmeans_iters: int = 10, 81 | decay: float = 0.99, 82 | epsilon: float = 1e-5, 83 | threshold_ema_dead_code: int = 2, 84 | ): 85 | super().__init__() 86 | self.decay = decay 87 | init_fn: Union[Callable[..., torch.Tensor], 88 | Any] = uniform_init if not kmeans_init else torch.zeros 89 | embed = init_fn(codebook_size, dim) 90 | 91 | self.codebook_size = codebook_size 92 | 93 | self.kmeans_iters = kmeans_iters 94 | self.epsilon = epsilon 95 | self.threshold_ema_dead_code = threshold_ema_dead_code 96 | 97 | self.register_buffer("inited", torch.Tensor([not kmeans_init])) 98 | self.register_buffer("cluster_size", torch.zeros(codebook_size)) 99 | self.register_buffer("embed", embed) 100 | self.register_buffer("embed_avg", embed.clone()) 101 | 102 | @torch.jit.unused 103 | def init_embed_(self, data): 104 | embed, cluster_size = kmeans(data, self.codebook_size, 105 | self.kmeans_iters) 106 | self.embed.data.copy_(embed) 107 | self.embed_avg.data.copy_(embed.clone()) 108 | self.cluster_size.data.copy_(cluster_size) 109 | self.inited.data.copy_(torch.Tensor([True])) 110 | 111 | def replace_(self, samples, mask): 112 | modified_codebook = torch.where( 113 | mask[..., None], sample_vectors(samples, self.codebook_size), 114 | self.embed) 115 | self.embed.data.copy_(modified_codebook) 116 | 117 | def expire_codes_(self, batch_samples): 118 | if self.threshold_ema_dead_code == 0: 119 | return 120 | 121 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 122 | if not torch.any(expired_codes): 123 | return 124 | 125 | batch_samples = batch_samples.reshape(-1, batch_samples.shape[-1]) 126 | self.replace_(batch_samples, mask=expired_codes) 127 | 128 | def preprocess(self, x): 129 | return x.reshape(-1, x.shape[-1]) 130 | 131 | def quantize(self, x): 132 | embed = self.embed.t() 133 | dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + 134 | embed.pow(2).sum(0, keepdim=True)) 135 | embed_ind = dist.max(dim=-1).indices 136 | return embed_ind 137 | 138 | def dequantize(self, embed_ind): 139 | quantize = F.embedding(embed_ind, self.embed) 140 | return quantize 141 | 142 | def encode(self, x): 143 | shape = x.shape 144 | # pre-process 145 | x = self.preprocess(x) 146 | # quantize 147 | embed_ind = self.quantize(x) 148 | # post-process 149 | embed_ind = embed_ind.reshape(shape[0], shape[1]) 150 | return embed_ind 151 | 152 | def decode(self, embed_ind): 153 | quantize = self.dequantize(embed_ind) 154 | return quantize 155 | 156 | def forward(self, x): 157 | shape, dtype = x.shape, x.dtype 158 | x = self.preprocess(x) 159 | 160 | if not self.inited: 161 | self.init_embed_(x) 162 | 163 | embed_ind = self.quantize(x) 164 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 165 | embed_ind = embed_ind.reshape(shape[0], shape[1]) 166 | quantize = self.dequantize(embed_ind) 167 | 168 | if self.training: 169 | # We do the expiry of code at that point as buffers are in sync 170 | # and all the workers will take the same decision. 171 | self.expire_codes_(x) 172 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) 173 | embed_sum = x.t() @ embed_onehot 174 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay) 175 | cluster_size = (laplace_smoothing( 176 | self.cluster_size, self.codebook_size, self.epsilon) * 177 | self.cluster_size.sum()) 178 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 179 | self.embed.data.copy_(embed_normalized) 180 | 181 | return quantize, embed_ind 182 | 183 | 184 | class VectorQuantization(nn.Module): 185 | """Vector quantization implementation. 186 | Currently supports only euclidean distance. 187 | Args: 188 | dim (int): Dimension 189 | codebook_size (int): Codebook size 190 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. 191 | decay (float): Decay for exponential moving average over the codebooks. 192 | epsilon (float): Epsilon value for numerical stability. 193 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 194 | kmeans_iters (int): Number of iterations used for kmeans initialization. 195 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 196 | that have an exponential moving average cluster size less than the specified threshold with 197 | randomly selected vector from the current batch. 198 | commitment_weight (float): Weight for commitment loss. 199 | """ 200 | 201 | def __init__( 202 | self, 203 | dim: int, 204 | codebook_size: int, 205 | codebook_dim: Optional[int] = None, 206 | decay: float = 0.99, 207 | epsilon: float = 1e-5, 208 | kmeans_init: bool = True, 209 | kmeans_iters: int = 50, 210 | threshold_ema_dead_code: int = 2, 211 | commitment_weight: float = 1., 212 | ): 213 | super().__init__() 214 | _codebook_dim: int = codebook_dim or dim 215 | 216 | requires_projection = _codebook_dim != dim 217 | self.project_in = (nn.Linear(dim, _codebook_dim) 218 | if requires_projection else nn.Identity()) 219 | self.project_out = (nn.Linear(_codebook_dim, dim) 220 | if requires_projection else nn.Identity()) 221 | 222 | self.epsilon = epsilon 223 | self.commitment_weight = commitment_weight 224 | 225 | self._codebook = EuclideanCodebook( 226 | dim=_codebook_dim, 227 | codebook_size=codebook_size, 228 | kmeans_init=kmeans_init, 229 | kmeans_iters=kmeans_iters, 230 | decay=decay, 231 | epsilon=epsilon, 232 | threshold_ema_dead_code=threshold_ema_dead_code) 233 | self.codebook_size = codebook_size 234 | 235 | @property 236 | def codebook(self): 237 | return self._codebook.embed 238 | 239 | def encode(self, x): 240 | x = x.permute(0, 2, 1) 241 | x = self.project_in(x) 242 | embed_in = self._codebook.encode(x) 243 | return embed_in 244 | 245 | def decode(self, embed_ind): 246 | quantize = self._codebook.decode(embed_ind) 247 | quantize = self.project_out(quantize) 248 | quantize = quantize.permute(0, 2, 1) 249 | return quantize 250 | 251 | def forward(self, x): 252 | device = x.device 253 | x = x.permute(0, 2, 1) 254 | x = self.project_in(x) 255 | 256 | quantize, embed_ind = self._codebook(x) 257 | 258 | if self.training: 259 | quantize = x + (quantize - x).detach() 260 | 261 | loss = torch.tensor([0.0], device=device, requires_grad=self.training) 262 | 263 | if self.training: 264 | if self.commitment_weight > 0: 265 | commit_loss = F.mse_loss(quantize.detach(), x) 266 | loss = loss + commit_loss * self.commitment_weight 267 | 268 | quantize = self.project_out(quantize) 269 | quantize = quantize.permute(0, 2, 1) 270 | return quantize, embed_ind, loss 271 | 272 | 273 | class ResidualVectorQuantization(nn.Module): 274 | """Residual vector quantization implementation. 275 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf 276 | """ 277 | 278 | def __init__(self, num_quantizers, **kwargs): 279 | super().__init__() 280 | self.layers = nn.ModuleList( 281 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)]) 282 | 283 | def forward(self, x): 284 | quantized_out = 0.0 285 | residual = x 286 | 287 | all_losses = [] 288 | all_indices = [] 289 | 290 | for layer in self.layers: 291 | quantized, indices, loss = layer(residual) 292 | residual = residual - quantized 293 | quantized_out = quantized_out + quantized 294 | 295 | all_indices.append(indices) 296 | all_losses.append(loss) 297 | 298 | out_losses = torch.stack(all_losses, 0).sum() 299 | all_indices = torch.stack(all_indices, 1) 300 | return quantized_out, out_losses, all_indices 301 | 302 | def encode(self, x: torch.Tensor) -> torch.Tensor: 303 | residual = x 304 | all_indices = [] 305 | for layer in self.layers: 306 | indices = layer.encode(residual) 307 | quantized = layer.decode(indices) 308 | residual = residual - quantized 309 | all_indices.append(indices) 310 | out_indices = torch.stack(all_indices, 1) 311 | return out_indices 312 | 313 | def decode(self, q_indices: torch.Tensor) -> torch.Tensor: 314 | quantized_out = torch.tensor(0.0, device=q_indices.device) 315 | for i, layer in enumerate(self.layers): 316 | quantized = layer.decode(q_indices[:, i]) 317 | quantized_out = quantized_out + quantized 318 | return quantized_out -------------------------------------------------------------------------------- /rave/resampler.py: -------------------------------------------------------------------------------- 1 | import cached_conv as cc 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .pqmf import kaiser_filter 7 | 8 | 9 | class Resampler(nn.Module): 10 | 11 | def __init__(self, target_sr, model_sr): 12 | super().__init__() 13 | assert target_sr != model_sr, "identical source and target rates" 14 | 15 | self.model_sr = model_sr 16 | self.taget_sr = target_sr 17 | 18 | ratio = target_sr // model_sr 19 | assert int(ratio) == ratio 20 | 21 | if ratio % 2 and cc.USE_BUFFER_CONV: 22 | raise ValueError( 23 | f"When using streaming mode, resampling ratio must be a power of 2, got {ratio}" 24 | ) 25 | 26 | wc = np.pi / ratio 27 | filt = kaiser_filter(wc, 140) 28 | filt = torch.from_numpy(filt).float() 29 | 30 | self.downsample = cc.Conv1d( 31 | 1, 32 | 1, 33 | len(filt), 34 | stride=ratio, 35 | padding=cc.get_padding(len(filt), ratio), 36 | bias=False, 37 | ) 38 | 39 | self.downsample.weight.data.copy_(filt.reshape(1, 1, -1)) 40 | 41 | pad = len(filt) % ratio 42 | 43 | filt = nn.functional.pad(filt, (pad, 0)) 44 | filt = filt.reshape(-1, ratio).permute(1, 0) 45 | 46 | pad = (filt.shape[-1] + 1) % 2 47 | filt = nn.functional.pad(filt, (pad, 0)).unsqueeze(1) 48 | 49 | self.upsample = cc.Conv1d(1, 50 | ratio, 51 | filt.shape[-1], 52 | stride=1, 53 | padding=cc.get_padding(filt.shape[-1]), 54 | bias=False) 55 | 56 | self.upsample.weight.data.copy_(filt) 57 | 58 | self.ratio = ratio 59 | 60 | def to_model_sampling_rate(self, x): 61 | x_down = x.reshape(-1, 1, x.shape[-1]) 62 | x_down = self.downsample(x_down) 63 | return x_down.reshape(x.shape[0], x.shape[1], -1) 64 | 65 | def from_model_sampling_rate(self, x): 66 | x_up = x.reshape(-1, 1, x.shape[-1]) 67 | x_up = self.upsample(x_up) # B x 2 x T 68 | x_up = x_up.permute(0, 2, 1).reshape(x_up.shape[0], -1).unsqueeze(1) 69 | x_up = x_up.reshape(x.shape[0], x.shape[1], -1) 70 | return x_up 71 | -------------------------------------------------------------------------------- /rave/transforms.py: -------------------------------------------------------------------------------- 1 | from random import choice, randint, random, randrange 2 | import bisect 3 | import torchaudio 4 | import gin.torch 5 | from typing import Tuple 6 | import librosa as li 7 | import numpy as np 8 | import torch 9 | import scipy.signal as signal 10 | from udls.transforms import * 11 | 12 | 13 | class Transform(object): 14 | def __call__(self, x: torch.Tensor): 15 | raise NotImplementedError 16 | 17 | 18 | class RandomApply(Transform): 19 | """ 20 | Apply transform with probability p 21 | """ 22 | def __init__(self, transform, p=.5): 23 | self.transform = transform 24 | self.p = p 25 | 26 | def __call__(self, x: np.ndarray): 27 | if random() < self.p: 28 | x = self.transform(x) 29 | return x 30 | 31 | class Resample(Transform): 32 | """ 33 | Resample target signal to target sample rate. 34 | """ 35 | def __init__(self, orig_sr: int, target_sr: int): 36 | self.orig_sr = orig_sr 37 | self.target_sr = target_sr 38 | 39 | def __call__(self, x: np.ndarray): 40 | return torchaudio.functional.resample(torch.from_numpy(x).float(), self.orig_sr, self.target_sr).numpy() 41 | 42 | 43 | class Compose(Transform): 44 | """ 45 | Apply a list of transform sequentially 46 | """ 47 | def __init__(self, transform_list): 48 | self.transform_list = transform_list 49 | 50 | def __call__(self, x: np.ndarray): 51 | for elm in self.transform_list: 52 | x = elm(x) 53 | return x 54 | 55 | 56 | class RandomPitch(Transform): 57 | def __init__(self, n_signal, pitch_range = [0.7, 1.3], max_factor: int = 20, prob: float = 0.5): 58 | self.n_signal = n_signal 59 | self.pitch_range = pitch_range 60 | self.factor_list, self.ratio_list = self._get_factors(max_factor, pitch_range) 61 | self.prob = prob 62 | 63 | def _get_factors(self, factor_limit, pitch_range): 64 | factor_list = [] 65 | ratio_list = [] 66 | for x in range(1, factor_limit): 67 | for y in range(1, factor_limit): 68 | if (x==y): 69 | continue 70 | factor = x / y 71 | if factor <= pitch_range[1] and factor >= pitch_range[0]: 72 | i = bisect.bisect_left(factor_list, factor) 73 | factor_list.insert(i, factor) 74 | ratio_list.insert(i, (x, y)) 75 | return factor_list, ratio_list 76 | 77 | def __call__(self, x: np.ndarray): 78 | perform_pitch = bool(torch.bernoulli(torch.tensor(self.prob))) 79 | if not perform_pitch: 80 | return x 81 | random_range = list(self.pitch_range) 82 | random_range[1] = min(random_range[1], x.shape[-1] / self.n_signal) 83 | random_pitch = random() * (random_range[1] - random_range[0]) + random_range[0] 84 | ratio_idx = bisect.bisect_left(self.factor_list, random_pitch) 85 | if ratio_idx == len(self.factor_list): 86 | ratio_idx -= 1 87 | up, down = self.ratio_list[ratio_idx] 88 | x_pitched = signal.resample_poly(x, up, down, padtype='mean', axis=-1) 89 | return x_pitched 90 | 91 | 92 | class RandomCrop(Transform): 93 | """ 94 | Randomly crops signal to fit n_signal samples 95 | """ 96 | def __init__(self, n_signal): 97 | self.n_signal = n_signal 98 | 99 | def __call__(self, x: np.ndarray): 100 | in_point = randint(0, x.shape[-1] - self.n_signal) 101 | x = x[..., in_point:in_point + self.n_signal] 102 | return x 103 | 104 | 105 | class Dequantize(Transform): 106 | def __init__(self, bit_depth): 107 | self.bit_depth = bit_depth 108 | 109 | def __call__(self, x: np.ndarray): 110 | x += np.random.rand(*x.shape) / 2**self.bit_depth 111 | return x 112 | 113 | 114 | @gin.configurable 115 | class Compress(Transform): 116 | def __init__(self, time="0.1,0.1", lookup="6:-70,-60,-20 ", gain="0", sr=44100): 117 | self.sox_args = ['compand', time, lookup, gain] 118 | self.sr = sr 119 | 120 | def __call__(self, x: torch.Tensor): 121 | x = torchaudio.sox_effects.apply_effects_tensor(torch.from_numpy(x).float(), self.sr, [self.sox_args])[0].numpy() 122 | return x 123 | 124 | @gin.configurable 125 | class RandomCompress(Transform): 126 | def __init__(self, threshold = -40, amp_range = [-60, 0], attack=0.1, release=0.1, prob=0.8, sr=44100): 127 | assert prob >= 0. and prob <= 1., "prob must be between 0. and 1." 128 | self.amp_range = amp_range 129 | self.threshold = threshold 130 | self.attack = attack 131 | self.release = release 132 | self.prob = prob 133 | self.sr = sr 134 | 135 | def __call__(self, x: torch.Tensor): 136 | perform = bool(torch.bernoulli(torch.full((1,), self.prob))) 137 | if perform: 138 | amp_factor = torch.rand((1,)) * (self.amp_range[1] - self.amp_range[0]) + self.amp_range[0] 139 | x_aug = torchaudio.sox_effects.apply_effects_tensor(torch.from_numpy(x).float(), 140 | self.sr, 141 | [['compand', f'{self.attack},{self.release}', f'6:-80,{self.threshold},{float(amp_factor)}']] 142 | )[0].numpy() 143 | return x_aug 144 | else: 145 | return x 146 | 147 | @gin.configurable 148 | class RandomGain(Transform): 149 | def __init__(self, gain_range: Tuple[int, int] = [-6, 3], prob: float = 0.5, limit = True): 150 | assert prob >= 0. and prob <= 1., "prob must be between 0. and 1." 151 | self.gain_range = gain_range 152 | self.prob = prob 153 | self.limit = limit 154 | 155 | def __call__(self, x: torch.Tensor): 156 | perform = bool(torch.bernoulli(torch.full((1,), self.prob))) 157 | if perform: 158 | gain_factor = np.random.rand(1)[None, None][0] * (self.gain_range[1] - self.gain_range[0]) + self.gain_range[0] 159 | amp_factor = np.power(10, gain_factor / 20) 160 | x_amp = x * amp_factor 161 | if (self.limit) and (np.abs(x_amp).max() > 1): 162 | x_amp = x_amp / np.abs(x_amp).max() 163 | return x 164 | else: 165 | return x 166 | 167 | 168 | @gin.configurable 169 | class RandomMute(Transform): 170 | def __init__(self, prob: torch.Tensor = 0.1): 171 | assert prob >= 0. and prob <= 1., "prob must be between 0. and 1." 172 | self.prob = prob 173 | 174 | def __call__(self, x: torch.Tensor): 175 | mask = torch.bernoulli(torch.full((x.shape[0],), 1 - self.prob)) 176 | mask = np.random.binomial(1, 1-self.prob, size=1) 177 | return x * mask 178 | 179 | 180 | @gin.configurable 181 | class FrequencyMasking(Transform): 182 | def __init__(self, prob = 0.5, max_size: int = 80): 183 | self.prob = prob 184 | self.max_size = max_size 185 | 186 | def __call__(self, x: torch.Tensor): 187 | perform = bool(torch.bernoulli(torch.full((1,), self.prob))) 188 | if not perform: 189 | return x 190 | spectrogram = signal.stft(x, nperseg=4096)[2] 191 | mask_size = randrange(1, self.max_size) 192 | freq_idx = randrange(0, spectrogram.shape[-2] - mask_size) 193 | spectrogram[..., freq_idx:freq_idx+mask_size, :] = 0 194 | x_inv = signal.istft(spectrogram)[1] 195 | return x_inv 196 | 197 | 198 | 199 | # Utilitary for GIN recording of augmentations 200 | 201 | 202 | _augmentations = [] 203 | 204 | @gin.configurable() 205 | def add_augmentation(aug): 206 | global _augmentations 207 | _augmentations.append(aug) 208 | 209 | def get_augmentations(): 210 | return _augmentations -------------------------------------------------------------------------------- /rave/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.3.1" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=1.2.0 2 | einops>=0.5.0 3 | gin-config 4 | GPUtil>=1.4.0 5 | librosa>=0.9.2 6 | numpy>=1.23.3 7 | pytorch_lightning==1.9.0 8 | PyYAML>=6.0 9 | scikit_learn>=1.1.2 10 | scipy==1.10.0 11 | torch 12 | tqdm>=4.64.1 13 | udls>=1.0.1 14 | cached-conv>=2.5.0 15 | nn-tilde>=1.5.2 16 | torchaudio 17 | tensorboard 18 | pytest>=7.2.2 19 | Flask>=2.2.3 -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/export_onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.set_grad_enabled(False) 4 | import os 5 | 6 | import cached_conv as cc 7 | import gin 8 | import torch.nn as nn 9 | from absl import app, flags 10 | from effortless_config import Config 11 | 12 | import rave 13 | 14 | flags.DEFINE_string('run', default=None, required=True, help='Run to export') 15 | FLAGS = flags.FLAGS 16 | 17 | 18 | def main(argv): 19 | gin.parse_config_file(os.path.join(FLAGS.run, "config.gin")) 20 | checkpoint = rave.core.search_for_run(FLAGS.run) 21 | 22 | print(f"using {checkpoint}") 23 | 24 | pretrained = rave.RAVE() 25 | pretrained.load_state_dict(torch.load(checkpoint)["state_dict"]) 26 | pretrained.eval() 27 | 28 | for m in pretrained.modules(): 29 | if hasattr(m, "weight_g"): 30 | nn.utils.remove_weight_norm(m) 31 | 32 | def recursive_replace(model: nn.Module): 33 | for name, child in model.named_children(): 34 | if isinstance(child, cc.convs.Conv1d): 35 | conv = nn.Conv1d( 36 | child.in_channels, 37 | child.out_channels, 38 | child.kernel_size, 39 | child.stride, 40 | child._pad[0], 41 | child.dilation, 42 | child.groups, 43 | child.bias, 44 | ) 45 | conv.weight.data.copy_(child.weight.data) 46 | if conv.bias is not None: 47 | conv.bias.data.copy_(child.bias.data) 48 | setattr(model, name, conv) 49 | elif isinstance(child, cc.convs.ConvTranspose1d): 50 | conv = nn.ConvTranspose1d( 51 | child.in_channels, 52 | child.out_channels, 53 | child.kernel_size, 54 | child.stride, 55 | child.padding, 56 | child.output_padding, 57 | child.groups, 58 | child.bias, 59 | child.dilation, 60 | child.padding_mode, 61 | ) 62 | conv.weight.data.copy_(child.weight.data) 63 | if conv.bias is not None: 64 | conv.bias.data.copy_(child.bias.data) 65 | setattr(model, name, conv) 66 | else: 67 | recursive_replace(child) 68 | 69 | recursive_replace(pretrained) 70 | 71 | x = torch.randn(1, pretrained.n_channels, 2**15) 72 | pretrained(x) 73 | 74 | name = os.path.basename(os.path.normpath(FLAGS.run)) 75 | export_path = os.path.join(FLAGS.run, name) 76 | torch.onnx.export( 77 | pretrained, 78 | x, 79 | f"{export_path}.onnx", 80 | export_params=True, 81 | opset_version=12, 82 | input_names=["audio_in"], 83 | output_names=["audio_out"], 84 | dynamic_axes={ 85 | "audio_in": { 86 | 2: "audio_length" 87 | }, 88 | "audio_out": [0], 89 | }, 90 | do_constant_folding=False, 91 | ) 92 | 93 | 94 | if __name__ == '__main__': 95 | app.run(main) -------------------------------------------------------------------------------- /scripts/generate.py: -------------------------------------------------------------------------------- 1 | from absl import app, flags, logging 2 | import pdb 3 | import torch, torchaudio, argparse, os, tqdm, re, gin 4 | import cached_conv as cc 5 | 6 | try: 7 | import rave 8 | except: 9 | import sys, os 10 | sys.path.append(os.path.abspath('.')) 11 | import rave 12 | 13 | 14 | FLAGS = flags.FLAGS 15 | flags.DEFINE_string('model', required=True, default=None, help="model path") 16 | flags.DEFINE_multi_string('input', required=True, default=None, help="model inputs (file or folder)") 17 | flags.DEFINE_string('out_path', 'generations', help="output path") 18 | flags.DEFINE_string('name', None, help="name of the model") 19 | flags.DEFINE_integer('gpu', default=-1, help='GPU to use') 20 | flags.DEFINE_bool('stream', default=False, help='simulates streaming mode') 21 | flags.DEFINE_integer('chunk_size', default=None, help="chunk size for encoding/decoding (default: full file)") 22 | 23 | 24 | def get_audio_files(path): 25 | audio_files = [] 26 | valid_exts = rave.core.get_valid_extensions() 27 | for root, _, files in os.walk(path): 28 | valid_files = list(filter(lambda x: os.path.splitext(x)[1] in valid_exts, files)) 29 | audio_files.extend([(path, os.path.join(root, f)) for f in valid_files]) 30 | return audio_files 31 | 32 | 33 | def main(argv): 34 | torch.set_float32_matmul_precision('high') 35 | cc.use_cached_conv(FLAGS.stream) 36 | 37 | model_path = FLAGS.model 38 | paths = FLAGS.input 39 | # load model 40 | logging.info("building rave") 41 | is_scripted = False 42 | if not os.path.exists(model_path): 43 | logging.error('path %s does not seem to exist.'%model_path) 44 | exit() 45 | if os.path.splitext(model_path)[1] == ".ts": 46 | model = torch.jit.load(model_path) 47 | is_scripted = True 48 | else: 49 | config_path = rave.core.search_for_config(model_path) 50 | if config_path is None: 51 | logging.error('config not found in folder %s'%model_path) 52 | gin.parse_config_file(config_path) 53 | model = rave.RAVE() 54 | run = rave.core.search_for_run(model_path) 55 | if run is None: 56 | logging.error("run not found in folder %s"%model_path) 57 | model = model.load_from_checkpoint(run) 58 | 59 | # device 60 | if FLAGS.gpu >= 0: 61 | device = torch.device('cuda:%d'%FLAGS.gpu) 62 | model = model.to(device) 63 | else: 64 | device = torch.device('cpu') 65 | 66 | 67 | # make output directories 68 | if FLAGS.name is None: 69 | FLAGS.name = "_".join(os.path.basename(model_path).split('_')[:-1]) 70 | out_path = os.path.join(FLAGS.out_path, FLAGS.name) 71 | os.makedirs(out_path, exist_ok=True) 72 | 73 | # parse inputs 74 | audio_files = sum([get_audio_files(f) for f in paths], []) 75 | receptive_field = rave.core.get_minimum_size(model) 76 | 77 | progress_bar = tqdm.tqdm(audio_files) 78 | cc.MAX_BATCH_SIZE = 8 79 | 80 | for i, (d, f) in enumerate(progress_bar): 81 | #TODO reset cache 82 | 83 | try: 84 | x, sr = torchaudio.load(f) 85 | except: 86 | logging.warning('could not open file %s.'%f) 87 | continue 88 | progress_bar.set_description(f) 89 | 90 | # load file 91 | if sr != model.sr: 92 | x = torchaudio.functional.resample(x, sr, model.sr) 93 | if model.n_channels != x.shape[0]: 94 | if model.n_channels < x.shape[0]: 95 | x = x[:model.n_channels] 96 | else: 97 | print('[Warning] file %s has %d channels, butt model has %d channels ; skipping'%(f, model.n_channels)) 98 | x = x.to(device) 99 | if FLAGS.stream: 100 | if FLAGS.chunk_size: 101 | assert FLAGS.chunk_size > receptive_field, "chunk_size must be higher than models' receptive field (here : %s)"%receptive_field 102 | x = list(x.split(FLAGS.chunk_size, dim=-1)) 103 | if x[-1].shape[0] < FLAGS.chunk_size: 104 | x[-1] = torch.nn.functional.pad(x[-1], (0, FLAGS.chunk_size - x[-1].shape[-1])) 105 | x = torch.stack(x, 0) 106 | else: 107 | x = x[None] 108 | 109 | # forward into model 110 | out = [] 111 | for x_chunk in x: 112 | x_chunk = x_chunk.to(device) 113 | out_tmp = model(x_chunk[None]) 114 | out.append(out_tmp) 115 | out = torch.cat(out, -1) 116 | else: 117 | out = model.forward(x[None]) 118 | 119 | # save file 120 | out_path = re.sub(d, "", f) 121 | out_path = os.path.join(FLAGS.out_path, f) 122 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 123 | torchaudio.save(out_path, out[0].cpu(), sample_rate=model.sr) 124 | 125 | if __name__ == "__main__": 126 | app.run(main) -------------------------------------------------------------------------------- /scripts/main_cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from absl import app 4 | 5 | AVAILABLE_SCRIPTS = [ 6 | 'preprocess', 'train', 'train_prior', 'export', 'export_onnx', 'remote_dataset', 'generate' 7 | ] 8 | 9 | 10 | def help(): 11 | print(f"""usage: rave [ {' | '.join(AVAILABLE_SCRIPTS)} ] 12 | 13 | positional arguments: 14 | command Command to launch with rave. 15 | """) 16 | exit() 17 | 18 | 19 | def main(): 20 | if len(sys.argv) == 1: 21 | help() 22 | elif sys.argv[1] not in AVAILABLE_SCRIPTS: 23 | help() 24 | 25 | command = sys.argv[1] 26 | 27 | if command == 'train': 28 | from scripts import train 29 | sys.argv[0] = train.__name__ 30 | app.run(train.main) 31 | elif command == 'train_prior': 32 | from scripts import train_prior 33 | sys.argv[0] = train_prior.__name__ 34 | app.run(train_prior.main) 35 | elif command == 'export': 36 | from scripts import export 37 | sys.argv[0] = export.__name__ 38 | app.run(export.main) 39 | elif command == 'preprocess': 40 | from scripts import preprocess 41 | sys.argv[0] = preprocess.__name__ 42 | app.run(preprocess.main) 43 | elif command == 'export_onnx': 44 | from scripts import export_onnx 45 | sys.argv[0] = export_onnx.__name__ 46 | app.run(export_onnx.main) 47 | elif command == "generate": 48 | from scripts import generate 49 | sys.argv[0] = generate.__name__ 50 | app.run(generate.main) 51 | elif command == 'remote_dataset': 52 | from scripts import remote_dataset 53 | sys.argv[0] = remote_dataset.__name__ 54 | app.run(remote_dataset.main) 55 | else: 56 | raise Exception(f'Command {command} not found') 57 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing 3 | import os 4 | import pathlib 5 | import subprocess 6 | from datetime import timedelta 7 | from functools import partial 8 | from itertools import repeat 9 | from typing import Callable, Iterable, Sequence, Tuple 10 | 11 | import lmdb 12 | import numpy as np 13 | import torch 14 | import yaml 15 | import math 16 | from absl import app, flags 17 | from tqdm import tqdm 18 | from udls.generated import AudioExample 19 | 20 | torch.set_grad_enabled(False) 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | flags.DEFINE_multi_string('input_path', 25 | None, 26 | help='Path to a directory containing audio files', 27 | required=True) 28 | flags.DEFINE_string('output_path', 29 | None, 30 | help='Output directory for the dataset', 31 | required=True) 32 | flags.DEFINE_integer('num_signal', 33 | 131072, 34 | help='Number of audio samples to use during training') 35 | flags.DEFINE_integer('channels', 1, help="Number of audio channels") 36 | flags.DEFINE_integer('sampling_rate', 37 | 44100, 38 | help='Sampling rate to use during training') 39 | flags.DEFINE_integer('max_db_size', 40 | 100, 41 | help='Maximum size (in GB) of the dataset') 42 | flags.DEFINE_multi_string( 43 | 'ext', 44 | default=['aif', 'aiff', 'wav', 'opus', 'mp3', 'aac', 'flac', 'ogg'], 45 | help='Extension to search for in the input directory') 46 | flags.DEFINE_bool('lazy', 47 | default=False, 48 | help='Decode and resample audio samples.') 49 | flags.DEFINE_bool('dyndb', 50 | default=True, 51 | help="Allow the database to grow dynamically") 52 | 53 | 54 | def float_array_to_int16_bytes(x): 55 | return np.floor(x * (2**15 - 1)).astype(np.int16).tobytes() 56 | 57 | 58 | def load_audio_chunk(path: str, n_signal: int, 59 | sr: int, channels: int = 1) -> Iterable[np.ndarray]: 60 | 61 | _, input_channels = get_audio_channels(path) 62 | channel_map = range(channels) 63 | if input_channels < channels: 64 | channel_map = (math.ceil(channels / input_channels) * list(range(input_channels)))[:channels] 65 | 66 | processes = [] 67 | for i in range(channels): 68 | process = subprocess.Popen( 69 | [ 70 | 'ffmpeg', '-hide_banner', '-loglevel', 'panic', '-i', path, 71 | '-ar', str(sr), 72 | '-f', 's16le', 73 | '-filter_complex', 'channelmap=%d-0'%channel_map[i], 74 | '-' 75 | ], 76 | stdout=subprocess.PIPE, 77 | ) 78 | processes.append(process) 79 | 80 | chunk = [p.stdout.read(n_signal * 4) for p in processes] 81 | while len(chunk[0]) == n_signal * 4: 82 | yield b''.join(chunk) 83 | chunk = [p.stdout.read(n_signal * 4) for p in processes] 84 | process.stdout.close() 85 | 86 | 87 | def get_audio_length(path: str) -> float: 88 | process = subprocess.Popen( 89 | [ 90 | 'ffprobe', '-i', path, '-v', 'error', '-show_entries', 91 | 'format=duration' 92 | ], 93 | stdout=subprocess.PIPE, 94 | stderr=subprocess.PIPE, 95 | ) 96 | stdout, _ = process.communicate() 97 | if process.returncode: return None 98 | try: 99 | stdout = stdout.decode().split('\n')[1].split('=')[-1] 100 | length = float(stdout) 101 | _, channels = get_audio_channels(path) 102 | return path, float(length), int(channels) 103 | except: 104 | return None 105 | 106 | def get_audio_channels(path: str) -> int: 107 | process = subprocess.Popen( 108 | [ 109 | 'ffprobe', '-i', path, '-v', 'error', '-show_entries', 110 | 'stream=channels' 111 | ], 112 | stdout=subprocess.PIPE, 113 | stderr=subprocess.PIPE, 114 | ) 115 | stdout, _ = process.communicate() 116 | if process.returncode: return None 117 | try: 118 | stdout = stdout.decode().split('\n')[1].split('=')[-1] 119 | channels = int(stdout) 120 | return path, int(channels) 121 | except: 122 | return None 123 | 124 | 125 | def flatten(iterator: Iterable): 126 | for elm in iterator: 127 | for sub_elm in elm: 128 | yield sub_elm 129 | 130 | def get_metadata(audio_samples, channels: int = 1): 131 | audio = np.frombuffer(audio_samples, dtype=np.int16) 132 | audio = audio.astype(float) / (2**15 - 1) 133 | audio = audio.reshape(channels, -1) 134 | peak_amplitude = np.amax(np.abs(audio)) 135 | rms_amplitude = np.sqrt(np.mean(audio**2)) 136 | return {'peak': peak_amplitude, 'rms_amplitude': rms_amplitude} 137 | 138 | 139 | def process_audio_array(audio: Tuple[int, bytes], 140 | env: lmdb.Environment, 141 | channels: int = 1) -> int: 142 | audio_id, audio_samples = audio 143 | buffers = {} 144 | buffers['waveform'] = AudioExample.AudioBuffer( 145 | shape=(channels, int(len(audio_samples) / channels)), 146 | sampling_rate=FLAGS.sampling_rate, 147 | data=audio_samples, 148 | precision=AudioExample.Precision.INT16, 149 | ) 150 | 151 | ae = AudioExample(buffers=buffers) 152 | key = f'{audio_id:08d}' 153 | with env.begin(write=True) as txn: 154 | txn.put( 155 | key.encode(), 156 | ae.SerializeToString(), 157 | ) 158 | return audio_id 159 | 160 | 161 | def process_audio_file(audio: Tuple[int, Tuple[str, float]], 162 | env: lmdb.Environment) -> int: 163 | audio_id, (path, length, channels) = audio 164 | ae = AudioExample(metadata={'path': path, 'length': str(length), 'channels': str(channels)}) 165 | key = f'{audio_id:08d}' 166 | with env.begin(write=True) as txn: 167 | txn.put( 168 | key.encode(), 169 | ae.SerializeToString(), 170 | ) 171 | return length 172 | 173 | 174 | def flatmap(pool: multiprocessing.Pool, 175 | func: Callable, 176 | iterable: Iterable, 177 | chunksize=None): 178 | queue = multiprocessing.Manager().Queue(maxsize=os.cpu_count()) 179 | pool.map_async( 180 | functools.partial(flat_mappper, func), 181 | zip(iterable, repeat(queue)), 182 | chunksize, 183 | lambda _: queue.put(None), 184 | lambda *e: print(e), 185 | ) 186 | 187 | item = queue.get() 188 | while item is not None: 189 | yield item 190 | item = queue.get() 191 | 192 | 193 | def flat_mappper(func, arg): 194 | data, queue = arg 195 | for item in func(data): 196 | queue.put(item) 197 | 198 | 199 | def search_for_audios(path_list: Sequence[str], extensions: Sequence[str]): 200 | paths = map(pathlib.Path, path_list) 201 | audios = [] 202 | for p in paths: 203 | for ext in extensions: 204 | audios.append(p.rglob(f'*.{ext}')) 205 | audios.append(p.rglob(f'*.{ext.upper()}')) 206 | audios = flatten(audios) 207 | return audios 208 | 209 | 210 | def main(argv): 211 | if FLAGS.lazy and os.name in ["nt", "posix"]: 212 | while (answer := input( 213 | "Using lazy datasets on Windows/macOS might result in slow training. Continue ? (y/n) " 214 | ).lower()) not in ["y", "n"]: 215 | print("Answer 'y' or 'n'.") 216 | if answer == "n": 217 | print("Aborting...") 218 | exit() 219 | 220 | 221 | chunk_load = partial(load_audio_chunk, 222 | n_signal=FLAGS.num_signal, 223 | sr=FLAGS.sampling_rate, 224 | channels=FLAGS.channels) 225 | 226 | output_dir = os.path.join(*os.path.split(FLAGS.output_path)[:-1]) 227 | if not os.path.isdir(output_dir): 228 | os.makedirs(output_dir) 229 | 230 | # create database 231 | env = lmdb.open( 232 | FLAGS.output_path, 233 | map_size=FLAGS.max_db_size * 1024**3, 234 | map_async=not FLAGS.dyndb, 235 | writemap=not FLAGS.dyndb, 236 | ) 237 | pool = multiprocessing.Pool() 238 | 239 | 240 | # search for audio files 241 | audios = search_for_audios(FLAGS.input_path, FLAGS.ext) 242 | audios = map(str, audios) 243 | audios = map(os.path.abspath, audios) 244 | audios = [*audios] 245 | if len(audios) == 0: 246 | print("No valid file found in %s. Aborting"%FLAGS.input_path) 247 | 248 | if not FLAGS.lazy: 249 | 250 | # load chunks 251 | chunks = flatmap(pool, chunk_load, audios) 252 | chunks = enumerate(chunks) 253 | 254 | processed_samples = map(partial(process_audio_array, env=env, channels=FLAGS.channels), chunks) 255 | 256 | pbar = tqdm(processed_samples) 257 | n_seconds = 0 258 | for audio_id in pbar: 259 | n_seconds = (FLAGS.num_signal * 2) / FLAGS.sampling_rate * audio_id 260 | pbar.set_description( 261 | f'dataset length: {timedelta(seconds=n_seconds)}') 262 | pbar.close() 263 | else: 264 | audio_lengths = pool.imap_unordered(get_audio_length, audios) 265 | audio_lengths = filter(lambda x: x is not None, audio_lengths) 266 | audio_lengths = enumerate(audio_lengths) 267 | processed_samples = map(partial(process_audio_file, env=env), 268 | audio_lengths) 269 | pbar = tqdm(processed_samples) 270 | n_seconds = 0 271 | for length in pbar: 272 | n_seconds += length 273 | pbar.set_description( 274 | f'dataset length: {timedelta(seconds=n_seconds)}') 275 | pbar.close() 276 | 277 | with open(os.path.join( 278 | FLAGS.output_path, 279 | 'metadata.yaml', 280 | ), 'w') as metadata: 281 | yaml.safe_dump({'lazy': FLAGS.lazy, 'channels': FLAGS.channels, 'n_seconds': n_seconds, 'sr': FLAGS.sampling_rate}, metadata) 282 | pool.close() 283 | env.close() 284 | 285 | 286 | if __name__ == '__main__': 287 | app.run(main) 288 | -------------------------------------------------------------------------------- /scripts/remote_dataset.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import logging 3 | import os 4 | 5 | import flask 6 | import numpy as np 7 | from absl import flags 8 | from udls import AudioExample 9 | 10 | from rave.dataset import get_dataset 11 | 12 | logging.basicConfig(level=logging.ERROR) 13 | log = logging.getLogger('werkzeug') 14 | log.setLevel(logging.ERROR) 15 | 16 | FLAGS = flags.FLAGS 17 | flags.DEFINE_string( 18 | "db_path", 19 | default=None, 20 | required=True, 21 | help="path to database.", 22 | ) 23 | flags.DEFINE_integer( 24 | "sr", 25 | default=44100, 26 | help="sampling rate.", 27 | ) 28 | flags.DEFINE_integer( 29 | "n_signal", 30 | default=2**16, 31 | help="sample size.", 32 | ) 33 | flags.DEFINE_integer( 34 | "port", 35 | default=5000, 36 | help="port to serve the dataset.", 37 | ) 38 | 39 | 40 | def main(argv): 41 | app = flask.Flask(__name__) 42 | dataset = get_dataset(db_path=FLAGS.db_path, 43 | sr=FLAGS.sr, 44 | n_signal=FLAGS.n_signal) 45 | 46 | @app.route("/") 47 | def main(): 48 | return ("

RAVE remote dataset

\n" 49 | f"

Serving: {os.path.abspath(FLAGS.db_path)}

\n" 50 | f"

Length: {len(dataset)}

") 51 | 52 | @app.route("/len") 53 | def length(): 54 | return flask.jsonify(len(dataset)) 55 | 56 | @app.route("/get/") 57 | def get(index): 58 | index = int(index) 59 | ae = AudioExample() 60 | ae.put("audio", dataset[index], np.float32) 61 | ae = base64.b64encode(bytes(ae)) 62 | return ae 63 | 64 | app.run(host="0.0.0.0", port=FLAGS.port) 65 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import sys 4 | from typing import Any, Dict 5 | 6 | import gin 7 | import pytorch_lightning as pl 8 | import torch 9 | from absl import flags, app 10 | from torch.utils.data import DataLoader 11 | 12 | try: 13 | import rave 14 | except: 15 | import sys, os 16 | sys.path.append(os.path.abspath('.')) 17 | import rave 18 | 19 | import rave 20 | import rave.core 21 | import rave.dataset 22 | from rave.transforms import get_augmentations, add_augmentation 23 | 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | flags.DEFINE_string('name', None, help='Name of the run', required=True) 28 | flags.DEFINE_multi_string('config', 29 | default='v2.gin', 30 | help='RAVE configuration to use') 31 | flags.DEFINE_multi_string('augment', 32 | default = [], 33 | help = 'augmentation configurations to use') 34 | flags.DEFINE_string('db_path', 35 | None, 36 | help='Preprocessed dataset path', 37 | required=True) 38 | flags.DEFINE_string('out_path', 39 | default="runs/", 40 | help='Output folder') 41 | flags.DEFINE_integer('max_steps', 42 | 6000000, 43 | help='Maximum number of training steps') 44 | flags.DEFINE_integer('val_every', 10000, help='Checkpoint model every n steps') 45 | flags.DEFINE_integer('save_every', 46 | 500000, 47 | help='save every n steps (default: just last)') 48 | flags.DEFINE_integer('n_signal', 49 | 131072, 50 | help='Number of audio samples to use during training') 51 | flags.DEFINE_integer('channels', 0, help="number of audio channels") 52 | flags.DEFINE_integer('batch', 8, help='Batch size') 53 | flags.DEFINE_string('ckpt', 54 | None, 55 | help='Path to previous checkpoint of the run') 56 | flags.DEFINE_multi_string('override', default=[], help='Override gin binding') 57 | flags.DEFINE_integer('workers', 58 | default=8, 59 | help='Number of workers to spawn for dataset loading') 60 | flags.DEFINE_multi_integer('gpu', default=None, help='GPU to use') 61 | flags.DEFINE_bool('derivative', 62 | default=False, 63 | help='Train RAVE on the derivative of the signal') 64 | flags.DEFINE_bool('normalize', 65 | default=False, 66 | help='Train RAVE on normalized signals') 67 | flags.DEFINE_list('rand_pitch', 68 | default=None, 69 | help='activates random pitch') 70 | flags.DEFINE_float('ema', 71 | default=None, 72 | help='Exponential weight averaging factor (optional)') 73 | flags.DEFINE_bool('progress', 74 | default=True, 75 | help='Display training progress bar') 76 | flags.DEFINE_bool('smoke_test', 77 | default=False, 78 | help="Run training with n_batches=1 to test the model") 79 | 80 | 81 | class EMA(pl.Callback): 82 | 83 | def __init__(self, factor=.999) -> None: 84 | super().__init__() 85 | self.weights = {} 86 | self.factor = factor 87 | 88 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, 89 | batch_idx) -> None: 90 | for n, p in pl_module.named_parameters(): 91 | if n not in self.weights: 92 | self.weights[n] = p.data.clone() 93 | continue 94 | 95 | self.weights[n] = self.weights[n] * self.factor + p.data * ( 96 | 1 - self.factor) 97 | 98 | def swap_weights(self, module): 99 | for n, p in module.named_parameters(): 100 | current = p.data.clone() 101 | p.data.copy_(self.weights[n]) 102 | self.weights[n] = current 103 | 104 | def on_validation_epoch_start(self, trainer, pl_module) -> None: 105 | if self.weights: 106 | self.swap_weights(pl_module) 107 | else: 108 | print("no ema weights available") 109 | 110 | def on_validation_epoch_end(self, trainer, pl_module) -> None: 111 | if self.weights: 112 | self.swap_weights(pl_module) 113 | else: 114 | print("no ema weights available") 115 | 116 | def state_dict(self) -> Dict[str, Any]: 117 | return self.weights.copy() 118 | 119 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 120 | self.weights.update(state_dict) 121 | 122 | def add_gin_extension(config_name: str) -> str: 123 | if config_name[-4:] != '.gin': 124 | config_name += '.gin' 125 | return config_name 126 | 127 | def parse_augmentations(augmentations): 128 | for a in augmentations: 129 | gin.parse_config_file(a) 130 | add_augmentation() 131 | gin.clear_config() 132 | return get_augmentations() 133 | 134 | def main(argv): 135 | torch.set_float32_matmul_precision('high') 136 | torch.backends.cudnn.benchmark = True 137 | 138 | # check dataset channels 139 | n_channels = rave.dataset.get_training_channels(FLAGS.db_path, FLAGS.channels) 140 | gin.bind_parameter('RAVE.n_channels', n_channels) 141 | 142 | # parse augmentations 143 | augmentations = parse_augmentations(map(add_gin_extension, FLAGS.augment)) 144 | gin.bind_parameter('dataset.get_dataset.augmentations', augmentations) 145 | 146 | # parse configuration 147 | if FLAGS.ckpt: 148 | config_file = rave.core.search_for_config(FLAGS.ckpt) 149 | if config_file is None: 150 | print('Config file not found in %s'%FLAGS.run) 151 | gin.parse_config_file(config_file) 152 | else: 153 | gin.parse_config_files_and_bindings( 154 | map(add_gin_extension, FLAGS.config), 155 | FLAGS.override, 156 | ) 157 | 158 | # create model 159 | model = rave.RAVE(n_channels=FLAGS.channels) 160 | if FLAGS.derivative: 161 | model.integrator = rave.dataset.get_derivator_integrator(model.sr)[1] 162 | 163 | # parse datasset 164 | dataset = rave.dataset.get_dataset(FLAGS.db_path, 165 | model.sr, 166 | FLAGS.n_signal, 167 | derivative=FLAGS.derivative, 168 | normalize=FLAGS.normalize, 169 | rand_pitch=FLAGS.rand_pitch, 170 | n_channels=n_channels) 171 | train, val = rave.dataset.split_dataset(dataset, 98) 172 | 173 | # get data-loader 174 | num_workers = FLAGS.workers 175 | if os.name == "nt" or sys.platform == "darwin": 176 | num_workers = 0 177 | train = DataLoader(train, 178 | FLAGS.batch, 179 | True, 180 | drop_last=True, 181 | num_workers=num_workers) 182 | val = DataLoader(val, FLAGS.batch, False, num_workers=num_workers) 183 | 184 | # CHECKPOINT CALLBACKS 185 | validation_checkpoint = pl.callbacks.ModelCheckpoint(monitor="validation", 186 | filename="best") 187 | last_filename = "last" if FLAGS.save_every is None else "epoch-{epoch:04d}" 188 | last_checkpoint = rave.core.ModelCheckpoint(filename=last_filename, step_period=FLAGS.save_every) 189 | 190 | val_check = {} 191 | if len(train) >= FLAGS.val_every: 192 | val_check["val_check_interval"] = 1 if FLAGS.smoke_test else FLAGS.val_every 193 | else: 194 | nepoch = FLAGS.val_every // len(train) 195 | val_check["check_val_every_n_epoch"] = nepoch 196 | 197 | if FLAGS.smoke_test: 198 | val_check['limit_train_batches'] = 1 199 | val_check['limit_val_batches'] = 1 200 | 201 | gin_hash = hashlib.md5( 202 | gin.operative_config_str().encode()).hexdigest()[:10] 203 | 204 | RUN_NAME = f'{FLAGS.name}_{gin_hash}' 205 | 206 | os.makedirs(os.path.join(FLAGS.out_path, RUN_NAME), exist_ok=True) 207 | 208 | if FLAGS.gpu == [-1]: 209 | gpu = 0 210 | else: 211 | gpu = FLAGS.gpu or rave.core.setup_gpu() 212 | 213 | print('selected gpu:', gpu) 214 | 215 | accelerator = None 216 | devices = None 217 | if FLAGS.gpu == [-1]: 218 | pass 219 | elif torch.cuda.is_available(): 220 | accelerator = "cuda" 221 | devices = FLAGS.gpu or rave.core.setup_gpu() 222 | elif torch.backends.mps.is_available(): 223 | print( 224 | "Training on mac is not available yet. Use --gpu -1 to train on CPU (not recommended)." 225 | ) 226 | exit() 227 | accelerator = "mps" 228 | devices = 1 229 | 230 | callbacks = [ 231 | validation_checkpoint, 232 | last_checkpoint, 233 | rave.model.WarmupCallback(), 234 | rave.model.QuantizeCallback(), 235 | # rave.core.LoggerCallback(rave.core.ProgressLogger(RUN_NAME)), 236 | rave.model.BetaWarmupCallback(), 237 | ] 238 | 239 | if FLAGS.ema is not None: 240 | callbacks.append(EMA(FLAGS.ema)) 241 | 242 | trainer = pl.Trainer( 243 | logger=pl.loggers.TensorBoardLogger( 244 | FLAGS.out_path, 245 | name=RUN_NAME, 246 | ), 247 | accelerator=accelerator, 248 | devices=devices, 249 | callbacks=callbacks, 250 | max_epochs=300000, 251 | max_steps=FLAGS.max_steps, 252 | profiler="simple", 253 | enable_progress_bar=FLAGS.progress, 254 | **val_check, 255 | ) 256 | 257 | run = rave.core.search_for_run(FLAGS.ckpt) 258 | if run is not None: 259 | print('loading state from file %s'%run) 260 | loaded = torch.load(run, map_location='cpu') 261 | # model = model.load_state_dict(loaded) 262 | trainer.fit_loop.epoch_loop._batches_that_stepped = loaded['global_step'] 263 | # model = model.load_state_dict(loaded['state_dict']) 264 | 265 | with open(os.path.join(FLAGS.out_path, RUN_NAME, "config.gin"), "w") as config_out: 266 | config_out.write(gin.operative_config_str()) 267 | 268 | trainer.fit(model, train, val, ckpt_path=run) 269 | 270 | 271 | if __name__ == "__main__": 272 | app.run(main) 273 | -------------------------------------------------------------------------------- /scripts/train_prior.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import sys 4 | 5 | import gin 6 | import pytorch_lightning as pl 7 | import torch 8 | from absl import flags, app 9 | from torch.utils.data import DataLoader 10 | 11 | try: 12 | import rave 13 | except: 14 | import sys, os 15 | sys.path.append(os.path.abspath('.')) 16 | import rave 17 | 18 | import rave 19 | import rave.dataset 20 | import rave.prior 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | flags.DEFINE_string('name', None, help='Name of the run') 25 | flags.DEFINE_string('model', default=None, required=True, help="pretrained RAVE path") 26 | flags.DEFINE_multi_string('config', default="prior/prior_v1.gin", help="config path") 27 | flags.DEFINE_string('db_path', default=None, required=True, help="Preprocessed dataset path") 28 | flags.DEFINE_string('out_path', default="runs/", help="out directory path") 29 | flags.DEFINE_multi_integer('gpu', default=None, help='GPU to use') 30 | flags.DEFINE_integer('batch', 8, help="batch size") 31 | flags.DEFINE_integer('n_signal', 0, help="chunk size (default: given by prior config)") 32 | flags.DEFINE_string('ckpt', default=None, help="checkpoint to resume") 33 | flags.DEFINE_integer('workers', 34 | default=8, 35 | help='Number of workers to spawn for dataset loading') 36 | flags.DEFINE_integer('val_every', 10000, help='Checkpoint model every n steps') 37 | flags.DEFINE_integer('save_every', 38 | None, 39 | help='save every n steps (default: just last)') 40 | flags.DEFINE_integer('max_steps', default=1000000, help="max training steps") 41 | flags.DEFINE_multi_string('override', default=[], help='Override gin binding') 42 | 43 | flags.DEFINE_bool('derivative', 44 | default=False, 45 | help='Train RAVE on the derivative of the signal') 46 | flags.DEFINE_bool('normalize', 47 | default=False, 48 | help='Train RAVE on normalized signals') 49 | flags.DEFINE_list('rand_pitch', 50 | default=None, 51 | help='activates random pitch') 52 | flags.DEFINE_bool('progress', 53 | default=True, 54 | help='Display training progress bar') 55 | flags.DEFINE_bool('smoke_test', 56 | default=False, 57 | help="Run training with n_batches=1 to test the model") 58 | 59 | def add_gin_extension(config_name: str) -> str: 60 | if config_name[-4:] != '.gin': 61 | config_name += '.gin' 62 | return config_name 63 | 64 | 65 | def main(argv): 66 | 67 | # load pretrained RAVE 68 | config_file = rave.core.search_for_config(FLAGS.model) 69 | if config_file is None: 70 | print('no configuration file found at address :'%FLAGS.model) 71 | gin.parse_config_file(config_file) 72 | run = rave.core.search_for_run(FLAGS.model) 73 | if run is None: 74 | print('no checkpoint found in %s'%FLAGS.model) 75 | exit() 76 | pretrained = rave.RAVE() 77 | print('model found : %s'%run) 78 | checkpoint = torch.load(run, map_location='cpu') 79 | if "EMA" in checkpoint["callbacks"]: 80 | pretrained.load_state_dict( 81 | checkpoint["callbacks"]["EMA"], 82 | strict=False, 83 | ) 84 | else: 85 | pretrained.load_state_dict( 86 | checkpoint["state_dict"], 87 | strict=False, 88 | ) 89 | pretrained.eval() 90 | gin.clear_config() 91 | 92 | # parse configuration 93 | if FLAGS.ckpt: 94 | config_file = rave.core.search_for_config(FLAGS.ckpt) 95 | if config_file is None: 96 | print('Config gile not found in %s'%FLAGS.run) 97 | gin.parse_config_file(config_file) 98 | else: 99 | gin.parse_config_files_and_bindings( 100 | map(add_gin_extension, FLAGS.config), 101 | FLAGS.override 102 | ) 103 | 104 | # create model 105 | if isinstance(pretrained.encoder, rave.blocks.VariationalEncoder): 106 | prior = rave.prior.VariationalPrior(pretrained_vae=pretrained) 107 | else: 108 | raise NotImplementedError("prior not implemented for encoder of type %s"%(type(pretrained.encoder))) 109 | 110 | dataset = rave.dataset.get_dataset(FLAGS.db_path, 111 | pretrained.sr, 112 | max(FLAGS.n_signal, prior.min_receptive_field), 113 | derivative=FLAGS.derivative, 114 | normalize=FLAGS.normalize, 115 | rand_pitch=FLAGS.rand_pitch, 116 | n_channels=pretrained.n_channels) 117 | 118 | train, val = rave.dataset.split_dataset(dataset, 98) 119 | 120 | # get data-loader 121 | num_workers = FLAGS.workers 122 | if os.name == "nt" or sys.platform == "darwin": 123 | num_workers = 0 124 | train = DataLoader(train, 125 | FLAGS.batch, 126 | True, 127 | drop_last=True, 128 | num_workers=num_workers) 129 | val = DataLoader(val, FLAGS.batch, False, num_workers=num_workers) 130 | 131 | # CHECKPOINT CALLBACKS 132 | validation_checkpoint = pl.callbacks.ModelCheckpoint(monitor="validation", 133 | filename="best") 134 | last_filename = "last" if FLAGS.save_every is None else "epoch-{epoch:04d}" 135 | last_checkpoint = rave.core.ModelCheckpoint(filename=last_filename, step_period=FLAGS.save_every) 136 | 137 | val_check = {} 138 | if len(train) >= FLAGS.val_every: 139 | val_check["val_check_interval"] = 1 if FLAGS.smoke_test else FLAGS.val_every 140 | else: 141 | nepoch = FLAGS.val_every // len(train) 142 | val_check["check_val_every_n_epoch"] = nepoch 143 | 144 | if FLAGS.smoke_test: 145 | val_check['limit_train_batches'] = 1 146 | val_check['limit_val_batches'] = 1 147 | 148 | gin_hash = hashlib.md5( 149 | gin.operative_config_str().encode()).hexdigest()[:10] 150 | 151 | RUN_NAME = f'{FLAGS.name}_{gin_hash}' 152 | os.makedirs(os.path.join(FLAGS.out_path, RUN_NAME), exist_ok=True) 153 | 154 | if FLAGS.gpu == [-1]: 155 | gpu = 0 156 | else: 157 | gpu = FLAGS.gpu or rave.core.setup_gpu() 158 | 159 | print('selected gpu:', gpu) 160 | 161 | accelerator = None 162 | devices = None 163 | if FLAGS.gpu == [-1]: 164 | pass 165 | elif torch.cuda.is_available(): 166 | accelerator = "cuda" 167 | devices = FLAGS.gpu or rave.core.setup_gpu() 168 | elif torch.backends.mps.is_available(): 169 | print( 170 | "Training on mac is not available yet. Use --gpu -1 to train on CPU (not recommended)." 171 | ) 172 | exit() 173 | accelerator = "mps" 174 | devices = 1 175 | 176 | callbacks = [ 177 | validation_checkpoint, 178 | last_checkpoint, 179 | ] 180 | 181 | trainer = pl.Trainer( 182 | logger=pl.loggers.TensorBoardLogger( 183 | FLAGS.out_path, 184 | name=RUN_NAME, 185 | ), 186 | accelerator=accelerator, 187 | devices=devices, 188 | callbacks=callbacks, 189 | max_epochs=300000, 190 | max_steps=FLAGS.max_steps, 191 | profiler="simple", 192 | enable_progress_bar=FLAGS.progress, 193 | **val_check, 194 | ) 195 | 196 | run = rave.core.search_for_run(FLAGS.ckpt) 197 | if run is not None: 198 | print('loading state from file %s'%run) 199 | loaded = torch.load(run, map_location='cpu') 200 | trainer.fit_loop.epoch_loop._batches_that_stepped = loaded['global_step'] 201 | 202 | with open(os.path.join(FLAGS.out_path, RUN_NAME, "config.gin"), "w") as config_out: 203 | config_out.write(gin.operative_config_str()) 204 | 205 | trainer.fit(prior, train, val, ckpt_path=run) 206 | 207 | if __name__== "__main__": 208 | app.run(main) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import setuptools 5 | 6 | # imports __version__ 7 | exec(open('rave/version.py').read()) 8 | 9 | with open("README.md", "r") as readme: 10 | readme = readme.read() 11 | 12 | with open("requirements.txt", "r") as requirements: 13 | requirements = requirements.read() 14 | 15 | setuptools.setup( 16 | name="acids-rave", 17 | version=__version__, # type: ignore 18 | author="Antoine CAILLON", 19 | author_email="caillon@ircam.fr", 20 | description="RAVE: a Realtime Audio Variatione autoEncoder", 21 | long_description=readme, 22 | long_description_content_type="text/markdown", 23 | packages=setuptools.find_packages(), 24 | package_data={ 25 | 'rave/configs': ['*.gin'], 26 | }, 27 | classifiers=[ 28 | "Programming Language :: Python :: 3", 29 | "License :: OSI Approved :: MIT License", 30 | "Operating System :: OS Independent", 31 | ], 32 | entry_points={"console_scripts": [ 33 | "rave = scripts.main_cli:main", 34 | ]}, 35 | install_requires=requirements.split("\n"), 36 | python_requires='>=3.9', 37 | include_package_data=True, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import tempfile 4 | 5 | import gin 6 | import pytest 7 | import torch 8 | import torch.nn as nn 9 | 10 | import rave 11 | from scripts import export 12 | 13 | gin.enter_interactive_mode() 14 | 15 | configs = [ 16 | ["v1.gin"], 17 | ["v2.gin"], 18 | ["v2.gin", "adain.gin"], 19 | ["v2.gin", "wasserstein.gin"], 20 | ["v2.gin", "spherical.gin"], 21 | ["v2.gin", "hybrid.gin"], 22 | ["v2_small.gin", "adain.gin"], 23 | ["v2_small.gin", "wasserstein.gin"], 24 | ["v2_small.gin", "spherical.gin"], 25 | ["v2_small.gin", "hybrid.gin"], 26 | ["discrete.gin"], 27 | ["discrete.gin", "snake.gin"], 28 | ["discrete.gin", "snake.gin", "adain.gin"], 29 | ["discrete.gin", "snake.gin", "descript_discriminator.gin"], 30 | ["discrete.gin", "spectral_discriminator.gin"], 31 | ["discrete.gin", "noise.gin"], 32 | ["discrete.gin", "hybrid.gin"], 33 | ["v3.gin"], 34 | ["v3.gin", "hybrid.gin"] 35 | ] 36 | 37 | configs += [c + ["causal.gin"] for c in configs] 38 | 39 | model_sampling_rate = [44100, 22050] 40 | stereo = [True, False] 41 | 42 | configs = list(itertools.product(configs, model_sampling_rate, stereo)) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "config,sr,stereo", 47 | configs, 48 | ids=map( 49 | lambda e: " ".join(e[0]) + f" [{e[1]}] " + 50 | ("stereo" if e[2] else "mono"), configs), 51 | ) 52 | def test_config(config, sr, stereo): 53 | 54 | gin.clear_config() 55 | gin.parse_config_files_and_bindings(config, [ 56 | f"SAMPLING_RATE={sr}", 57 | "CAPACITY=2", 58 | ]) 59 | 60 | n_channels = 2 if stereo else 1 61 | model = rave.RAVE(n_channels=n_channels) 62 | 63 | x = torch.randn(1, n_channels, 2**15) 64 | z, _ = model.encode(x, return_mb=True) 65 | z, _ = model.encoder.reparametrize(z)[:2] 66 | y = model.decode(z) 67 | score = model.discriminator(y) 68 | 69 | assert x.shape == y.shape 70 | 71 | if isinstance(model.encoder, rave.blocks.VariationalEncoder): 72 | script_class = export.VariationalScriptedRAVE 73 | elif isinstance(model.encoder, rave.blocks.DiscreteEncoder): 74 | script_class = export.DiscreteScriptedRAVE 75 | elif isinstance(model.encoder, rave.blocks.WasserteinEncoder): 76 | script_class = export.WasserteinScriptedRAVE 77 | elif isinstance(model.encoder, rave.blocks.SphericalEncoder): 78 | script_class = export.SphericalScriptedRAVE 79 | else: 80 | raise ValueError(f"Encoder type {type(model.encoder)} " 81 | "not supported for export.") 82 | 83 | x = torch.zeros(1, n_channels, 2**14) 84 | 85 | model(x) 86 | 87 | for m in model.modules(): 88 | if hasattr(m, "weight_g"): 89 | nn.utils.remove_weight_norm(m) 90 | 91 | scripted_rave = script_class( 92 | pretrained=model, 93 | channels=n_channels, 94 | ) 95 | 96 | scripted_rave_resampled = script_class( 97 | pretrained=model, 98 | channels=n_channels, 99 | target_sr=44100, 100 | ) 101 | 102 | with tempfile.TemporaryDirectory() as tmpdir: 103 | scripted_rave.export_to_ts(os.path.join(tmpdir, "ori.ts")) 104 | scripted_rave_resampled.export_to_ts( 105 | os.path.join(tmpdir, "resampled.ts")) 106 | -------------------------------------------------------------------------------- /tests/test_resampler.py: -------------------------------------------------------------------------------- 1 | import cached_conv as cc 2 | import gin 3 | import pytest 4 | import torch 5 | 6 | from rave.resampler import Resampler 7 | 8 | configs = [(44100, 22050), (48000, 16000)] 9 | 10 | 11 | @pytest.mark.parametrize("target_sr,model_sr", configs) 12 | def test_resampler(target_sr, model_sr): 13 | gin.clear_config() 14 | cc.use_cached_conv(False) 15 | 16 | resampler = Resampler(target_sr, model_sr) 17 | 18 | x = torch.randn(1, 1, 2**12 * 3) 19 | 20 | y = resampler.to_model_sampling_rate(x) 21 | z = resampler.from_model_sampling_rate(y) 22 | 23 | assert x.shape == z.shape 24 | 25 | cc.use_cached_conv(True) 26 | 27 | try: 28 | resampler = Resampler(target_sr, model_sr) 29 | 30 | x = torch.randn(1, 1, 2**12 * 3) 31 | 32 | y = resampler.to_model_sampling_rate(x) 33 | z = resampler.from_model_sampling_rate(y) 34 | 35 | assert x.shape == z.shape 36 | 37 | except ValueError: 38 | pass 39 | -------------------------------------------------------------------------------- /tests/test_residual.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import cached_conv as cc 4 | import gin 5 | import pytest 6 | import torch 7 | 8 | from rave.blocks import * 9 | 10 | gin.enter_interactive_mode() 11 | 12 | kernel_size = [ 13 | 1, 14 | 3, 15 | ] 16 | 17 | dilations = [[1, 1], [3, 1]] 18 | 19 | kernel_sizes = [ 20 | [3], 21 | [3, 5], 22 | [3, 5, 7], 23 | ] 24 | 25 | dilations_list = [ 26 | [[1, 1]], 27 | [[1, 1], [3, 1], [5, 1]], 28 | ] 29 | 30 | ratios = [ 31 | 2, 32 | 4, 33 | 8, 34 | ] 35 | 36 | 37 | @pytest.mark.parametrize('kernel_sizes,dilations_list', 38 | itertools.product(kernel_sizes, dilations_list)) 39 | def test_residual_stack(kernel_sizes, dilations_list): 40 | dim = 16 41 | x = torch.randn(1, dim, 32) 42 | cc.use_cached_conv(False) 43 | stack_regular = ResidualStack( 44 | dim=dim, 45 | kernel_sizes=[3], 46 | dilations_list=[[1, 1], [3, 1], [5, 1]], 47 | ) 48 | 49 | cc.use_cached_conv(True) 50 | stack_stream = ResidualStack( 51 | dim=dim, 52 | kernel_sizes=[3], 53 | dilations_list=[[1, 1], [3, 1], [5, 1]], 54 | ) 55 | 56 | for p1, p2 in zip(stack_regular.parameters(), stack_stream.parameters()): 57 | p2.data.copy_(p1.data) 58 | 59 | delay = stack_stream.cumulative_delay 60 | 61 | y_regular = stack_regular(x) 62 | y_stream = stack_stream(x) 63 | 64 | if delay: 65 | y_regular = y_regular[..., delay:-delay] 66 | y_stream = y_stream[..., delay + delay:] 67 | 68 | assert torch.allclose(y_regular, y_stream, 1e-4, 1e-4) 69 | 70 | 71 | @pytest.mark.parametrize('kernel_size,dilations_list', 72 | itertools.product(kernel_size, dilations)) 73 | def test_residual_layer(kernel_size, dilations_list): 74 | dim = 16 75 | x = torch.randn(1, dim, 32) 76 | 77 | cc.use_cached_conv(False) 78 | layer_regular = ResidualLayer(dim, kernel_size, dilations_list) 79 | 80 | cc.use_cached_conv(True) 81 | layer_stream = ResidualLayer(dim, kernel_size, dilations_list) 82 | 83 | for p1, p2 in zip(layer_regular.parameters(), layer_stream.parameters()): 84 | p2.data.copy_(p1.data) 85 | 86 | delay = layer_stream.cumulative_delay 87 | 88 | y_regular = layer_regular(x) 89 | y_stream = layer_stream(x) 90 | 91 | if delay: 92 | y_regular = y_regular[..., delay:-delay] 93 | y_stream = y_stream[..., delay + delay:] 94 | 95 | assert torch.allclose(y_regular, y_stream, 1e-3, 1e-4) 96 | 97 | 98 | @pytest.mark.parametrize('ratio,', ratios) 99 | def test_upsample_layer(ratio): 100 | dim = 16 101 | x = torch.randn(1, dim, 32) 102 | 103 | cc.use_cached_conv(False) 104 | upsample_regular = UpsampleLayer(dim, dim, ratio) 105 | 106 | cc.use_cached_conv(True) 107 | upsample_stream = UpsampleLayer(dim, dim, ratio) 108 | 109 | for p1, p2 in zip(upsample_regular.parameters(), 110 | upsample_stream.parameters()): 111 | p2.data.copy_(p1.data) 112 | 113 | delay = upsample_stream.cumulative_delay 114 | 115 | y_regular = upsample_regular(x) 116 | y_stream = upsample_stream(x) 117 | 118 | if delay: 119 | y_regular = y_regular[..., delay:-delay] 120 | y_stream = y_stream[..., delay + delay:] 121 | 122 | assert torch.allclose(y_regular, y_stream, 1e-3, 1e-4) 123 | --------------------------------------------------------------------------------