├── .github
└── workflows
│ └── actions.yml
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
├── darbouka_prior.mp3
├── log_distance.png
├── log_fidelity.png
├── log_gan.png
├── maxmsp_screenshot.png
├── rave.png
├── rave_attribute.png
├── rave_buffer.png
├── rave_encode_decode.png
├── rave_high_level.png
├── rave_method_forward.png
├── tensorboard_guide.md
└── training_setup.md
├── rave
├── __init__.py
├── balancer.py
├── blocks.py
├── configs
│ ├── adain.gin
│ ├── augmentations
│ │ ├── compress.gin
│ │ ├── gain.gin
│ │ └── mute.gin
│ ├── causal.gin
│ ├── descript_discriminator.gin
│ ├── discrete.gin
│ ├── discrete_v3.gin
│ ├── hybrid.gin
│ ├── noise.gin
│ ├── normalize_ambient.gin
│ ├── onnx.gin
│ ├── prior
│ │ └── prior_v1.gin
│ ├── raspberry.gin
│ ├── snake.gin
│ ├── spectral_discriminator.gin
│ ├── spherical.gin
│ ├── v1.gin
│ ├── v2.gin
│ ├── v2_nopqmf.gin
│ ├── v2_nopqmf_small.gin
│ ├── v2_small.gin
│ ├── v2_with_augs.gin
│ ├── v3.gin
│ └── wasserstein.gin
├── core.py
├── dataset.py
├── descript_discriminator.py
├── discriminator.py
├── model.py
├── pqmf.py
├── prior
│ ├── __init__.py
│ ├── core.py
│ ├── model.py
│ └── residual_block.py
├── quantization.py
├── resampler.py
├── transforms.py
└── version.py
├── requirements.txt
├── scripts
├── __init__.py
├── export.py
├── export_onnx.py
├── generate.py
├── main_cli.py
├── preprocess.py
├── remote_dataset.py
├── train.py
└── train_prior.py
├── setup.py
└── tests
├── __init__.py
├── test_configs.py
├── test_resampler.py
└── test_residual.py
/.github/workflows/actions.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | permissions:
4 | pull-requests: write
5 | issues: write
6 | repository-projects: write
7 | contents: write
8 |
9 | on:
10 | pull_request:
11 | push:
12 | branches: [master]
13 | tags: v*
14 |
15 | jobs:
16 | build:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - uses: actions/checkout@v3
20 | - name: Set up Python
21 | uses: actions/setup-python@v3
22 | with:
23 | python-version: "3.10"
24 | cache: pip
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip setuptools wheel build pytest
28 | python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
29 | python -m pip install -r requirements.txt
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | if: startsWith(github.ref, 'refs/tags/v')
34 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
35 | with:
36 | user: __token__
37 | password: ${{ secrets.PYPI_TOKEN }}
38 |
39 | test:
40 | runs-on: ubuntu-latest
41 | steps:
42 | - uses: actions/checkout@v3
43 | - name: Set up Python
44 | uses: actions/setup-python@v3
45 | with:
46 | python-version: "3.10"
47 | cache: pip
48 | - name: Install dependencies
49 | run: |
50 | python -m pip install --upgrade pip setuptools wheel build pytest
51 | python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
52 | python -m pip install -r requirements.txt
53 | - name: Run tests
54 | run: pytest --junitxml=.test-report.xml
55 | - uses: actions/upload-artifact@v3
56 | if: success() || failure()
57 | with:
58 | name: test-report
59 | path: .test-report.xml
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *pycache*
2 | *DS_Store
3 | lightning_logs/
4 | *.ckpt
5 | *.ts
6 | *libtorch*
7 | *.wav
8 | *.txt
9 | runs
10 | *.npy
11 | *.yaml
12 | *.onnx
13 | __version__*
14 | PKG-INFO
15 | .junit-test-report.xml
16 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.formatOnSave": true,
3 | "python.formatting.provider": "yapf",
4 | "python.testing.pytestArgs": [
5 | "."
6 | ],
7 | "python.testing.unittestEnabled": false,
8 | "python.testing.pytestEnabled": true
9 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Creative Commons Attribution-NonCommercial 4.0 International
2 |
3 | Creative Commons Corporation ("Creative Commons") is not a law firm and
4 | does not provide legal services or legal advice. Distribution of
5 | Creative Commons public licenses does not create a lawyer-client or
6 | other relationship. Creative Commons makes its licenses and related
7 | information available on an "as-is" basis. Creative Commons gives no
8 | warranties regarding its licenses, any material licensed under their
9 | terms and conditions, or any related information. Creative Commons
10 | disclaims all liability for damages resulting from their use to the
11 | fullest extent possible.
12 |
13 | Using Creative Commons Public Licenses
14 |
15 | Creative Commons public licenses provide a standard set of terms and
16 | conditions that creators and other rights holders may use to share
17 | original works of authorship and other material subject to copyright and
18 | certain other rights specified in the public license below. The
19 | following considerations are for informational purposes only, are not
20 | exhaustive, and do not form part of our licenses.
21 |
22 | - Considerations for licensors: Our public licenses are intended for
23 | use by those authorized to give the public permission to use
24 | material in ways otherwise restricted by copyright and certain other
25 | rights. Our licenses are irrevocable. Licensors should read and
26 | understand the terms and conditions of the license they choose
27 | before applying it. Licensors should also secure all rights
28 | necessary before applying our licenses so that the public can reuse
29 | the material as expected. Licensors should clearly mark any material
30 | not subject to the license. This includes other CC-licensed
31 | material, or material used under an exception or limitation to
32 | copyright. More considerations for licensors :
33 | wiki.creativecommons.org/Considerations_for_licensors
34 |
35 | - Considerations for the public: By using one of our public licenses,
36 | a licensor grants the public permission to use the licensed material
37 | under specified terms and conditions. If the licensor's permission
38 | is not necessary for any reason–for example, because of any
39 | applicable exception or limitation to copyright–then that use is not
40 | regulated by the license. Our licenses grant only permissions under
41 | copyright and certain other rights that a licensor has authority to
42 | grant. Use of the licensed material may still be restricted for
43 | other reasons, including because others have copyright or other
44 | rights in the material. A licensor may make special requests, such
45 | as asking that all changes be marked or described. Although not
46 | required by our licenses, you are encouraged to respect those
47 | requests where reasonable. More considerations for the public :
48 | wiki.creativecommons.org/Considerations_for_licensees
49 |
50 | Creative Commons Attribution-NonCommercial 4.0 International Public
51 | License
52 |
53 | By exercising the Licensed Rights (defined below), You accept and agree
54 | to be bound by the terms and conditions of this Creative Commons
55 | Attribution-NonCommercial 4.0 International Public License ("Public
56 | License"). To the extent this Public License may be interpreted as a
57 | contract, You are granted the Licensed Rights in consideration of Your
58 | acceptance of these terms and conditions, and the Licensor grants You
59 | such rights in consideration of benefits the Licensor receives from
60 | making the Licensed Material available under these terms and conditions.
61 |
62 | - Section 1 – Definitions.
63 |
64 | - a. Adapted Material means material subject to Copyright and
65 | Similar Rights that is derived from or based upon the Licensed
66 | Material and in which the Licensed Material is translated,
67 | altered, arranged, transformed, or otherwise modified in a
68 | manner requiring permission under the Copyright and Similar
69 | Rights held by the Licensor. For purposes of this Public
70 | License, where the Licensed Material is a musical work,
71 | performance, or sound recording, Adapted Material is always
72 | produced where the Licensed Material is synched in timed
73 | relation with a moving image.
74 | - b. Adapter's License means the license You apply to Your
75 | Copyright and Similar Rights in Your contributions to Adapted
76 | Material in accordance with the terms and conditions of this
77 | Public License.
78 | - c. Copyright and Similar Rights means copyright and/or similar
79 | rights closely related to copyright including, without
80 | limitation, performance, broadcast, sound recording, and Sui
81 | Generis Database Rights, without regard to how the rights are
82 | labeled or categorized. For purposes of this Public License, the
83 | rights specified in Section 2(b)(1)-(2) are not Copyright and
84 | Similar Rights.
85 | - d. Effective Technological Measures means those measures that,
86 | in the absence of proper authority, may not be circumvented
87 | under laws fulfilling obligations under Article 11 of the WIPO
88 | Copyright Treaty adopted on December 20, 1996, and/or similar
89 | international agreements.
90 | - e. Exceptions and Limitations means fair use, fair dealing,
91 | and/or any other exception or limitation to Copyright and
92 | Similar Rights that applies to Your use of the Licensed
93 | Material.
94 | - f. Licensed Material means the artistic or literary work,
95 | database, or other material to which the Licensor applied this
96 | Public License.
97 | - g. Licensed Rights means the rights granted to You subject to
98 | the terms and conditions of this Public License, which are
99 | limited to all Copyright and Similar Rights that apply to Your
100 | use of the Licensed Material and that the Licensor has authority
101 | to license.
102 | - h. Licensor means the individual(s) or entity(ies) granting
103 | rights under this Public License.
104 | - i. NonCommercial means not primarily intended for or directed
105 | towards commercial advantage or monetary compensation. For
106 | purposes of this Public License, the exchange of the Licensed
107 | Material for other material subject to Copyright and Similar
108 | Rights by digital file-sharing or similar means is NonCommercial
109 | provided there is no payment of monetary compensation in
110 | connection with the exchange.
111 | - j. Share means to provide material to the public by any means or
112 | process that requires permission under the Licensed Rights, such
113 | as reproduction, public display, public performance,
114 | distribution, dissemination, communication, or importation, and
115 | to make material available to the public including in ways that
116 | members of the public may access the material from a place and
117 | at a time individually chosen by them.
118 | - k. Sui Generis Database Rights means rights other than copyright
119 | resulting from Directive 96/9/EC of the European Parliament and
120 | of the Council of 11 March 1996 on the legal protection of
121 | databases, as amended and/or succeeded, as well as other
122 | essentially equivalent rights anywhere in the world.
123 | - l. You means the individual or entity exercising the Licensed
124 | Rights under this Public License. Your has a corresponding
125 | meaning.
126 |
127 | - Section 2 – Scope.
128 |
129 | - a. License grant.
130 | - 1. Subject to the terms and conditions of this Public
131 | License, the Licensor hereby grants You a worldwide,
132 | royalty-free, non-sublicensable, non-exclusive, irrevocable
133 | license to exercise the Licensed Rights in the Licensed
134 | Material to:
135 | - A. reproduce and Share the Licensed Material, in whole
136 | or in part, for NonCommercial purposes only; and
137 | - B. produce, reproduce, and Share Adapted Material for
138 | NonCommercial purposes only.
139 | - 2. Exceptions and Limitations. For the avoidance of doubt,
140 | where Exceptions and Limitations apply to Your use, this
141 | Public License does not apply, and You do not need to comply
142 | with its terms and conditions.
143 | - 3. Term. The term of this Public License is specified in
144 | Section 6(a).
145 | - 4. Media and formats; technical modifications allowed. The
146 | Licensor authorizes You to exercise the Licensed Rights in
147 | all media and formats whether now known or hereafter
148 | created, and to make technical modifications necessary to do
149 | so. The Licensor waives and/or agrees not to assert any
150 | right or authority to forbid You from making technical
151 | modifications necessary to exercise the Licensed Rights,
152 | including technical modifications necessary to circumvent
153 | Effective Technological Measures. For purposes of this
154 | Public License, simply making modifications authorized by
155 | this Section 2(a)(4) never produces Adapted Material.
156 | - 5. Downstream recipients.
157 | - A. Offer from the Licensor – Licensed Material. Every
158 | recipient of the Licensed Material automatically
159 | receives an offer from the Licensor to exercise the
160 | Licensed Rights under the terms and conditions of this
161 | Public License.
162 | - B. No downstream restrictions. You may not offer or
163 | impose any additional or different terms or conditions
164 | on, or apply any Effective Technological Measures to,
165 | the Licensed Material if doing so restricts exercise of
166 | the Licensed Rights by any recipient of the Licensed
167 | Material.
168 | - 6. No endorsement. Nothing in this Public License
169 | constitutes or may be construed as permission to assert or
170 | imply that You are, or that Your use of the Licensed
171 | Material is, connected with, or sponsored, endorsed, or
172 | granted official status by, the Licensor or others
173 | designated to receive attribution as provided in Section
174 | 3(a)(1)(A)(i).
175 | - b. Other rights.
176 | - 1. Moral rights, such as the right of integrity, are not
177 | licensed under this Public License, nor are publicity,
178 | privacy, and/or other similar personality rights; however,
179 | to the extent possible, the Licensor waives and/or agrees
180 | not to assert any such rights held by the Licensor to the
181 | limited extent necessary to allow You to exercise the
182 | Licensed Rights, but not otherwise.
183 | - 2. Patent and trademark rights are not licensed under this
184 | Public License.
185 | - 3. To the extent possible, the Licensor waives any right to
186 | collect royalties from You for the exercise of the Licensed
187 | Rights, whether directly or through a collecting society
188 | under any voluntary or waivable statutory or compulsory
189 | licensing scheme. In all other cases the Licensor expressly
190 | reserves any right to collect such royalties, including when
191 | the Licensed Material is used other than for NonCommercial
192 | purposes.
193 |
194 | - Section 3 – License Conditions.
195 |
196 | Your exercise of the Licensed Rights is expressly made subject to
197 | the following conditions.
198 |
199 | - a. Attribution.
200 | - 1. If You Share the Licensed Material (including in modified
201 | form), You must:
202 | - A. retain the following if it is supplied by the
203 | Licensor with the Licensed Material:
204 | - i. identification of the creator(s) of the Licensed
205 | Material and any others designated to receive
206 | attribution, in any reasonable manner requested by
207 | the Licensor (including by pseudonym if designated);
208 | - ii. a copyright notice;
209 | - iii. a notice that refers to this Public License;
210 | - iv. a notice that refers to the disclaimer of
211 | warranties;
212 | - v. a URI or hyperlink to the Licensed Material to
213 | the extent reasonably practicable;
214 | - B. indicate if You modified the Licensed Material and
215 | retain an indication of any previous modifications; and
216 | - C. indicate the Licensed Material is licensed under this
217 | Public License, and include the text of, or the URI or
218 | hyperlink to, this Public License.
219 | - 2. You may satisfy the conditions in Section 3(a)(1) in any
220 | reasonable manner based on the medium, means, and context in
221 | which You Share the Licensed Material. For example, it may
222 | be reasonable to satisfy the conditions by providing a URI
223 | or hyperlink to a resource that includes the required
224 | information.
225 | - 3. If requested by the Licensor, You must remove any of the
226 | information required by Section 3(a)(1)(A) to the extent
227 | reasonably practicable.
228 | - 4. If You Share Adapted Material You produce, the Adapter's
229 | License You apply must not prevent recipients of the Adapted
230 | Material from complying with this Public License.
231 |
232 | - Section 4 – Sui Generis Database Rights.
233 |
234 | Where the Licensed Rights include Sui Generis Database Rights that
235 | apply to Your use of the Licensed Material:
236 |
237 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the
238 | right to extract, reuse, reproduce, and Share all or a
239 | substantial portion of the contents of the database for
240 | NonCommercial purposes only;
241 | - b. if You include all or a substantial portion of the database
242 | contents in a database in which You have Sui Generis Database
243 | Rights, then the database in which You have Sui Generis Database
244 | Rights (but not its individual contents) is Adapted Material;
245 | and
246 | - c. You must comply with the conditions in Section 3(a) if You
247 | Share all or a substantial portion of the contents of the
248 | database.
249 |
250 | For the avoidance of doubt, this Section 4 supplements and does not
251 | replace Your obligations under this Public License where the
252 | Licensed Rights include other Copyright and Similar Rights.
253 |
254 | - Section 5 – Disclaimer of Warranties and Limitation of Liability.
255 |
256 | - a. Unless otherwise separately undertaken by the Licensor, to
257 | the extent possible, the Licensor offers the Licensed Material
258 | as-is and as-available, and makes no representations or
259 | warranties of any kind concerning the Licensed Material, whether
260 | express, implied, statutory, or other. This includes, without
261 | limitation, warranties of title, merchantability, fitness for a
262 | particular purpose, non-infringement, absence of latent or other
263 | defects, accuracy, or the presence or absence of errors, whether
264 | or not known or discoverable. Where disclaimers of warranties
265 | are not allowed in full or in part, this disclaimer may not
266 | apply to You.
267 | - b. To the extent possible, in no event will the Licensor be
268 | liable to You on any legal theory (including, without
269 | limitation, negligence) or otherwise for any direct, special,
270 | indirect, incidental, consequential, punitive, exemplary, or
271 | other losses, costs, expenses, or damages arising out of this
272 | Public License or use of the Licensed Material, even if the
273 | Licensor has been advised of the possibility of such losses,
274 | costs, expenses, or damages. Where a limitation of liability is
275 | not allowed in full or in part, this limitation may not apply to
276 | You.
277 | - c. The disclaimer of warranties and limitation of liability
278 | provided above shall be interpreted in a manner that, to the
279 | extent possible, most closely approximates an absolute
280 | disclaimer and waiver of all liability.
281 |
282 | - Section 6 – Term and Termination.
283 |
284 | - a. This Public License applies for the term of the Copyright and
285 | Similar Rights licensed here. However, if You fail to comply
286 | with this Public License, then Your rights under this Public
287 | License terminate automatically.
288 | - b. Where Your right to use the Licensed Material has terminated
289 | under Section 6(a), it reinstates:
290 |
291 | - 1. automatically as of the date the violation is cured,
292 | provided it is cured within 30 days of Your discovery of the
293 | violation; or
294 | - 2. upon express reinstatement by the Licensor.
295 |
296 | For the avoidance of doubt, this Section 6(b) does not affect
297 | any right the Licensor may have to seek remedies for Your
298 | violations of this Public License.
299 |
300 | - c. For the avoidance of doubt, the Licensor may also offer the
301 | Licensed Material under separate terms or conditions or stop
302 | distributing the Licensed Material at any time; however, doing
303 | so will not terminate this Public License.
304 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
305 | License.
306 |
307 | - Section 7 – Other Terms and Conditions.
308 |
309 | - a. The Licensor shall not be bound by any additional or
310 | different terms or conditions communicated by You unless
311 | expressly agreed.
312 | - b. Any arrangements, understandings, or agreements regarding the
313 | Licensed Material not stated herein are separate from and
314 | independent of the terms and conditions of this Public License.
315 |
316 | - Section 8 – Interpretation.
317 |
318 | - a. For the avoidance of doubt, this Public License does not, and
319 | shall not be interpreted to, reduce, limit, restrict, or impose
320 | conditions on any use of the Licensed Material that could
321 | lawfully be made without permission under this Public License.
322 | - b. To the extent possible, if any provision of this Public
323 | License is deemed unenforceable, it shall be automatically
324 | reformed to the minimum extent necessary to make it enforceable.
325 | If the provision cannot be reformed, it shall be severed from
326 | this Public License without affecting the enforceability of the
327 | remaining terms and conditions.
328 | - c. No term or condition of this Public License will be waived
329 | and no failure to comply consented to unless expressly agreed to
330 | by the Licensor.
331 | - d. Nothing in this Public License constitutes or may be
332 | interpreted as a limitation upon, or waiver of, any privileges
333 | and immunities that apply to the Licensor or You, including from
334 | the legal processes of any jurisdiction or authority.
335 |
336 | Creative Commons is not a party to its public licenses. Notwithstanding,
337 | Creative Commons may elect to apply one of its public licenses to
338 | material it publishes and in those instances will be considered the
339 | "Licensor." The text of the Creative Commons public licenses is
340 | dedicated to the public domain under the CC0 Public Domain Dedication.
341 | Except for the limited purpose of indicating that material is shared
342 | under a Creative Commons public license or as otherwise permitted by the
343 | Creative Commons policies published at creativecommons.org/policies,
344 | Creative Commons does not authorize the use of the trademark "Creative
345 | Commons" or any other trademark or logo of Creative Commons without its
346 | prior written consent including, without limitation, in connection with
347 | any unauthorized modifications to any of its public licenses or any
348 | other arrangements, understandings, or agreements concerning use of
349 | licensed material. For the avoidance of doubt, this paragraph does not
350 | form part of the public licenses.
351 |
352 | Creative Commons may be contacted at creativecommons.org.
353 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include rave/configs/*.gin
2 | include rave/configs/augmentations/*.gin
3 | include requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # RAVE: Realtime Audio Variational autoEncoder
4 |
5 | Official implementation of _RAVE: A variational autoencoder for fast and high-quality neural audio synthesis_ ([article link](https://arxiv.org/abs/2111.05011)) by Antoine Caillon and Philippe Esling.
6 |
7 | If you use RAVE as a part of a music performance or installation, be sure to cite either this repository or the article !
8 |
9 | If you want to share / discuss / ask things about RAVE and other research from ACIDS, you can do so in our [discord server](https://discord.gg/r9umPrGEWv) !
10 |
11 | Please check the FAQ before posting an issue!
12 |
13 | **RAVE VST** RAVE VST for Windows, Mac and Linux is available as beta on the [corresponding Forum IRCAM webpage](https://forum.ircam.fr/projects/detail/rave-vst/). For problems, please write an issue here or [on the Forum IRCAM discussion page](https://discussion.forum.ircam.fr/c/rave-vst/651).
14 |
15 | **Tutorials** : new tutorials are available on the Forum IRCAM webpage, and video versions are coming soon!
16 | - [Tutorial: Neural Synthesis in a DAW with RAVE](https://forum.ircam.fr/article/detail/neural-synthesis-in-a-daw-with-rave/)
17 | - [Tutorial: Neural Synthesis in Max 8 with RAVE](https://forum.ircam.fr/article/detail/tutorial-neural-synthesis-in-max-8-with-rave/)
18 | - [Tutorial: Training RAVE models on custom data](https://forum.ircam.fr/article/detail/training-rave-models-on-custom-data/)
19 |
20 | ## Previous versions
21 |
22 | The original implementation of the RAVE model can be restored using
23 |
24 | ```bash
25 | git checkout v1
26 | ```
27 |
28 | ## Installation
29 |
30 | Install RAVE using
31 |
32 | ```bash
33 | pip install acids-rave
34 | ```
35 |
36 | **Warning** It is strongly advised to install `torch` and `torchaudio` before `acids-rave`, so you can choose the appropriate version of torch on the [library website](http://www.pytorch.org). For future compatibility with new devices (and modern Python environments), `rave-acids` does not enforce torch==1.13 anymore.
37 |
38 | You will need **ffmpeg** on your computer. You can install it locally inside your virtual environment using
39 |
40 | ```bash
41 | conda install ffmpeg
42 | ```
43 |
44 |
45 |
46 | ## Colab
47 |
48 | A colab to train RAVEv2 is now available thanks to [hexorcismos](https://github.com/moiseshorta) !
49 | [](https://colab.research.google.com/drive/1ih-gv1iHEZNuGhHPvCHrleLNXvooQMvI?usp=sharing)
50 |
51 | ## Usage
52 |
53 | Training a RAVE model usually involves 3 separate steps, namely _dataset preparation_, _training_ and _export_.
54 |
55 | ### Dataset preparation
56 |
57 | You can know prepare a dataset using two methods: regular and lazy. Lazy preprocessing allows RAVE to be trained directly on the raw files (i.e. mp3, ogg), without converting them first. **Warning**: lazy dataset loading will increase your CPU load by a large margin during training, especially on Windows. This can however be useful when training on large audio corpus which would not fit on a hard drive when uncompressed. In any case, prepare your dataset using
58 |
59 | ```bash
60 | rave preprocess --input_path /audio/folder --output_path /dataset/path --channels X (--lazy)
61 | ```
62 |
63 | ### Training
64 |
65 | RAVEv2 has many different configurations. The improved version of the v1 is called `v2`, and can therefore be trained with
66 |
67 | ```bash
68 | rave train --config v2 --db_path /dataset/path --out_path /model/out --name give_a_name --channels X
69 | ```
70 |
71 | We also provide a discrete configuration, similar to SoundStream or EnCodec
72 |
73 | ```bash
74 | rave train --config discrete ...
75 | ```
76 |
77 | By default, RAVE is built with non-causal convolutions. If you want to make the model causal (hence lowering the overall latency of the model), you can use the causal mode
78 |
79 | ```bash
80 | rave train --config discrete --config causal ...
81 | ```
82 |
83 | New in 2.3, data augmentations are also available to improve the model's generalization in low data regimes. You can add data augmentation by adding augmentation configuration files with the `--augment` keyword
84 |
85 | ```bash
86 | rave train --config v2 --augment mute --augment compress
87 | ```
88 |
89 | Many other configuration files are available in `rave/configs` and can be combined. Here is a list of all the available configurations & augmentations :
90 |
91 |
92 |
93 |
94 | Type |
95 | Name |
96 | Description |
97 |
98 |
99 |
100 |
101 |
102 | Architecture |
103 | v1 |
104 | Original continuous model (minimum GPU memory : 8Go) |
105 |
106 |
107 |
108 | v2 |
109 | Improved continuous model (faster, higher quality) (minimum GPU memory : 16Go) |
110 |
111 |
112 |
113 | v2_small |
114 | v2 with a smaller receptive field, adpated adversarial training, and noise generator, adapted for timbre transfer for stationary signals (minimum GPU memory : 8Go) |
115 |
116 |
117 |
118 | v2_nopqmf |
119 | (experimental) v2 without pqmf in generator (more efficient for bending purposes) (minimum GPU memory : 16Go) |
120 |
121 |
122 |
123 | v3 |
124 | v2 with Snake activation, descript discriminator and Adaptive Instance Normalization for real style transfer (minimum GPU memory : 32Go) |
125 |
126 |
127 |
128 | discrete |
129 | Discrete model (similar to SoundStream or EnCodec) (minimum GPU memory : 18Go) |
130 |
131 |
132 |
133 | onnx |
134 | Noiseless v1 configuration for onnx usage (minimum GPU memory : 6Go) |
135 |
136 |
137 |
138 | raspberry |
139 | Lightweight configuration compatible with realtime RaspberryPi 4 inference (minimum GPU memory : 5Go) |
140 |
141 |
142 |
143 | Regularization (v2 only) |
144 | default |
145 | Variational Auto Encoder objective (ELBO) |
146 |
147 |
148 |
149 | wasserstein |
150 | Wasserstein Auto Encoder objective (MMD) |
151 |
152 |
153 |
154 | spherical |
155 | Spherical Auto Encoder objective |
156 |
157 |
158 |
159 | Discriminator |
160 | spectral_discriminator |
161 | Use the MultiScale discriminator from EnCodec. |
162 |
163 |
164 |
165 | Others |
166 | causal |
167 | Use causal convolutions |
168 |
169 |
170 |
171 | noise |
172 | Enables noise synthesizer V2 |
173 |
174 |
175 |
176 | hybrid |
177 | Enable mel-spectrogram input |
178 |
179 |
180 |
181 | Augmentations |
182 | mute |
183 | Randomly mutes data batches (default prob : 0.1). Enforces the model to learn silence |
184 |
185 |
186 |
187 | compress |
188 | Randomly compresses the waveform (equivalent to light non-linear amplification of batches) |
189 |
190 |
191 |
192 | gain |
193 | Applies a random gain to waveform (default range : [-6, 3]) |
194 |
195 |
196 |
197 |
198 |
199 | ### Export
200 |
201 | Once trained, export your model to a torchscript file using
202 |
203 | ```bash
204 | rave export --run /path/to/your/run (--streaming)
205 | ```
206 |
207 | Setting the `--streaming` flag will enable cached convolutions, making the model compatible with realtime processing. **If you forget to use the streaming mode and try to load the model in Max, you will hear clicking artifacts.**
208 |
209 | ## Prior
210 |
211 | For discrete models, we redirect the user to the `msprior` library [here](https://github.com/caillonantoine/msprior). However, as this library is still experimental, the prior from version 1.x has been re-integrated in v2.3.
212 |
213 | ### Training
214 |
215 | To train a prior for a pretrained RAVE model :
216 |
217 | ```bash
218 | rave train_prior --model /path/to/your/run --db_path /path/to/your_preprocessed_data --out_path /path/to/output
219 | ```
220 |
221 | this will train a prior over the latent of the pretrained model `path/to/your/run`, and save the model and tensorboard logs to folder `/path/to/output`.
222 |
223 | ### Scripting
224 |
225 | To script a prior along with a RAVE model, export your model by providing the `--prior` keyword to your pretrained prior :
226 |
227 | ```bash
228 | rave export --run /path/to/your/run --prior /path/to/your/prior (--streaming)
229 | ```
230 |
231 | ## Pretrained models
232 |
233 | Several pretrained streaming models [are available here](https://acids-ircam.github.io/rave_models_download). We'll keep the list updated with new models.
234 |
235 | ## Realtime usage
236 |
237 | This section presents how RAVE can be loaded inside [`nn~`](https://acids-ircam.github.io/nn_tilde/) in order to be used live with Max/MSP or PureData.
238 |
239 | ### Reconstruction
240 |
241 | A pretrained RAVE model named `darbouka.gin` available on your computer can be loaded inside `nn~` using the following syntax, where the default method is set to forward (i.e. encode then decode)
242 |
243 |
244 |
245 | This does the same thing as the following patch, but slightly faster.
246 |
247 |
248 |
249 | ### High-level manipulation
250 |
251 | Having an explicit access to the latent representation yielded by RAVE allows us to interact with the representation using Max/MSP or PureData signal processing tools:
252 |
253 |
254 |
255 | ### Style transfer
256 |
257 | By default, RAVE can be used as a style transfer tool, based on the large compression ratio of the model. We recently added a technique inspired from StyleGAN to include Adaptive Instance Normalization to the reconstruction process, effectively allowing to define _source_ and _target_ styles directly inside Max/MSP or PureData, using the attribute system of `nn~`.
258 |
259 |
260 |
261 | Other attributes, such as `enable` or `gpu` can enable/disable computation, or use the gpu to speed up things (still experimental).
262 |
263 | ## Offline usage
264 |
265 | A batch generation script has been released in v2.3 to allow transformation of large amount of files
266 |
267 | ```bash
268 | rave generate model_path path_1 path_2 --out out_path
269 | ```
270 |
271 | where `model_path` is the path to your trained model (original or scripted), `path_X` a list of audio files or directories, and `out_path` the out directory of the generations.
272 |
273 | ## Discussion
274 |
275 | If you have questions, want to share your experience with RAVE or share musical pieces done with the model, you can use the [Discussion tab](https://github.com/acids-ircam/RAVE/discussions) !
276 |
277 | ## Demonstration
278 |
279 | ### RAVE x nn~
280 |
281 | Demonstration of what you can do with RAVE and the nn~ external for maxmsp !
282 |
283 | [](https://www.youtube.com/watch?v=dMZs04TzxUI)
284 |
285 | ### embedded RAVE
286 |
287 | Using nn~ for puredata, RAVE can be used in realtime on embedded platforms !
288 |
289 | [](https://www.youtube.com/watch?v=jAIRf4nGgYI)
290 |
291 | # Frequently Asked Question (FAQ)
292 |
293 | **Question** : my preprocessing is stuck, showing `0it[00:00, ?it/s]`
294 | **Answer** : This means that the audio files in your dataset are too short to provide a sufficient temporal scope to RAVE. Try decreasing the signal window with the `--num_signal XXX(samples)` with `preprocess`, without forgetting afterwards to add the `--n_signal XXX(samples)` with `train`
295 |
296 | **Question** : During training I got an exception resembling `ValueError: n_components=128 must be between 0 and min(n_samples, n_features)=64 with svd_solver='full'`
297 | **Answer** : This means that your dataset does not have enough data batches to compute the intern latent PCA, that requires at least 128 examples (then batches).
298 |
299 |
300 | # Funding
301 |
302 | This work is led at IRCAM, and has been funded by the following projects
303 |
304 | - [ANR MakiMono](https://acids.ircam.fr/course/makimono/)
305 | - [ACTOR](https://www.actorproject.org/)
306 | - [DAFNE+](https://dafneplus.eu/) N° 101061548
307 |
308 |
309 |
--------------------------------------------------------------------------------
/docs/darbouka_prior.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/darbouka_prior.mp3
--------------------------------------------------------------------------------
/docs/log_distance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/log_distance.png
--------------------------------------------------------------------------------
/docs/log_fidelity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/log_fidelity.png
--------------------------------------------------------------------------------
/docs/log_gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/log_gan.png
--------------------------------------------------------------------------------
/docs/maxmsp_screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/maxmsp_screenshot.png
--------------------------------------------------------------------------------
/docs/rave.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave.png
--------------------------------------------------------------------------------
/docs/rave_attribute.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_attribute.png
--------------------------------------------------------------------------------
/docs/rave_buffer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_buffer.png
--------------------------------------------------------------------------------
/docs/rave_encode_decode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_encode_decode.png
--------------------------------------------------------------------------------
/docs/rave_high_level.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_high_level.png
--------------------------------------------------------------------------------
/docs/rave_method_forward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/docs/rave_method_forward.png
--------------------------------------------------------------------------------
/docs/tensorboard_guide.md:
--------------------------------------------------------------------------------
1 | # Tensorboard guide
2 |
3 | ## Latent space size estimation
4 |
5 | During training, RAVE regularly estimates the **size** of the latent space given a specific dataset for a given *fidelity*. The fidelity parameter is a percentage that defines how well the model should be able to reconstruct an input audio sample.
6 |
7 | Usually values around 80% yield correct yet not accurate reconstructions. Values around 95% are most of the time sufficient to have both a compact latent space and correct reconstructions.
8 |
9 | We log the estimated size of the latent space for several values of fidelity in tensorboard (80, 90, 95 and 99%).
10 |
11 | 
12 |
13 | ## Reconstrution error
14 |
15 | The values you should look at for tracking the reconstruction error of the model are the *distance* and *validation* logs
16 |
17 | 
18 |
19 | When the 2 phase kicks in, those values increase - **that's usually normal**
20 |
21 | ## Adversarial losses
22 |
23 | The `loss_dis, loss_gen, pred_true, pred_fake` losses only appear during the second phase. They are usually harder to read, as most of GAN losses are, bu we include here an example of what *normal* logs should look like
24 |
25 | 
--------------------------------------------------------------------------------
/docs/training_setup.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Training setup
4 |
5 | 1. You should train on a _CUDA-enabled_ machine (i.e with an nvidia-card)
6 | - You can use either **Linux** or **Windows**
7 | - However we advise to use **Linux** if available
8 | - Training RAVE without a hardware accelerator (GPU, TPU) will take ages, and is not recommended
9 | 2. Make sure that you have CUDA enabled
10 | - Go to a terminal an enter `nvidia-smi`
11 | - If a message appears with the name of your graphic card and the available memory, it's all good !
12 | - Otherwise, you have to install **cuda** on your computer (we don't provide support for that, lots of guides are available online)
13 | 3. Let's install python !
14 |
15 | # Python installation
16 |
17 | Python is often pre-installed on most computers, but we won't use this version. Instead, we will install a **conda** distribution on the machine. This keeps different versions of python separate for different projects, and allows regular users to install new packages without sudo access.
18 |
19 | You can follow the [instructions here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) to install a miniconda environment on your computer.
20 |
21 | Once installed, you know that you are inside your miniconda environment if there's a "`(base)`" at the beginning of your terminal.
22 |
23 | # RAVE installation
24 |
25 | We will create a new virtual environment for RAVE.
26 |
27 | ```bash
28 | conda create -n rave python=3.9
29 | ```
30 |
31 | Each time we want to use RAVE, we can (and **should**) activate this environment using
32 |
33 | ```bash
34 | conda activate rave
35 | ```
36 |
37 | Let's clone RAVE and install the requirements !
38 |
39 | ```bash
40 | git clone https://github.com/acids-ircam/RAVE
41 | cd RAVE
42 | pip install -r requirements.txt
43 | ```
44 |
45 | You can now use `python cli_helper.py` to start a new training !
46 |
47 | # About the dataset
48 |
49 | A good rule of thumb is **more is better**. You might want to have _at least_ 3h of homogeneous recordings to train RAVE, more if your dataset is complex (e.g mixtures of instruments, lots of variations...)
50 |
51 | If you have a folder filled with various audio files (any extension, any sampling rate), you can use the `resample` utility in this folder
52 |
53 | ```bash
54 | conda activate rave
55 | resample --sr TARGET_SAMPLING_RATE --augment
56 | ```
57 |
58 | It will convert, resample, crop and augment all audio files present in the directory to an output directory called `out_TARGET_SAMPLING_RATE/` (which is the one you should give to `cli_helper.py` when asked for the path of the .wav files).
59 |
--------------------------------------------------------------------------------
/rave/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import cached_conv as cc
4 | import gin
5 | import torch
6 |
7 |
8 | BASE_PATH: Path = Path(__file__).parent
9 |
10 | gin.add_config_file_search_path(BASE_PATH)
11 | gin.add_config_file_search_path(BASE_PATH.joinpath('configs'))
12 | gin.add_config_file_search_path(BASE_PATH.joinpath('configs', 'augmentations'))
13 |
14 |
15 | def __safe_configurable(name):
16 | try:
17 | setattr(cc, name, gin.get_configurable(f"cc.{name}"))
18 | except ValueError:
19 | setattr(cc, name, gin.external_configurable(getattr(cc, name), module="cc"))
20 |
21 | # cc.get_padding = gin.external_configurable(cc.get_padding, module="cc")
22 | # cc.Conv1d = gin.external_configurable(cc.Conv1d, module="cc")
23 | # cc.ConvTranspose1d = gin.external_configurable(cc.ConvTranspose1d, module="cc")
24 |
25 | __safe_configurable("get_padding")
26 | __safe_configurable("Conv1d")
27 | __safe_configurable("ConvTranspose1d")
28 |
29 | from .blocks import *
30 | from .discriminator import *
31 | from .model import RAVE, BetaWarmupCallback
32 | from .pqmf import *
33 | from .balancer import *
34 |
--------------------------------------------------------------------------------
/rave/balancer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import gin.torch
3 |
4 |
5 | @gin.configurable
6 | class Balancer(nn.Module):
7 | def __init__(self):
8 | super().__init__(self)
9 |
10 | def forward(self, *args, **kwargs):
11 | raise RuntimeError('Balancer has been disabled in newest RAVE version. \n' \
12 | 'If you try to import checkpoint trained with a previous version, remove it from configuration.')
--------------------------------------------------------------------------------
/rave/configs/adain.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import blocks
5 |
6 | blocks.EncoderV2:
7 | adain = @blocks.AdaptiveInstanceNormalization
8 |
9 | blocks.GeneratorV2:
10 | adain = @blocks.AdaptiveInstanceNormalization
--------------------------------------------------------------------------------
/rave/configs/augmentations/compress.gin:
--------------------------------------------------------------------------------
1 | # dataset.get_dataset:
2 | # augmentations = [
3 | # @augmentations/transforms.RandomCompress(),
4 | # ]
5 |
6 | add_augmentation:
7 | aug = @augmentations/transforms.RandomCompress()
8 |
9 |
--------------------------------------------------------------------------------
/rave/configs/augmentations/gain.gin:
--------------------------------------------------------------------------------
1 | add_augmentation:
2 | aug = @augmentations/transforms.RandomGain()
3 |
4 |
--------------------------------------------------------------------------------
/rave/configs/augmentations/mute.gin:
--------------------------------------------------------------------------------
1 | add_augmentation:
2 | aug = @augmentations/transforms.RandomMute()
3 |
4 |
--------------------------------------------------------------------------------
/rave/configs/causal.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import cached_conv as cc
4 |
5 | cc.get_padding.mode = 'causal'
6 |
--------------------------------------------------------------------------------
/rave/configs/descript_discriminator.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import descript_discriminator
5 |
6 | rave.RAVE:
7 | discriminator = @descript_discriminator.DescriptDiscriminator
--------------------------------------------------------------------------------
/rave/configs/discrete.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | include "configs/v2.gin"
4 |
5 | import rave
6 | from rave import core
7 | from rave import blocks
8 | from rave import discriminator
9 | from rave import quantization
10 |
11 | import torch.nn as nn
12 |
13 | NUM_QUANTIZERS = 16
14 | RATIOS = [4, 4, 2, 2]
15 | LATENT_SIZE = 128
16 | CODEBOOK_SIZE = 1024
17 | DYNAMIC_MASKING = False
18 | CAPACITY = 96
19 | NOISE_AUGMENTATION = 128
20 | PHASE_1_DURATION = 200000
21 |
22 | core.AudioDistanceV1.log_epsilon = 1
23 |
24 | # ENCODER
25 |
26 | blocks.DiscreteEncoder:
27 | encoder_cls = @blocks.EncoderV2
28 | vq_cls = @quantization.ResidualVectorQuantization
29 | num_quantizers = %NUM_QUANTIZERS
30 | noise_augmentation = %NOISE_AUGMENTATION
31 |
32 | blocks.EncoderV2:
33 | n_out = 1
34 |
35 | quantization.ResidualVectorQuantization:
36 | num_quantizers = %NUM_QUANTIZERS
37 | dim = %LATENT_SIZE
38 | codebook_size = %CODEBOOK_SIZE
39 |
40 | # RAVE
41 | rave.RAVE:
42 | encoder = @blocks.DiscreteEncoder
43 | phase_1_duration = %PHASE_1_DURATION
44 | warmup_quantize = -1
45 | discriminator = @discriminator.CombineDiscriminators
46 | gan_loss = @core.hinge_gan
47 | valid_signal_crop = True
48 | num_skipped_features = 0
49 | update_discriminator_every = 4
50 |
51 | rave.BetaWarmupCallback:
52 | initial_value = .1
53 | target_value = .1
54 | warmup_len = 1
--------------------------------------------------------------------------------
/rave/configs/discrete_v3.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | include "configs/discrete.gin"
4 | include "configs/snake.gin"
5 | include "configs/descript_discriminator.gin"
6 |
7 | import rave
8 |
9 | rave.BetaWarmupCallback:
10 | initial_value = 1e-6
11 | target_value = 5e-2
12 | warmup_len = 20000
--------------------------------------------------------------------------------
/rave/configs/hybrid.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | from rave import blocks
4 | from rave import core
5 | from torchaudio import transforms
6 |
7 | import rave
8 |
9 | include "configs/v2.gin"
10 |
11 | N_FFT = 2048
12 | N_MELS = 128
13 | HOP_LENGTH = 256
14 | ENCODER_RATIOS = [2, 2, 2]
15 | NUM_GRU_LAYERS = 2
16 |
17 | blocks.EncoderV2:
18 | data_size = %N_MELS
19 | ratios = %ENCODER_RATIOS
20 | dilations = [1]
21 |
22 | core.n_fft_to_num_bands:
23 | n_fft = %N_FFT
24 |
25 | transforms.MelSpectrogram:
26 | sample_rate = %SAMPLING_RATE
27 | n_fft = %N_FFT
28 | win_length = %N_FFT
29 | hop_length = %HOP_LENGTH
30 | normalized = True
31 | n_mels = %N_MELS
32 |
33 | blocks.GeneratorV2:
34 | recurrent_layer = @blocks.GRU
35 |
36 | blocks.GRU:
37 | latent_size = %LATENT_SIZE
38 | num_layers = %NUM_GRU_LAYERS
39 |
40 | rave.RAVE:
41 | spectrogram = @transforms.MelSpectrogram()
42 | input_mode = "mel"
43 |
--------------------------------------------------------------------------------
/rave/configs/noise.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | from rave import blocks
4 |
5 | blocks.GeneratorV2:
6 | noise_module = @blocks.NoiseGeneratorV2
7 |
8 | blocks.NoiseGeneratorV2:
9 | hidden_size = 128
10 | data_size = %N_BAND
11 | ratios = [2, 2, 2]
12 | noise_bands = 5
--------------------------------------------------------------------------------
/rave/configs/normalize_ambient.gin:
--------------------------------------------------------------------------------
1 | dataset.get_dataset:
2 | augmentations = [
3 | @augmentations/transforms.Compress()
4 | ]
5 |
6 | augmentations/transforms.Compress:
7 | time='0.01,0.01'
8 | lookup='6:-30,-15,-10,-8,0,-5'
9 | sr=%SAMPLING_RATE
--------------------------------------------------------------------------------
/rave/configs/onnx.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | include "configs/v1.gin"
4 |
5 | import rave
6 | from rave import blocks
7 |
8 | CAPACITY = 32
9 |
10 | blocks.Generator.use_noise = False
--------------------------------------------------------------------------------
/rave/configs/prior/prior_v1.gin:
--------------------------------------------------------------------------------
1 | VariationalPrior:
2 | resolution = 32
3 | res_size = 512
4 | skp_size=256
5 | kernel_size=3
6 | cycle_size=4
7 | n_layers=10
8 | sr=@get_model_sr()
9 |
--------------------------------------------------------------------------------
/rave/configs/raspberry.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | include "configs/onnx.gin"
4 |
5 | CAPACITY = 16
--------------------------------------------------------------------------------
/rave/configs/snake.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | from rave import blocks
4 |
5 | ACTIVATION = @blocks.Snake
6 |
7 | blocks.ResidualLayer:
8 | activation = %ACTIVATION
9 |
10 | blocks.DilatedUnit:
11 | activation = %ACTIVATION
12 |
13 | blocks.UpsampleLayer:
14 | activation = %ACTIVATION
15 |
16 | blocks.NoiseGeneratorV2:
17 | activation = %ACTIVATION
18 |
19 | blocks.EncoderV2:
20 | activation = %ACTIVATION
21 |
22 | blocks.GeneratorV2:
23 | activation = %ACTIVATION
--------------------------------------------------------------------------------
/rave/configs/spectral_discriminator.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import discriminator
5 |
6 | discriminator.MultiScaleSpectralDiscriminator:
7 | scales = [4096, 2048, 1024, 512, 256]
8 | convnet = @discriminator.EncodecConvNet
9 |
10 | discriminator.EncodecConvNet:
11 | capacity = 32
12 |
13 | discriminator.CombineDiscriminators:
14 | discriminators = [
15 | @discriminator.MultiScaleDiscriminator,
16 | @discriminator.MultiScaleSpectralDiscriminator
17 | ]
--------------------------------------------------------------------------------
/rave/configs/spherical.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import blocks
5 |
6 | LATENT_SIZE = 16
7 |
8 | blocks.EncoderV2.n_out = 1
9 |
10 | blocks.SphericalEncoder:
11 | encoder_cls = @blocks.EncoderV2
12 |
13 | rave.RAVE:
14 | encoder = @blocks.SphericalEncoder
15 | phase_1_duration = 200000
16 |
--------------------------------------------------------------------------------
/rave/configs/v1.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import pqmf
5 | from rave import core
6 | from rave import blocks
7 | from rave import discriminator
8 | from rave import dataset
9 |
10 | import cached_conv as cc
11 | import torch
12 |
13 | SAMPLING_RATE = 44100
14 | CAPACITY = 64
15 | N_BAND = 16
16 | LATENT_SIZE = 128
17 | RATIOS = [4, 4, 4, 2]
18 | PHASE_1_DURATION = 1000000
19 |
20 | # CORE CONFIGURATION
21 | core.AudioDistanceV1:
22 | multiscale_stft = @core.MultiScaleSTFT
23 | log_epsilon = 1e-7
24 |
25 | core.MultiScaleSTFT:
26 | scales = [2048, 1024, 512, 256, 128]
27 | sample_rate = %SAMPLING_RATE
28 | magnitude = True
29 |
30 | dataset.split_dataset.max_residual = 1000
31 |
32 | # CONVOLUTION CONFIGURATION
33 | cc.Conv1d.bias = False
34 | cc.ConvTranspose1d.bias = False
35 |
36 | # PQMF
37 | pqmf.CachedPQMF:
38 | attenuation = 100
39 | n_band = %N_BAND
40 |
41 | blocks.normalization.mode = 'weight_norm'
42 |
43 | # ENCODER
44 | blocks.Encoder:
45 | data_size = %N_BAND
46 | capacity = %CAPACITY
47 | latent_size = %LATENT_SIZE
48 | ratios = %RATIOS
49 | sample_norm = False
50 | repeat_layers = 1
51 |
52 | variational/blocks.Encoder.n_out = 2
53 |
54 | blocks.VariationalEncoder:
55 | encoder = @variational/blocks.Encoder
56 |
57 | # DECODER
58 | blocks.Generator:
59 | latent_size = %LATENT_SIZE
60 | capacity = %CAPACITY
61 | data_size = %N_BAND
62 | ratios = %RATIOS
63 | loud_stride = 1
64 | use_noise = True
65 |
66 | blocks.ResidualStack:
67 | kernel_sizes = [3]
68 | dilations_list = [[1, 1], [3, 1], [5, 1]]
69 |
70 | blocks.NoiseGenerator:
71 | ratios = [4, 4, 4]
72 | noise_bands = 5
73 |
74 | # DISCRIMINATOR
75 | discriminator.ConvNet:
76 | in_size = 1
77 | out_size = 1
78 | capacity = %CAPACITY
79 | n_layers = 4
80 | stride = 4
81 |
82 | scales/discriminator.ConvNet:
83 | conv = @torch.nn.Conv1d
84 | kernel_size = 15
85 |
86 | discriminator.MultiScaleDiscriminator:
87 | n_discriminators = 3
88 | convnet = @scales/discriminator.ConvNet
89 |
90 | feature_matching/core.mean_difference:
91 | norm = 'L1'
92 |
93 | # MODEL ASSEMBLING
94 | rave.RAVE:
95 | latent_size = %LATENT_SIZE
96 | pqmf = @pqmf.CachedPQMF
97 | sampling_rate = %SAMPLING_RATE
98 | encoder = @blocks.VariationalEncoder
99 | decoder = @blocks.Generator
100 | discriminator = @discriminator.MultiScaleDiscriminator
101 | phase_1_duration = %PHASE_1_DURATION
102 | gan_loss = @core.hinge_gan
103 | valid_signal_crop = False
104 | feature_matching_fun = @feature_matching/core.mean_difference
105 | num_skipped_features = 0
106 | audio_distance = @core.AudioDistanceV1
107 | multiband_audio_distance = @core.AudioDistanceV1
108 | weights = {
109 | 'feature_matching': 10
110 | }
111 |
112 | rave.BetaWarmupCallback:
113 | initial_value = .1
114 | target_value = .1
115 | warmup_len = 1
--------------------------------------------------------------------------------
/rave/configs/v2.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import core
5 | from rave import blocks
6 | from rave import discriminator
7 |
8 | import torch.nn as nn
9 |
10 | include "configs/v1.gin"
11 |
12 | KERNEL_SIZE = 3
13 | DILATIONS = [
14 | [1, 3, 9],
15 | [1, 3, 9],
16 | [1, 3, 9],
17 | [1, 3],
18 | ]
19 | RATIOS = [4, 4, 4, 2]
20 | CAPACITY = 96
21 | NOISE_AUGMENTATION = 0
22 |
23 | core.AudioDistanceV1.log_epsilon = 1e-7
24 |
25 | core.get_augmented_latent_size:
26 | latent_size = %LATENT_SIZE
27 | noise_augmentation = %NOISE_AUGMENTATION
28 |
29 | # ENCODER
30 | blocks.EncoderV2:
31 | data_size = %N_BAND
32 | capacity = %CAPACITY
33 | ratios = %RATIOS
34 | latent_size = %LATENT_SIZE
35 | n_out = 2
36 | kernel_size = %KERNEL_SIZE
37 | dilations = %DILATIONS
38 |
39 | blocks.VariationalEncoder:
40 | encoder = @variational/blocks.EncoderV2
41 |
42 | # GENERATOR
43 | blocks.GeneratorV2:
44 | data_size = %N_BAND
45 | capacity = %CAPACITY
46 | ratios = %RATIOS
47 | latent_size = @core.get_augmented_latent_size()
48 | kernel_size = %KERNEL_SIZE
49 | dilations = %DILATIONS
50 | amplitude_modulation = True
51 |
52 | # DISCRIMINATOR
53 | periods/discriminator.ConvNet:
54 | conv = @nn.Conv2d
55 | kernel_size = (5, 1)
56 |
57 | spectral/discriminator.ConvNet:
58 | conv = @nn.Conv1d
59 | kernel_size = 5
60 | stride = 2
61 |
62 | discriminator.MultiPeriodDiscriminator:
63 | periods = [2, 3, 5, 7, 11]
64 | convnet = @periods/discriminator.ConvNet
65 |
66 | discriminator.MultiScaleSpectralDiscriminator1d:
67 | scales = [4096, 2048, 1024, 512, 256]
68 | convnet = @spectral/discriminator.ConvNet
69 |
70 | discriminator.CombineDiscriminators:
71 | discriminators = [
72 | @discriminator.MultiPeriodDiscriminator,
73 | @discriminator.MultiScaleDiscriminator,
74 | # @discriminator.MultiScaleSpectralDiscriminator1d,
75 | ]
76 |
77 | feature_matching/core.mean_difference:
78 | relative = True
79 |
80 | # RAVE
81 | rave.RAVE:
82 | discriminator = @discriminator.CombineDiscriminators
83 | valid_signal_crop = True
84 | num_skipped_features = 1
85 | decoder = @blocks.GeneratorV2
86 | update_discriminator_every = 4
87 | weights = {
88 | 'feature_matching': 20,
89 | }
90 |
91 | rave.BetaWarmupCallback:
92 | initial_value = 1e-6
93 | target_value = 5e-2
94 | warmup_len = 20000
95 |
--------------------------------------------------------------------------------
/rave/configs/v2_nopqmf.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import core
5 | from rave import dataset
6 | from rave import blocks
7 | from rave import discriminator
8 | from rave import transforms
9 |
10 | import torch.nn as nn
11 |
12 | include "configs/v1.gin"
13 |
14 | KERNEL_SIZE = 3
15 | DILATIONS = [
16 | [1, 3, 9],
17 | [1, 3, 9],
18 | [1, 3, 9],
19 | [1, 3],
20 | ]
21 | RATIOS = [4, 4, 4, 2]
22 | CAPACITY = 64
23 | NOISE_AUGMENTATION = 0
24 |
25 | core.AudioDistanceV1.log_epsilon = 1e-7
26 |
27 | core.get_augmented_latent_size:
28 | latent_size = %LATENT_SIZE
29 | noise_augmentation = %NOISE_AUGMENTATION
30 |
31 |
32 | # AUGMENTATIONS
33 | dataset.get_dataset:
34 | augmentations = [
35 | @augmentations/transforms.RandomCompress()
36 | ]
37 |
38 | augmentations/transforms.RandomCompress:
39 | amp_range = [-60,-10]
40 | threshold=-40
41 | prob = 0.5
42 | sr=%SAMPLING_RATE
43 |
44 | # ENCODER
45 | blocks.EncoderV2:
46 | data_size = %N_BAND
47 | capacity = %CAPACITY
48 | ratios = [4, 4, 4, 2]
49 | latent_size = %LATENT_SIZE
50 | n_out = 2
51 | kernel_size = %KERNEL_SIZE
52 | dilations = %DILATIONS
53 |
54 | blocks.VariationalEncoder:
55 | encoder = @variational/blocks.EncoderV2
56 |
57 | # GENERATOR
58 | blocks.GeneratorV2:
59 | capacity = %CAPACITY
60 | ratios = [8, 8, 8, 4]
61 | latent_size = @core.get_augmented_latent_size()
62 | kernel_size = %KERNEL_SIZE
63 | dilations = %DILATIONS
64 | amplitude_modulation = True
65 |
66 | # DISCRIMINATOR
67 | periods/discriminator.ConvNet:
68 | conv = @nn.Conv2d
69 | kernel_size = (5, 1)
70 |
71 | spectral/discriminator.ConvNet:
72 | conv = @nn.Conv1d
73 | kernel_size = 5
74 | stride = 2
75 |
76 | discriminator.MultiPeriodDiscriminator:
77 | periods = [2, 3, 5, 7, 11]
78 | convnet = @periods/discriminator.ConvNet
79 |
80 | discriminator.MultiScaleSpectralDiscriminator1d:
81 | scales = [4096, 2048, 1024, 512, 256]
82 | convnet = @spectral/discriminator.ConvNet
83 |
84 | discriminator.CombineDiscriminators:
85 | discriminators = [
86 | @discriminator.MultiPeriodDiscriminator,
87 | @discriminator.MultiScaleDiscriminator,
88 | # @discriminator.MultiScaleSpectralDiscriminator1d,
89 | ]
90 |
91 | feature_matching/core.mean_difference:
92 | relative = True
93 |
94 | # RAVE
95 | rave.RAVE:
96 | n_bands = %N_BAND
97 | discriminator = @discriminator.CombineDiscriminators
98 | valid_signal_crop = True
99 | num_skipped_features = 1
100 | decoder = @blocks.GeneratorV2
101 | phase_1_duration = 1000000
102 | weights = {
103 | 'feature_matching': 20
104 | }
105 | update_discriminator_every = 4
106 | output_mode = "raw"
107 | audio_monitor_epochs = 10
108 |
109 | rave.BetaWarmupCallback:
110 | initial_value = 1e-6
111 | target_value = 1e-2
112 | warmup_len = 500000
113 |
114 |
--------------------------------------------------------------------------------
/rave/configs/v2_nopqmf_small.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import core
5 | from rave import dataset
6 | from rave import blocks
7 | from rave import discriminator
8 | from rave import balancer
9 | from rave import transforms
10 |
11 | import torch.nn as nn
12 |
13 | include "configs/v1.gin"
14 |
15 | KERNEL_SIZE = 3
16 | DILATIONS = [
17 | [1, 3, 9],
18 | [1, 3, 9],
19 | [1, 3, 9],
20 | [1, 3],
21 | ]
22 | RATIOS = [4, 4, 4, 2]
23 | CAPACITY = 64
24 | NOISE_AUGMENTATION = 0
25 |
26 | core.AudioDistanceV1.log_epsilon = 1e-7
27 |
28 | core.get_augmented_latent_size:
29 | latent_size = %LATENT_SIZE
30 | noise_augmentation = %NOISE_AUGMENTATION
31 |
32 |
33 | # AUGMENTATIONS
34 | dataset.get_dataset:
35 | augmentations = [
36 | @augmentations/transforms.Compress()
37 | ]
38 |
39 | augmentations/transforms.Compress:
40 | amp_range = [-60,-10]
41 | threshold=-40
42 | prob = 0.5
43 |
44 | # ENCODER
45 | blocks.EncoderV2:
46 | data_size = %N_BAND
47 | capacity = %CAPACITY
48 | ratios = [4, 4, 4, 2]
49 | latent_size = %LATENT_SIZE
50 | n_out = 2
51 | kernel_size = %KERNEL_SIZE
52 | dilations = %DILATIONS
53 |
54 | blocks.VariationalEncoder:
55 | encoder = @variational/blocks.EncoderV2
56 |
57 | # GENERATOR
58 | blocks.GeneratorV2:
59 | capacity = %CAPACITY
60 | ratios = [8, 8, 8, 4]
61 | latent_size = @core.get_augmented_latent_size()
62 | kernel_size = %KERNEL_SIZE
63 | dilations = %DILATIONS
64 | amplitude_modulation = True
65 |
66 | # DISCRIMINATOR
67 | periods/discriminator.ConvNet:
68 | conv = @nn.Conv2d
69 | kernel_size = (5, 1)
70 |
71 | spectral/discriminator.ConvNet:
72 | conv = @nn.Conv1d
73 | kernel_size = 5
74 | stride = 2
75 |
76 | discriminator.MultiPeriodDiscriminator:
77 | periods = [2, 3, 5, 7, 11]
78 | convnet = @periods/discriminator.ConvNet
79 |
80 | discriminator.MultiScaleSpectralDiscriminator1d:
81 | scales = [4096, 2048, 1024, 512, 256]
82 | convnet = @spectral/discriminator.ConvNet
83 |
84 | discriminator.CombineDiscriminators:
85 | discriminators = [
86 | @discriminator.MultiPeriodDiscriminator,
87 | @discriminator.MultiScaleDiscriminator,
88 | # @discriminator.MultiScaleSpectralDiscriminator1d,
89 | ]
90 |
91 | feature_matching/core.mean_difference:
92 | relative = True
93 |
94 | # RAVE
95 | rave.RAVE:
96 | n_bands = %N_BAND
97 | discriminator = @discriminator.CombineDiscriminators
98 | valid_signal_crop = True
99 | num_skipped_features = 1
100 | decoder = @blocks.GeneratorV2
101 | phase_1_duration = 500000
102 | loss_weights = {'reg': 0.02, 'feature_matching': 20}
103 | update_discriminator_every = 4
104 | enable_pqmf_encode = True
105 | enable_pqmf_decode = False
106 | audio_monitor_epochs = 10
107 |
108 |
--------------------------------------------------------------------------------
/rave/configs/v2_small.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import core
5 | from rave import blocks
6 | from rave import discriminator
7 |
8 | import torch.nn as nn
9 |
10 | include "configs/v1.gin"
11 |
12 | KERNEL_SIZE = 3
13 | DILATIONS = [
14 | [1, 3, 9],
15 | [1, 3, 9],
16 | [1, 3, 9],
17 | [1, 3],
18 | ]
19 | RATIOS = [4, 2, 2, 2]
20 | CAPACITY = 48
21 | NOISE_AUGMENTATION = 0
22 |
23 | core.AudioDistanceV1.log_epsilon = 1e-7
24 |
25 | core.get_augmented_latent_size:
26 | latent_size = %LATENT_SIZE
27 | noise_augmentation = %NOISE_AUGMENTATION
28 |
29 | # ENCODER
30 | blocks.EncoderV2:
31 | data_size = %N_BAND
32 | capacity = %CAPACITY
33 | ratios = %RATIOS
34 | latent_size = %LATENT_SIZE
35 | n_out = 2
36 | kernel_size = %KERNEL_SIZE
37 | dilations = %DILATIONS
38 |
39 | blocks.VariationalEncoder:
40 | encoder = @variational/blocks.EncoderV2
41 |
42 | blocks.NoiseGeneratorV2:
43 | hidden_size = 64
44 | data_size = %N_BAND
45 | ratios = [2, 2, 2]
46 | noise_bands = 32
47 |
48 | # GENERATOR
49 | blocks.GeneratorV2:
50 | data_size = %N_BAND
51 | capacity = %CAPACITY
52 | ratios = %RATIOS
53 | latent_size = @core.get_augmented_latent_size()
54 | kernel_size = %KERNEL_SIZE
55 | dilations = %DILATIONS
56 | amplitude_modulation = True
57 | noise_module = @blocks.NoiseGeneratorV2
58 |
59 | # DISCRIMINATOR
60 | periods/discriminator.ConvNet:
61 | conv = @nn.Conv2d
62 | kernel_size = (5, 1)
63 |
64 | spectral/discriminator.ConvNet:
65 | conv = @nn.Conv1d
66 | kernel_size = 5
67 | stride = 2
68 |
69 | discriminator.MultiPeriodDiscriminator:
70 | periods = [2, 3, 5, 7, 11]
71 | convnet = @periods/discriminator.ConvNet
72 |
73 | discriminator.MultiScaleSpectralDiscriminator1d:
74 | scales = [4096, 2048, 1024, 512, 256]
75 | convnet = @spectral/discriminator.ConvNet
76 |
77 | discriminator.CombineDiscriminators:
78 | discriminators = [
79 | @discriminator.MultiPeriodDiscriminator,
80 | @discriminator.MultiScaleDiscriminator,
81 | # @discriminator.MultiScaleSpectralDiscriminator1d,
82 | ]
83 |
84 | feature_matching/core.mean_difference:
85 | relative = True
86 |
87 | # RAVE
88 | rave.RAVE:
89 | discriminator = @discriminator.CombineDiscriminators
90 | valid_signal_crop = True
91 | num_skipped_features = 1
92 | decoder = @blocks.GeneratorV2
93 | update_discriminator_every = 2
94 | weights = {
95 | 'feature_matching': 20,
96 | }
97 |
98 | rave.BetaWarmupCallback:
99 | initial_value = .01
100 | target_value = .01
101 | warmup_len = 300000
--------------------------------------------------------------------------------
/rave/configs/v2_with_augs.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import core
5 | from rave import dataset
6 | from rave import blocks
7 | from rave import discriminator
8 | from rave import transforms
9 |
10 | from torchaudio import transforms as ta_transforms
11 |
12 | import torch.nn as nn
13 |
14 | include "configs/v1.gin"
15 |
16 | KERNEL_SIZE = 3
17 | DILATIONS = [
18 | [1, 3, 9],
19 | [1, 3, 9],
20 | [1, 3, 9],
21 | [1, 3],
22 | ]
23 | ENCODER_RATIOS = [2, 2, 2]
24 | RATIOS = [4, 4, 4, 2]
25 | CAPACITY = 96
26 | NOISE_AUGMENTATION = 0
27 |
28 | # MELSPEC PROPERTIES
29 | N_FFT = 2048
30 | N_MELS = 128
31 | HOP_LENGTH = 256
32 | NUM_GRU_LAYERS = 2
33 |
34 | core.AudioDistanceV1.log_epsilon = 1e-7
35 |
36 | core.get_augmented_latent_size:
37 | latent_size = %LATENT_SIZE
38 | noise_augmentation = %NOISE_AUGMENTATION
39 |
40 | # AUGMENTATIONS
41 | dataset.get_dataset:
42 | augmentations = [
43 | @augmentations/transforms.RandomCompress(),
44 | # @augmentations/transforms.FrequencyMasking()
45 | ]
46 |
47 | augmentations/transforms.RandomCompress:
48 | amp_range = [-60,-10]
49 | threshold=-40
50 | prob = 0.5
51 |
52 | ta_transforms.MelSpectrogram:
53 | sample_rate = %SAMPLING_RATE
54 | n_fft = %N_FFT
55 | win_length = %N_FFT
56 | hop_length = %HOP_LENGTH
57 | normalized = True
58 | n_mels = %N_MELS
59 |
60 | # ENCODER
61 | blocks.EncoderV2:
62 | data_size = %N_MELS
63 | ratios = %ENCODER_RATIOS
64 | capacity = %CAPACITY
65 | latent_size = %LATENT_SIZE
66 | n_out = 2
67 | kernel_size = %KERNEL_SIZE
68 | dilations = %DILATIONS
69 |
70 | blocks.VariationalEncoder:
71 | encoder = @variational/blocks.EncoderV2
72 |
73 | # GENERATOR
74 | blocks.GeneratorV2:
75 | data_size = %N_BAND
76 | capacity = %CAPACITY
77 | ratios = %RATIOS
78 | latent_size = @core.get_augmented_latent_size()
79 | kernel_size = %KERNEL_SIZE
80 | dilations = %DILATIONS
81 | amplitude_modulation = True
82 |
83 | # DISCRIMINATOR
84 | periods/discriminator.ConvNet:
85 | conv = @nn.Conv2d
86 | kernel_size = (5, 1)
87 |
88 | spectral/discriminator.ConvNet:
89 | conv = @nn.Conv1d
90 | kernel_size = 5
91 | stride = 2
92 |
93 | discriminator.MultiPeriodDiscriminator:
94 | periods = [2, 3, 5, 7, 11]
95 | convnet = @periods/discriminator.ConvNet
96 |
97 | discriminator.MultiScaleSpectralDiscriminator1d:
98 | scales = [4096, 2048, 1024, 512, 256]
99 | convnet = @spectral/discriminator.ConvNet
100 |
101 | discriminator.CombineDiscriminators:
102 | discriminators = [
103 | @discriminator.MultiPeriodDiscriminator,
104 | @discriminator.MultiScaleDiscriminator,
105 | # @discriminator.MultiScaleSpectralDiscriminator1d,
106 | ]
107 |
108 | feature_matching/core.mean_difference:
109 | relative = True
110 |
111 | # RAVE
112 | rave.RAVE:
113 | discriminator = @discriminator.CombineDiscriminators
114 | valid_signal_crop = True
115 | num_skipped_features = 1
116 | decoder = @blocks.GeneratorV2
117 | phase_1_duration = 1000000
118 | spectrogram = @ta_transforms.MelSpectrogram()
119 | update_discriminator_every = 4
120 | input_mode = "mel"
121 | output_mode = "pqmf"
122 | audio_monitor_epochs = 10
123 |
124 |
--------------------------------------------------------------------------------
/rave/configs/v3.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | include "configs/v2.gin"
4 | include "configs/adain.gin"
5 | include "configs/snake.gin"
6 | include "configs/descript_discriminator.gin"
7 |
8 | import rave
9 |
10 | rave.BetaWarmupCallback:
11 | initial_value = 1e-6
12 | target_value = 5e-2
13 | warmup_len = 20000
--------------------------------------------------------------------------------
/rave/configs/wasserstein.gin:
--------------------------------------------------------------------------------
1 | from __gin__ import dynamic_registration
2 |
3 | import rave
4 | from rave import blocks
5 |
6 | LATENT_SIZE = 16
7 | NOISE_AUGMENTATION = 128
8 | PHASE_1_DURATION = 200000
9 |
10 | blocks.EncoderV2.n_out = 1
11 |
12 | blocks.WasserteinEncoder:
13 | encoder_cls = @blocks.EncoderV2
14 | noise_augmentation = %NOISE_AUGMENTATION
15 |
16 | rave.RAVE:
17 | encoder = @blocks.WasserteinEncoder
18 | phase_1_duration = %PHASE_1_DURATION
19 | weights = {
20 | 'fullband_spectral_distance': 2,
21 | 'multiband_spectral_distance': 2,
22 | 'adversarial': 2,
23 | }
24 |
25 | rave.BetaWarmupCallback:
26 | initial_value = 100
27 | target_value = 100
28 | warmup_len = 1
--------------------------------------------------------------------------------
/rave/core.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from random import random
5 | from typing import Callable, Optional, Sequence, Union
6 |
7 | import GPUtil as gpu
8 | import librosa as li
9 | import lmdb
10 | import numpy as np
11 | import pytorch_lightning as pl
12 | import torch
13 | import torch.fft as fft
14 | import torch.nn as nn
15 | import torchaudio
16 | from einops import rearrange
17 | from scipy.signal import lfilter
18 |
19 |
20 | def mod_sigmoid(x):
21 | return 2 * torch.sigmoid(x)**2.3 + 1e-7
22 |
23 |
24 | def random_angle(min_f=20, max_f=8000, sr=24000):
25 | min_f = np.log(min_f)
26 | max_f = np.log(max_f)
27 | rand = np.exp(random() * (max_f - min_f) + min_f)
28 | rand = 2 * np.pi * rand / sr
29 | return rand
30 |
31 |
32 | def get_augmented_latent_size(latent_size: int, noise_augmentation: int):
33 | return latent_size + noise_augmentation
34 |
35 |
36 | def pole_to_z_filter(omega, amplitude=.9):
37 | z0 = amplitude * np.exp(1j * omega)
38 | a = [1, -2 * np.real(z0), abs(z0)**2]
39 | b = [abs(z0)**2, -2 * np.real(z0), 1]
40 | return b, a
41 |
42 | def random_phase_mangle(x, min_f, max_f, amp, sr):
43 | angle = random_angle(min_f, max_f, sr)
44 | b, a = pole_to_z_filter(angle, amp)
45 | return lfilter(b, a, x)
46 |
47 |
48 | def amp_to_impulse_response(amp, target_size):
49 | """
50 | transforms frequency amps to ir on the last dimension
51 | """
52 | amp = torch.stack([amp, torch.zeros_like(amp)], -1)
53 | amp = torch.view_as_complex(amp)
54 | amp = fft.irfft(amp)
55 |
56 | filter_size = amp.shape[-1]
57 |
58 | amp = torch.roll(amp, filter_size // 2, -1)
59 | win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
60 |
61 | amp = amp * win
62 |
63 | amp = nn.functional.pad(
64 | amp,
65 | (0, int(target_size) - int(filter_size)),
66 | )
67 | amp = torch.roll(amp, -filter_size // 2, -1)
68 |
69 | return amp
70 |
71 | def fft_convolve(signal, kernel):
72 | """
73 | convolves signal by kernel on the last dimension
74 | """
75 | signal = nn.functional.pad(signal, (0, signal.shape[-1]))
76 | kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
77 |
78 | output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
79 | output = output[..., output.shape[-1] // 2:]
80 |
81 | return output
82 |
83 |
84 | def get_ckpts(folder, name=None):
85 | ckpts = map(str, Path(folder).rglob("*.ckpt"))
86 | if name:
87 | ckpts = filter(lambda e: mode in os.path.basename(str(e)), ckpts)
88 | ckpts = sorted(ckpts, key=os.path.getmtime)
89 | return ckpts
90 |
91 |
92 | def get_versions(folder):
93 | ckpts = map(str, Path(folder).rglob("version_*"))
94 | ckpts = filter(lambda x: os.path.isdir(x), ckpts)
95 | return sorted(Path(dirpath).iterdir(), key=os.path.getmtime)
96 |
97 | def search_for_config(folder):
98 | if os.path.isfile(folder):
99 | folder = os.path.dirname(folder)
100 | configs = list(map(str, Path(folder).rglob("config.gin")))
101 | if configs != []:
102 | return os.path.abspath(os.path.join(folder, "config.gin"))
103 | configs = list(map(str, Path(folder).rglob("../config.gin")))
104 | if configs != []:
105 | return os.path.abspath(os.path.join(folder, "../config.gin"))
106 | configs = list(map(str, Path(folder).rglob("../../config.gin")))
107 | if configs != []:
108 | return os.path.abspath(os.path.join(folder, "../../config.gin"))
109 | else:
110 | return None
111 |
112 |
113 |
114 | def search_for_run(run_path, name=None):
115 | if run_path is None: return None
116 | if ".ckpt" in run_path: return run_path
117 | ckpts = get_ckpts(run_path)
118 | if len(ckpts) != 0:
119 | return ckpts[-1]
120 | else:
121 | print('No checkpoint found')
122 | return None
123 |
124 |
125 | def setup_gpu():
126 | return gpu.getAvailable(maxMemory=.05)
127 |
128 |
129 | def get_beta_kl(step, warmup, min_beta, max_beta):
130 | if step > warmup: return max_beta
131 | t = step / warmup
132 | min_beta_log = np.log(min_beta)
133 | max_beta_log = np.log(max_beta)
134 | beta_log = t * (max_beta_log - min_beta_log) + min_beta_log
135 | return np.exp(beta_log)
136 |
137 |
138 | def get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta):
139 | return get_beta_kl(step % cycle_size, cycle_size // 2, min_beta, max_beta)
140 |
141 |
142 | def get_beta_kl_cyclic_annealed(step, cycle_size, warmup, min_beta, max_beta):
143 | min_beta = get_beta_kl(step, warmup, min_beta, max_beta)
144 | return get_beta_kl_cyclic(step, cycle_size, min_beta, max_beta)
145 |
146 |
147 | def n_fft_to_num_bands(n_fft: int) -> int:
148 | return n_fft // 2 + 1
149 |
150 |
151 | def hinge_gan(score_real, score_fake):
152 | loss_dis = torch.relu(1 - score_real) + torch.relu(1 + score_fake)
153 | loss_dis = loss_dis.mean()
154 | loss_gen = -score_fake.mean()
155 | return loss_dis, loss_gen
156 |
157 |
158 | def ls_gan(score_real, score_fake):
159 | loss_dis = (score_real - 1).pow(2) + score_fake.pow(2)
160 | loss_dis = loss_dis.mean()
161 | loss_gen = (score_fake - 1).pow(2).mean()
162 | return loss_dis, loss_gen
163 |
164 |
165 | def nonsaturating_gan(score_real, score_fake):
166 | score_real = torch.clamp(torch.sigmoid(score_real), 1e-7, 1 - 1e-7)
167 | score_fake = torch.clamp(torch.sigmoid(score_fake), 1e-7, 1 - 1e-7)
168 | loss_dis = -(torch.log(score_real) + torch.log(1 - score_fake)).mean()
169 | loss_gen = -torch.log(score_fake).mean()
170 | return loss_dis, loss_gen
171 |
172 | def get_minimum_size(model):
173 | N = 2**15
174 | device = next(iter(model.parameters())).device
175 | x = torch.randn(1, model.n_channels, N, requires_grad=True, device=device)
176 | z = model.encode(x)
177 | return int(x.shape[-1] / z.shape[-1])
178 |
179 |
180 | @torch.enable_grad()
181 | def get_rave_receptive_field(model, n_channels=1):
182 | N = 2**15
183 | model.eval()
184 | device = next(iter(model.parameters())).device
185 |
186 | for module in model.modules():
187 | if hasattr(module, 'gru_state') or hasattr(module, 'temporal'):
188 | module.disable()
189 |
190 | while True:
191 | x = torch.randn(1, model.n_channels, N, requires_grad=True, device=device)
192 |
193 | z = model.encode(x)
194 | z = model.encoder.reparametrize(z)[0]
195 | y = model.decode(z)
196 |
197 | y[0, 0, N // 2].backward()
198 | assert x.grad is not None, "input has no grad"
199 |
200 | grad = x.grad.data.reshape(-1)
201 | left_grad, right_grad = grad.chunk(2, 0)
202 | large_enough = (left_grad[0] == 0) and right_grad[-1] == 0
203 | if large_enough:
204 | break
205 | else:
206 | N *= 2
207 | left_receptive_field = len(left_grad[left_grad != 0])
208 | right_receptive_field = len(right_grad[right_grad != 0])
209 | model.zero_grad()
210 |
211 | for module in model.modules():
212 | if hasattr(module, 'gru_state') or hasattr(module, 'temporal'):
213 | module.enable()
214 | ratio = x.shape[-1] // z.shape[-1]
215 | rate = model.sr / ratio
216 | print(f"Compression ratio: {ratio}x (~{rate:.1f}Hz @ {model.sr}Hz)")
217 | return left_receptive_field, right_receptive_field
218 |
219 |
220 | def valid_signal_crop(x, left_rf, right_rf):
221 | dim = x.shape[1]
222 | x = x[..., left_rf.item() // dim:]
223 | if right_rf.item():
224 | x = x[..., :-right_rf.item() // dim]
225 | return x
226 |
227 |
228 | def relative_distance(
229 | x: torch.Tensor,
230 | y: torch.Tensor,
231 | norm: Callable[[torch.Tensor], torch.Tensor],
232 | ) -> torch.Tensor:
233 | return norm(x - y) / norm(x)
234 |
235 |
236 | def mean_difference(target: torch.Tensor,
237 | value: torch.Tensor,
238 | norm: str = 'L1',
239 | relative: bool = False):
240 | diff = target - value
241 | if norm == 'L1':
242 | diff = diff.abs().mean()
243 | if relative:
244 | diff = diff / target.abs().mean()
245 | return diff
246 | elif norm == 'L2':
247 | diff = (diff * diff).mean()
248 | if relative:
249 | diff = diff / (target * target).mean()
250 | return diff
251 | else:
252 | raise Exception(f'Norm must be either L1 or L2, got {norm}')
253 |
254 |
255 | class MelScale(nn.Module):
256 |
257 | def __init__(self, sample_rate: int, n_fft: int, n_mels: int) -> None:
258 | super().__init__()
259 | mel = li.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mels)
260 | mel = torch.from_numpy(mel).float()
261 | self.register_buffer('mel', mel)
262 |
263 | def forward(self, x: torch.Tensor) -> torch.Tensor:
264 | mel = self.mel.type_as(x)
265 | y = torch.einsum('bft,mf->bmt', x, mel)
266 | return y
267 |
268 |
269 | class MultiScaleSTFT(nn.Module):
270 |
271 | def __init__(self,
272 | scales: Sequence[int],
273 | sample_rate: int,
274 | magnitude: bool = True,
275 | normalized: bool = False,
276 | num_mels: Optional[int] = None) -> None:
277 | super().__init__()
278 | self.scales = scales
279 | self.magnitude = magnitude
280 | self.num_mels = num_mels
281 |
282 | self.stfts = []
283 | self.mel_scales = []
284 | for scale in scales:
285 | self.stfts.append(
286 | torchaudio.transforms.Spectrogram(
287 | n_fft=scale,
288 | win_length=scale,
289 | hop_length=scale // 4,
290 | normalized=normalized,
291 | power=None,
292 | ))
293 | if num_mels is not None:
294 | self.mel_scales.append(
295 | MelScale(
296 | sample_rate=sample_rate,
297 | n_fft=scale,
298 | n_mels=num_mels,
299 | ))
300 | else:
301 | self.mel_scales.append(None)
302 |
303 | self.stfts = nn.ModuleList(self.stfts)
304 | self.mel_scales = nn.ModuleList(self.mel_scales)
305 |
306 | def forward(self, x: torch.Tensor) -> Sequence[torch.Tensor]:
307 | x = rearrange(x, "b c t -> (b c) t")
308 | stfts = []
309 | for stft, mel in zip(self.stfts, self.mel_scales):
310 | y = stft(x)
311 | if mel is not None:
312 | y = mel(y)
313 | if self.magnitude:
314 | y = y.abs()
315 | else:
316 | y = torch.stack([y.real, y.imag], -1)
317 | stfts.append(y)
318 |
319 | return stfts
320 |
321 |
322 | class AudioDistanceV1(nn.Module):
323 |
324 | def __init__(self, multiscale_stft: Callable[[], nn.Module],
325 | log_epsilon: float) -> None:
326 | super().__init__()
327 | self.multiscale_stft = multiscale_stft()
328 | self.log_epsilon = log_epsilon
329 |
330 | def forward(self, x: torch.Tensor, y: torch.Tensor):
331 | stfts_x = self.multiscale_stft(x)
332 | stfts_y = self.multiscale_stft(y)
333 | distance = 0.
334 |
335 | for x, y in zip(stfts_x, stfts_y):
336 | logx = torch.log(x + self.log_epsilon)
337 | logy = torch.log(y + self.log_epsilon)
338 |
339 | lin_distance = mean_difference(x, y, norm='L2', relative=True)
340 | log_distance = mean_difference(logx, logy, norm='L1')
341 |
342 | distance = distance + lin_distance + log_distance
343 |
344 | return {'spectral_distance': distance}
345 |
346 |
347 | class WeightedInstantaneousSpectralDistance(nn.Module):
348 |
349 | def __init__(self,
350 | multiscale_stft: Callable[[], MultiScaleSTFT],
351 | weighted: bool = False) -> None:
352 | super().__init__()
353 | self.multiscale_stft = multiscale_stft()
354 | self.weighted = weighted
355 |
356 | def phase_to_instantaneous_frequency(self,
357 | x: torch.Tensor) -> torch.Tensor:
358 | x = self.unwrap(x)
359 | x = self.derivative(x)
360 | return x
361 |
362 | def derivative(self, x: torch.Tensor) -> torch.Tensor:
363 | return x[..., 1:] - x[..., :-1]
364 |
365 | def unwrap(self, x: torch.Tensor) -> torch.Tensor:
366 | x = self.derivative(x)
367 | x = (x + np.pi) % (2 * np.pi)
368 | return (x - np.pi).cumsum(-1)
369 |
370 | def forward(self, target: torch.Tensor, pred: torch.Tensor):
371 | stfts_x = self.multiscale_stft(target)
372 | stfts_y = self.multiscale_stft(pred)
373 | spectral_distance = 0.
374 | phase_distance = 0.
375 |
376 | for x, y in zip(stfts_x, stfts_y):
377 | assert x.shape[-1] == 2
378 |
379 | x = torch.view_as_complex(x)
380 | y = torch.view_as_complex(y)
381 |
382 | # AMPLITUDE DISTANCE
383 | x_abs = x.abs()
384 | y_abs = y.abs()
385 |
386 | logx = torch.log1p(x_abs)
387 | logy = torch.log1p(y_abs)
388 |
389 | lin_distance = mean_difference(x_abs,
390 | y_abs,
391 | norm='L2',
392 | relative=True)
393 | log_distance = mean_difference(logx, logy, norm='L1')
394 |
395 | spectral_distance = spectral_distance + lin_distance + log_distance
396 |
397 | # PHASE DISTANCE
398 | x_if = self.phase_to_instantaneous_frequency(x.angle())
399 | y_if = self.phase_to_instantaneous_frequency(y.angle())
400 |
401 | if self.weighted:
402 | mask = torch.clip(torch.log1p(x_abs[..., 2:]), 0, 1)
403 | x_if = x_if * mask
404 | y_if = y_if * mask
405 |
406 | phase_distance = phase_distance + mean_difference(
407 | x_if, y_if, norm='L2')
408 |
409 | return {
410 | 'spectral_distance': spectral_distance,
411 | 'phase_distance': phase_distance
412 | }
413 |
414 |
415 | class EncodecAudioDistance(nn.Module):
416 |
417 | def __init__(self, scales: int,
418 | spectral_distance: Callable[[int], nn.Module]) -> None:
419 | super().__init__()
420 | self.waveform_distance = WaveformDistance(norm='L1')
421 | self.spectral_distances = nn.ModuleList(
422 | [spectral_distance(scale) for scale in scales])
423 |
424 | def forward(self, x, y):
425 | waveform_distance = self.waveform_distance(x, y)
426 | spectral_distance = 0
427 | for dist in self.spectral_distances:
428 | spectral_distance = spectral_distance + dist(x, y)
429 |
430 | return {
431 | 'waveform_distance': waveform_distance,
432 | 'spectral_distance': spectral_distance
433 | }
434 |
435 |
436 | class WaveformDistance(nn.Module):
437 |
438 | def __init__(self, norm: str) -> None:
439 | super().__init__()
440 | self.norm = norm
441 |
442 | def forward(self, x, y):
443 | return mean_difference(y, x, self.norm)
444 |
445 |
446 | class SpectralDistance(nn.Module):
447 |
448 | def __init__(
449 | self,
450 | n_fft: int,
451 | sampling_rate: int,
452 | norm: Union[str, Sequence[str]],
453 | power: Union[int, None],
454 | normalized: bool,
455 | mel: Optional[int] = None,
456 | ) -> None:
457 | super().__init__()
458 | if mel:
459 | self.spec = torchaudio.transforms.MelSpectrogram(
460 | sampling_rate,
461 | n_fft,
462 | hop_length=n_fft // 4,
463 | n_mels=mel,
464 | power=power,
465 | normalized=normalized,
466 | center=False,
467 | pad_mode=None,
468 | )
469 | else:
470 | self.spec = torchaudio.transforms.Spectrogram(
471 | n_fft,
472 | hop_length=n_fft // 4,
473 | power=power,
474 | normalized=normalized,
475 | center=False,
476 | pad_mode=None,
477 | )
478 |
479 | if isinstance(norm, str):
480 | norm = (norm, )
481 | self.norm = norm
482 |
483 | def forward(self, x, y):
484 | x = self.spec(x)
485 | y = self.spec(y)
486 |
487 | distance = 0
488 | for norm in self.norm:
489 | distance = distance + mean_difference(y, x, norm)
490 | return distance
491 |
492 |
493 | class ProgressLogger(object):
494 |
495 | def __init__(self, name: str) -> None:
496 | self.env = lmdb.open("status")
497 | self.name = name
498 |
499 | def update(self, **new_state):
500 | current_state = self.__call__()
501 | with self.env.begin(write=True) as txn:
502 | current_state.update(new_state)
503 | current_state = json.dumps(current_state)
504 | txn.put(self.name.encode(), current_state.encode())
505 |
506 | def __call__(self):
507 | with self.env.begin(write=True) as txn:
508 | current_state = txn.get(self.name.encode())
509 | if current_state is not None:
510 | current_state = json.loads(current_state.decode())
511 | else:
512 | current_state = {}
513 | return current_state
514 |
515 |
516 | class LoggerCallback(pl.Callback):
517 |
518 | def __init__(self, logger: ProgressLogger) -> None:
519 | super().__init__()
520 | self.state = {'step': 0, 'warmed': False}
521 | self.logger = logger
522 |
523 | def on_train_batch_end(self, trainer, pl_module, outputs, batch,
524 | batch_idx) -> None:
525 | self.state['step'] += 1
526 | self.state['warmed'] = pl_module.warmed_up
527 |
528 | if not self.state['step'] % 100:
529 | self.logger.update(**self.state)
530 |
531 | def state_dict(self):
532 | return self.state.copy()
533 |
534 | def load_state_dict(self, state_dict):
535 | self.state.update(state_dict)
536 |
537 |
538 | class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
539 | def __init__(self, step_period: int = None, **kwargs):
540 | super().__init__(**kwargs)
541 | self.step_period = step_period
542 | self.__counter = 0
543 |
544 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
545 | self.__counter += 1
546 | if self.step_period:
547 | if self.__counter % self.step_period == 0:
548 | filename = os.path.join(self.dirpath, f"epoch_{self.__counter}{self.FILE_EXTENSION}")
549 | self._save_checkpoint(trainer, filename)
550 |
551 |
552 | def get_valid_extensions():
553 | import torchaudio
554 | backend = torchaudio.get_audio_backend()
555 | if backend in ["sox_io", "sox"]:
556 | return ['.'+f for f in torchaudio.utils.sox_utils.list_read_formats()]
557 | elif backend == "ffmpeg":
558 | return ['.'+f for f in torchaudio.utils.ffmpeg_utils.get_audio_decoders()]
559 | elif backend == "soundfile":
560 | return ['.wav', '.flac', '.ogg', '.aiff', '.aif', '.aifc']
561 |
562 |
--------------------------------------------------------------------------------
/rave/dataset.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import logging
3 | import math
4 | import os
5 | import subprocess
6 | from random import random
7 | from typing import Dict, Iterable, Optional, Sequence, Union, Callable
8 |
9 | import gin
10 | import lmdb
11 | import numpy as np
12 | import requests
13 | import torch
14 | import torchaudio
15 | import yaml
16 | from scipy.signal import lfilter
17 | from torch.utils import data
18 | from tqdm import tqdm
19 | from . import transforms
20 | from udls import AudioExample as AudioExampleWrapper
21 | from udls.generated import AudioExample
22 |
23 |
24 | def get_derivator_integrator(sr: int):
25 | alpha = 1 / (1 + 1 / sr * 2 * np.pi * 10)
26 | derivator = ([.5, -.5], [1])
27 | integrator = ([alpha**2, -alpha**2], [1, -2 * alpha, alpha**2])
28 |
29 | return lambda x: lfilter(*derivator, x), lambda x: lfilter(*integrator, x)
30 |
31 |
32 | class AudioDataset(data.Dataset):
33 |
34 | @property
35 | def env(self) -> lmdb.Environment:
36 | if self._env is None:
37 | self._env = lmdb.open(self._db_path, lock=False)
38 | return self._env
39 |
40 | @property
41 | def keys(self) -> Sequence[str]:
42 | if self._keys is None:
43 | with self.env.begin() as txn:
44 | self._keys = list(txn.cursor().iternext(values=False))
45 | return self._keys
46 |
47 | def __init__(self,
48 | db_path: str,
49 | audio_key: str = 'waveform',
50 | transforms: Optional[transforms.Transform] = None,
51 | n_channels: int = 1) -> None:
52 | super().__init__()
53 | self._db_path = db_path
54 | self._audio_key = audio_key
55 | self._env = None
56 | self._keys = None
57 | self._transforms = transforms
58 | self._n_channels = n_channels
59 | lens = []
60 | with self.env.begin() as txn:
61 | for k in self.keys:
62 | ae = AudioExample.FromString(txn.get(k))
63 | lens.append(np.frombuffer(ae.buffers['waveform'].data, dtype=np.int16).shape)
64 |
65 |
66 | def __len__(self):
67 | return len(self.keys)
68 |
69 | def __getitem__(self, index):
70 | with self.env.begin() as txn:
71 | ae = AudioExample.FromString(txn.get(self.keys[index]))
72 |
73 | buffer = ae.buffers[self._audio_key]
74 | assert buffer.precision == AudioExample.Precision.INT16
75 |
76 | audio = np.frombuffer(buffer.data, dtype=np.int16)
77 | audio = audio.astype(np.float32) / (2**15 - 1)
78 | audio = audio.reshape(self._n_channels, -1)
79 |
80 | if self._transforms is not None:
81 | audio = self._transforms(audio)
82 |
83 | return audio
84 |
85 |
86 | class LazyAudioDataset(data.Dataset):
87 |
88 | @property
89 | def env(self) -> lmdb.Environment:
90 | if self._env is None:
91 | self._env = lmdb.open(self._db_path, lock=False)
92 | return self._env
93 |
94 | @property
95 | def keys(self) -> Sequence[str]:
96 | if self._keys is None:
97 | with self.env.begin() as txn:
98 | self._keys = list(txn.cursor().iternext(values=False))
99 | return self._keys
100 |
101 | def __init__(self,
102 | db_path: str,
103 | n_signal: int,
104 | sampling_rate: int,
105 | transforms: Optional[transforms.Transform] = None,
106 | n_channels: int = 1) -> None:
107 | super().__init__()
108 | self._db_path = db_path
109 | self._env = None
110 | self._keys = None
111 | self._transforms = transforms
112 | self._n_signal = n_signal
113 | self._sampling_rate = sampling_rate
114 | self._n_channels = n_channels
115 |
116 | self.parse_dataset()
117 |
118 | def parse_dataset(self):
119 | items = []
120 | for key in tqdm(self.keys, desc='Discovering dataset'):
121 | with self.env.begin() as txn:
122 | ae = AudioExample.FromString(txn.get(key))
123 | length = float(ae.metadata['length'])
124 | n_signal = int(math.floor(length * self._sampling_rate))
125 | n_chunks = n_signal // self._n_signal
126 | items.append(n_chunks)
127 | items = np.asarray(items)
128 | items = np.cumsum(items)
129 | self.items = items
130 |
131 | def __len__(self):
132 | return self.items[-1]
133 |
134 | def __getitem__(self, index):
135 | audio_id = np.where(index < self.items)[0][0]
136 | if audio_id:
137 | index -= self.items[audio_id - 1]
138 |
139 | key = self.keys[audio_id]
140 |
141 | with self.env.begin() as txn:
142 | ae = AudioExample.FromString(txn.get(key))
143 |
144 | audio = extract_audio(
145 | ae.metadata['path'],
146 | self._n_signal,
147 | self._sampling_rate,
148 | index * self._n_signal,
149 | int(ae.metadata['channels']),
150 | self._n_channels
151 | )
152 |
153 | if self._transforms is not None:
154 | audio = self._transforms(audio)
155 |
156 | return audio
157 |
158 | def get_channels_from_dataset(db_path):
159 | with open(os.path.join(db_path, 'metadata.yaml'), 'r') as metadata:
160 | metadata = yaml.safe_load(metadata)
161 | return metadata.get('channels')
162 |
163 | def get_training_channels(db_path, target_channels):
164 | dataset_channels = get_channels_from_dataset(db_path)
165 | if dataset_channels is not None:
166 | if target_channels > dataset_channels:
167 | raise RuntimeError('[Error] Requested number of channels is %s, but dataset has %s channels')%(FLAGS.channels, dataset_channels)
168 | n_channels = target_channels or dataset_channels
169 | if n_channels is None:
170 | print('[Warning] channels not found in dataset, taking 1 by default')
171 | n_channels = 1
172 | return n_channels
173 |
174 | class HTTPAudioDataset(data.Dataset):
175 |
176 | def __init__(self, db_path: str):
177 | super().__init__()
178 | self.db_path = db_path
179 | logging.info("starting remote dataset session")
180 | self.length = int(requests.get("/".join([db_path, "len"])).text)
181 | logging.info("connection established !")
182 |
183 | def __len__(self):
184 | return self.length
185 |
186 | def __getitem__(self, index):
187 | example = requests.get("/".join([
188 | self.db_path,
189 | "get",
190 | f"{index}",
191 | ])).text
192 | example = AudioExampleWrapper(base64.b64decode(example)).get("audio")
193 | return example.copy()
194 |
195 |
196 | def normalize_signal(x: np.ndarray, max_gain_db: int = 30):
197 | peak = np.max(abs(x))
198 | if peak == 0: return x
199 |
200 | log_peak = 20 * np.log10(peak)
201 | log_gain = min(max_gain_db, -log_peak)
202 | gain = 10**(log_gain / 20)
203 |
204 | return x * gain
205 |
206 | @gin.configurable
207 | def get_dataset(db_path,
208 | sr,
209 | n_signal,
210 | derivative: bool = False,
211 | normalize: bool = False,
212 | rand_pitch: bool = False,
213 | augmentations: Union[None, Iterable[Callable]] = None,
214 | n_channels: int = 1):
215 | if db_path[:4] == "http":
216 | return HTTPAudioDataset(db_path=db_path)
217 | with open(os.path.join(db_path, 'metadata.yaml'), 'r') as metadata:
218 | metadata = yaml.safe_load(metadata)
219 |
220 | sr_dataset = metadata.get('sr', 44100)
221 | lazy = metadata['lazy']
222 |
223 | transform_list = [
224 | lambda x: x.astype(np.float32),
225 | transforms.RandomCrop(n_signal),
226 | transforms.RandomApply(
227 | lambda x: random_phase_mangle(x, 20, 2000, .99, sr_dataset),
228 | p=.8,
229 | ),
230 | transforms.Dequantize(16),
231 | ]
232 |
233 | if rand_pitch:
234 | rand_pitch = list(map(float, rand_pitch))
235 | assert len(rand_pitch) == 2, "rand_pitch must be given two floats"
236 | transform_list.insert(1, transforms.RandomPitch(n_signal, rand_pitch))
237 |
238 | if sr_dataset != sr:
239 | transform_list.append(transforms.Resample(sr_dataset, sr))
240 |
241 | if normalize:
242 | transform_list.append(normalize_signal)
243 |
244 | if derivative:
245 | transform_list.append(get_derivator_integrator(sr)[0])
246 |
247 | if augmentations:
248 | transform_list.extend(augmentations)
249 |
250 | transform_list.append(lambda x: x.astype(np.float32))
251 |
252 | transform_list = transforms.Compose(transform_list)
253 |
254 | if lazy:
255 | return LazyAudioDataset(db_path, n_signal, sr_dataset, transform_list, n_channels)
256 | else:
257 | return AudioDataset(
258 | db_path,
259 | transforms=transform_list,
260 | n_channels=n_channels
261 | )
262 |
263 |
264 | @gin.configurable
265 | def split_dataset(dataset, percent, max_residual: Optional[int] = None):
266 | split1 = max((percent * len(dataset)) // 100, 1)
267 | split2 = len(dataset) - split1
268 | if max_residual is not None:
269 | split2 = min(max_residual, split2)
270 | split1 = len(dataset) - split2
271 | print(f'train set: {split1} examples')
272 | print(f'val set: {split2} examples')
273 | split1, split2 = data.random_split(
274 | dataset,
275 | [split1, split2],
276 | generator=torch.Generator().manual_seed(42),
277 | )
278 | return split1, split2
279 |
280 |
281 | def random_angle(min_f=20, max_f=8000, sr=24000):
282 | min_f = np.log(min_f)
283 | max_f = np.log(max_f)
284 | rand = np.exp(random() * (max_f - min_f) + min_f)
285 | rand = 2 * np.pi * rand / sr
286 | return rand
287 |
288 |
289 | def pole_to_z_filter(omega, amplitude=.9):
290 | z0 = amplitude * np.exp(1j * omega)
291 | a = [1, -2 * np.real(z0), abs(z0)**2]
292 | b = [abs(z0)**2, -2 * np.real(z0), 1]
293 | return b, a
294 |
295 |
296 | def random_phase_mangle(x, min_f, max_f, amp, sr):
297 | angle = random_angle(min_f, max_f, sr)
298 | b, a = pole_to_z_filter(angle, amp)
299 | return lfilter(b, a, x)
300 |
301 | def extract_audio(path: str, n_signal: int, sr: int,
302 | start_sample: int, input_channels: int, channels: int) -> Iterable[np.ndarray]:
303 | # channel mapping
304 | channel_map = range(channels)
305 | if input_channels < channels:
306 | channel_map = (math.ceil(channels / input_channels) * list(range(input_channels)))[:channels]
307 | # time information
308 | start_sec = start_sample / sr
309 | length = (n_signal * 2) / sr
310 | chunks = []
311 | for i in channel_map:
312 | process = subprocess.Popen(
313 | [
314 | 'ffmpeg', '-v', 'error',
315 | '-ss',
316 | str(start_sec),
317 | '-i',
318 | path,
319 | '-ar',
320 | str(sr),
321 | '-filter_complex',
322 | 'channelmap=%d-0'%i,
323 | '-t',
324 | str(length),
325 | '-f',
326 | 's16le',
327 | '-'
328 | ],
329 | stdout=subprocess.PIPE,
330 | )
331 |
332 | chunk = process.communicate()[0]
333 | chunk = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 2**15
334 | chunk = np.concatenate([chunk, np.zeros(n_signal)], -1)
335 | chunks.append(chunk)
336 | return np.stack(chunks)[:, :(n_signal*2)]
337 |
--------------------------------------------------------------------------------
/rave/descript_discriminator.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/descriptinc/descript-audio-codec
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 | from torchaudio.transforms import Spectrogram
10 |
11 | from .pqmf import kaiser_filter
12 |
13 |
14 | def WNConv1d(*args, **kwargs):
15 | act = kwargs.pop("act", True)
16 | conv = weight_norm(nn.Conv1d(*args, **kwargs))
17 | if not act:
18 | return conv
19 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
20 |
21 |
22 | def WNConv2d(*args, **kwargs):
23 | act = kwargs.pop("act", True)
24 | conv = weight_norm(nn.Conv2d(*args, **kwargs))
25 | if not act:
26 | return conv
27 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
28 |
29 |
30 | class MPD(nn.Module):
31 |
32 | def __init__(self, period, n_channels: int = 1):
33 | super().__init__()
34 | self.period = period
35 | self.convs = nn.ModuleList([
36 | WNConv2d(n_channels, 32, (5, 1), (3, 1), padding=(2, 0)),
37 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
38 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
39 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
40 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
41 | ])
42 | self.conv_post = WNConv2d(1024,
43 | 1,
44 | kernel_size=(3, 1),
45 | padding=(1, 0),
46 | act=False)
47 |
48 | def pad_to_period(self, x):
49 | t = x.shape[-1]
50 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
51 | return x
52 |
53 | def forward(self, x):
54 | fmap = []
55 |
56 | x = self.pad_to_period(x)
57 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
58 |
59 | for layer in self.convs:
60 | x = layer(x)
61 | fmap.append(x)
62 |
63 | x = self.conv_post(x)
64 | fmap.append(x)
65 |
66 | return fmap
67 |
68 |
69 | class MSD(nn.Module):
70 |
71 | def __init__(self, scale: int, n_channels: int = 1):
72 | super().__init__()
73 | self.convs = nn.ModuleList([
74 | WNConv1d(n_channels, 16, 15, 1, padding=7),
75 | WNConv1d(16, 64, 41, 4, groups=4, padding=20),
76 | WNConv1d(64, 256, 41, 4, groups=16, padding=20),
77 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
78 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
79 | WNConv1d(1024, 1024, 5, 1, padding=2),
80 | ])
81 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
82 |
83 | self.scale = scale
84 |
85 | if self.scale != 1:
86 | wc = np.pi / self.scale
87 | filt = kaiser_filter(wc, 140)
88 | if not len(filt) % 2:
89 | filt = np.pad(filt, (1, 0))
90 |
91 | self.register_buffer(
92 | "downsampler",
93 | torch.from_numpy(filt).reshape(1, 1, -1).float())
94 |
95 | def forward(self, x):
96 | if self.scale != 1:
97 | x = nn.functional.conv1d(
98 | x,
99 | self.downsampler,
100 | padding=self.downsampler.shape[-1] // 2,
101 | stride=self.scale,
102 | )
103 |
104 | fmap = []
105 |
106 | for l in self.convs:
107 | x = l(x)
108 | fmap.append(x)
109 | x = self.conv_post(x)
110 | fmap.append(x)
111 |
112 | return fmap
113 |
114 |
115 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
116 |
117 |
118 | class MRD(nn.Module):
119 |
120 | def __init__(
121 | self,
122 | window_length: int,
123 | hop_factor: float = 0.25,
124 | sample_rate: int = 44100,
125 | bands: list = BANDS,
126 | n_channels: int = 1
127 | ):
128 | super().__init__()
129 |
130 | self.window_length = window_length
131 | self.hop_factor = hop_factor
132 | self.sample_rate = sample_rate
133 |
134 | n_fft = window_length // 2 + 1
135 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
136 | self.bands = bands
137 |
138 | ch = 32
139 | convs = lambda: nn.ModuleList([
140 | WNConv2d(2 * n_channels, ch, (3, 9), (1, 1), padding=(1, 4)),
141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
144 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
145 | ])
146 | self.band_convs = nn.ModuleList(
147 | [convs() for _ in range(len(self.bands))])
148 | self.conv_post = WNConv2d(ch,
149 | 1, (3, 3), (1, 1),
150 | padding=(1, 1),
151 | act=False)
152 |
153 | self.stft = Spectrogram(
154 | n_fft=window_length,
155 | win_length=window_length,
156 | hop_length=int(hop_factor * window_length),
157 | center=True,
158 | return_complex=True,
159 | power=None,
160 | )
161 |
162 | def spectrogram(self, x):
163 | x = torch.view_as_real(self.stft(x))
164 | x = rearrange(x, "b c f t p -> b (c p) t f")
165 | # Split into bands
166 | x_bands = [x[..., b[0]:b[1]] for b in self.bands]
167 | return x_bands
168 |
169 | def forward(self, x):
170 | x_bands = self.spectrogram(x)
171 | fmap = []
172 |
173 | x = []
174 | for band, stack in zip(x_bands, self.band_convs):
175 | for layer in stack:
176 | band = layer(band)
177 | fmap.append(band)
178 | x.append(band)
179 |
180 | x = torch.cat(x, dim=-1)
181 | x = self.conv_post(x)
182 | fmap.append(x)
183 |
184 | return fmap
185 |
186 |
187 | class DescriptDiscriminator(nn.Module):
188 |
189 | def __init__(
190 | self,
191 | rates: list = [],
192 | periods: list = [2, 3, 5, 7, 11],
193 | fft_sizes: list = [2048, 1024, 512],
194 | sample_rate: int = 44100,
195 | bands: list = BANDS,
196 | n_channels: int = 1,
197 | ):
198 | super().__init__()
199 | discs = []
200 | discs += [MPD(p, n_channels=n_channels) for p in periods]
201 | discs += [MSD(r, sample_rate=sample_rate, n_channels=n_channels) for r in rates]
202 | discs += [
203 | MRD(f, sample_rate=sample_rate, bands=bands, n_channels=n_channels) for f in fft_sizes
204 | ]
205 | self.discriminators = nn.ModuleList(discs)
206 |
207 | def preprocess(self, y):
208 | # Remove DC offset
209 | y = y - y.mean(dim=-1, keepdims=True)
210 | # Peak normalize the volume of input audio
211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212 | return y
213 |
214 | def forward(self, x):
215 | x = self.preprocess(x)
216 | fmaps = [d(x) for d in self.discriminators]
217 | return fmaps
218 |
--------------------------------------------------------------------------------
/rave/discriminator.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, Sequence, Tuple, Type
2 |
3 | import cached_conv as cc
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torchaudio
8 |
9 | from .blocks import normalization
10 |
11 |
12 | def spectrogram(n_fft: int):
13 | return torchaudio.transforms.Spectrogram(
14 | n_fft,
15 | hop_length=n_fft // 4,
16 | power=None,
17 | normalized=True,
18 | center=False,
19 | pad_mode=None,
20 | )
21 |
22 |
23 | def rectified_2d_conv_block(
24 | capacity,
25 | kernel_sizes,
26 | strides: Optional[Tuple[int, int]] = None,
27 | dilations: Optional[Tuple[int, int]] = None,
28 | in_size: Optional[int] = None,
29 | out_size: Optional[int] = None,
30 | activation: bool = True,
31 | ):
32 | if dilations is None:
33 | paddings = kernel_sizes[0] // 2, kernel_sizes[1] // 2
34 | else:
35 | fks = (kernel_sizes[0] - 1) * dilations[0], (kernel_sizes[1] -
36 | 1) * dilations[1]
37 | paddings = fks[0] // 2, fks[1] // 2
38 |
39 | conv = normalization(
40 | nn.Conv2d(
41 | in_size or capacity,
42 | out_size or capacity,
43 | kernel_size=kernel_sizes,
44 | stride=strides or (1, 1),
45 | dilation=dilations or (1, 1),
46 | padding=paddings,
47 | ))
48 |
49 | if not activation: return conv
50 |
51 | return nn.Sequential(conv, nn.LeakyReLU(.2))
52 |
53 |
54 | class EncodecConvNet(nn.Module):
55 |
56 | def __init__(self, capacity: int, n_channels: int = 1) -> None:
57 | super().__init__()
58 | self.net = nn.Sequential(
59 | rectified_2d_conv_block(capacity, (9, 3), in_size=2*n_channels),
60 | rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 1)),
61 | rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 2)),
62 | rectified_2d_conv_block(capacity, (9, 3), (2, 1), (1, 4)),
63 | rectified_2d_conv_block(capacity, (3, 3)),
64 | rectified_2d_conv_block(capacity, (3, 3),
65 | out_size=1,
66 | activation=False),
67 | )
68 |
69 | def forward(self, x):
70 | features = []
71 | for layer in self.net:
72 | x = layer(x)
73 | features.append(x)
74 | return features
75 |
76 |
77 | class ConvNet(nn.Module):
78 |
79 | def __init__(self, in_size, out_size, capacity, n_layers, kernel_size,
80 | stride, conv) -> None:
81 | super().__init__()
82 | channels = [in_size]
83 | channels += list(capacity * 2**np.arange(n_layers))
84 |
85 | if isinstance(stride, int):
86 | stride = n_layers * [stride]
87 |
88 | net = []
89 | for i in range(n_layers):
90 | if not isinstance(kernel_size, int):
91 | pad = (cc.get_padding(kernel_size[0],
92 | stride[i],
93 | mode="centered")[0], 0)
94 | s = (stride[i], 1)
95 | else:
96 | pad = cc.get_padding(kernel_size, stride[i],
97 | mode="centered")[0]
98 | s = stride[i]
99 | net.append(
100 | normalization(
101 | conv(
102 | channels[i],
103 | channels[i + 1],
104 | kernel_size,
105 | stride=s,
106 | padding=pad,
107 | )))
108 | net.append(nn.LeakyReLU(.2))
109 | net.append(conv(channels[-1], out_size, 1))
110 |
111 | self.net = nn.Sequential(*net)
112 |
113 | def forward(self, x):
114 | features = []
115 | for layer in self.net:
116 | x = layer(x)
117 | if isinstance(layer, nn.modules.conv._ConvNd):
118 | features.append(x)
119 | return features
120 |
121 |
122 | class MultiScaleDiscriminator(nn.Module):
123 |
124 | def __init__(self, n_discriminators, convnet, n_channels=1) -> None:
125 | super().__init__()
126 | layers = []
127 | for i in range(n_discriminators):
128 | layers.append(convnet(in_size=n_channels))
129 | self.layers = nn.ModuleList(layers)
130 |
131 | def forward(self, x):
132 | features = []
133 | for layer in self.layers:
134 | features.append(layer(x))
135 | x = nn.functional.avg_pool1d(x, 2)
136 | return features
137 |
138 |
139 | class MultiScaleSpectralDiscriminator(nn.Module):
140 |
141 | def __init__(self, scales: Sequence[int],
142 | convnet: Callable[[], nn.Module], n_channels: int = 1) -> None:
143 | super().__init__()
144 | self.specs = nn.ModuleList([spectrogram(n) for n in scales])
145 | self.nets = nn.ModuleList([convnet(n_channels=n_channels) for _ in scales])
146 |
147 | def forward(self, x):
148 | features = []
149 | for spec, net in zip(self.specs, self.nets):
150 | spec_x = spec(x)
151 | spec_x = torch.cat([spec_x.real, spec_x.imag], 1)
152 | features.append(net(spec_x))
153 | return features
154 |
155 |
156 | class MultiScaleSpectralDiscriminator1d(nn.Module):
157 |
158 | def __init__(self, scales: Sequence[int],
159 | convnet: Callable[[int], nn.Module],
160 | n_channels: int = 1) -> None:
161 | super().__init__()
162 | self.specs = nn.ModuleList([spectrogram(n) for n in scales])
163 | self.nets = nn.ModuleList([convnet(n + 2, n_channels) for n in scales])
164 |
165 | def forward(self, x):
166 | features = []
167 | for spec, net in zip(self.specs, self.nets):
168 | spec_x = spec(x).squeeze(1)
169 | spec_x = torch.cat([spec_x.real, spec_x.imag], 1)
170 | features.append(net(spec_x))
171 | return features
172 |
173 |
174 | class MultiPeriodDiscriminator(nn.Module):
175 |
176 | def __init__(self, periods, convnet, n_channels=1) -> None:
177 | super().__init__()
178 | layers = []
179 | self.periods = periods
180 |
181 | for _ in periods:
182 | layers.append(convnet(in_size=n_channels))
183 |
184 | self.layers = nn.ModuleList(layers)
185 |
186 | def forward(self, x):
187 | features = []
188 | for layer, n in zip(self.layers, self.periods):
189 | features.append(layer(self.fold(x, n)))
190 | return features
191 |
192 | def fold(self, x, n):
193 | pad = (n - (x.shape[-1] % n)) % n
194 | x = nn.functional.pad(x, (0, pad))
195 | return x.reshape(*x.shape[:2], -1, n)
196 |
197 |
198 | class CombineDiscriminators(nn.Module):
199 |
200 | def __init__(self, discriminators: Sequence[Type[nn.Module]], n_channels=1) -> None:
201 | super().__init__()
202 | self.discriminators = nn.ModuleList(disc_cls(n_channels=n_channels)
203 | for disc_cls in discriminators)
204 |
205 | def forward(self, x):
206 | features = []
207 | for disc in self.discriminators:
208 | features.extend(disc(x))
209 | return features
210 |
--------------------------------------------------------------------------------
/rave/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from time import time
3 | from typing import Callable, Optional, Iterable, Dict
4 |
5 | import gin, pdb
6 | import numpy as np
7 | import pytorch_lightning as pl
8 | import torch
9 | import torch.nn as nn
10 | from einops import rearrange
11 | from sklearn.decomposition import PCA
12 | from pytorch_lightning.trainer.states import RunningStage
13 |
14 |
15 | import rave.core
16 |
17 | from . import blocks
18 |
19 |
20 | _default_loss_weights = {
21 | 'audio_distance': 1.,
22 | 'multiband_audio_distance': 1.,
23 | 'adversarial': 1.,
24 | 'feature_matching' : 20,
25 | }
26 |
27 | class Profiler:
28 |
29 | def __init__(self):
30 | self.ticks = [[time(), None]]
31 |
32 | def tick(self, msg):
33 | self.ticks.append([time(), msg])
34 |
35 | def __repr__(self):
36 | rep = 80 * "=" + "\n"
37 | for i in range(1, len(self.ticks)):
38 | msg = self.ticks[i][1]
39 | ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
40 | rep += msg + f": {ellapsed*1000:.2f}ms\n"
41 | rep += 80 * "=" + "\n\n\n"
42 | return rep
43 |
44 |
45 | class WarmupCallback(pl.Callback):
46 |
47 | def __init__(self) -> None:
48 | super().__init__()
49 | self.state = {'training_steps': 0}
50 |
51 | def on_train_batch_start(self, trainer, pl_module, batch,
52 | batch_idx) -> None:
53 | if self.state['training_steps'] >= pl_module.warmup:
54 | pl_module.warmed_up = True
55 | self.state['training_steps'] += 1
56 |
57 | def state_dict(self):
58 | return self.state.copy()
59 |
60 | def load_state_dict(self, state_dict):
61 | self.state.update(state_dict)
62 |
63 |
64 | class QuantizeCallback(WarmupCallback):
65 |
66 | def on_train_batch_(self, trainer, pl_module, batch,
67 | batch_idx) -> None:
68 |
69 | if pl_module.warmup_quantize is None: return
70 |
71 | if self.state['training_steps'] >= pl_module.warmup_quantize:
72 | if isinstance(pl_module.encoder, blocks.DiscreteEncoder):
73 | pl_module.encoder.enabled = torch.tensor(1).type_as(
74 | pl_module.encoder.enabled)
75 | self.state['training_steps'] += 1
76 |
77 |
78 | @gin.configurable
79 | class BetaWarmupCallback(pl.Callback):
80 |
81 | def __init__(self, initial_value: float = .2,
82 | target_value: float = .2,
83 | warmup_len: int = 1,
84 | log: bool = True) -> None:
85 | super().__init__()
86 | self.state = {'training_steps': 0}
87 | self.warmup_len = warmup_len
88 | self.initial_value = initial_value
89 | self.target_value = target_value
90 | self.log_warmup = log
91 |
92 | def on_train_batch_start(self, trainer, pl_module, batch,
93 | batch_idx) -> None:
94 | self.state['training_steps'] += 1
95 | if self.state["training_steps"] >= self.warmup_len:
96 | pl_module.beta_factor = self.target_value
97 | return
98 |
99 | warmup_ratio = self.state["training_steps"] / self.warmup_len
100 |
101 | if self.log_warmup:
102 | beta = math.log(self.initial_value) * (1 - warmup_ratio) + math.log(
103 | self.target_value) * warmup_ratio
104 | pl_module.beta_factor = math.exp(beta)
105 | else:
106 | beta = warmup_ratio * (self.target_value - self.initial_value) + self.initial_value
107 | pl_module.beta_factor = min(beta, self.target_value)
108 |
109 | def state_dict(self):
110 | return self.state.copy()
111 |
112 | def load_state_dict(self, state_dict):
113 | self.state.update(state_dict)
114 |
115 |
116 | @torch.fx.wrap
117 | def _pqmf_encode(pqmf, x: torch.Tensor):
118 | batch_size = x.shape[:-2]
119 | x_multiband = x.reshape(-1, 1, x.shape[-1])
120 | x_multiband = pqmf(x_multiband)
121 | x_multiband = x_multiband.reshape(*batch_size, -1, x_multiband.shape[-1])
122 | return x_multiband
123 |
124 |
125 | @torch.fx.wrap
126 | def _pqmf_decode(pqmf, x: torch.Tensor, batch_size: Iterable[int], n_channels: int):
127 | x = x.reshape(x.shape[0] * n_channels, -1, x.shape[-1])
128 | x = pqmf.inverse(x)
129 | x = x.reshape(*batch_size, n_channels, -1)
130 | return x
131 |
132 |
133 | @gin.configurable
134 | class RAVE(pl.LightningModule):
135 |
136 | def __init__(
137 | self,
138 | latent_size,
139 | sampling_rate,
140 | encoder,
141 | decoder,
142 | discriminator,
143 | phase_1_duration,
144 | gan_loss,
145 | valid_signal_crop,
146 | feature_matching_fun,
147 | num_skipped_features,
148 | audio_distance: Callable[[], nn.Module],
149 | multiband_audio_distance: Callable[[], nn.Module],
150 | n_bands: int = 16,
151 | balancer = None,
152 | weights: Optional[Dict[str, float]] = None,
153 | warmup_quantize: Optional[int] = None,
154 | pqmf: Optional[Callable[[], nn.Module]] = None,
155 | spectrogram: Optional[Callable] = None,
156 | update_discriminator_every: int = 2,
157 | n_channels: int = 1,
158 | input_mode: str = "pqmf",
159 | output_mode: str = "pqmf",
160 | audio_monitor_epochs: int = 1,
161 | # for retro-compatibility
162 | enable_pqmf_encode: Optional[bool] = None,
163 | enable_pqmf_decode: Optional[bool] = None,
164 | is_mel_input: Optional[bool] = None,
165 | loss_weights = None
166 | ):
167 | super().__init__()
168 | self.pqmf = pqmf(n_channels=n_channels)
169 | self.spectrogram = None
170 | if spectrogram is not None:
171 | self.spectrogram = spectrogram
172 | assert input_mode in ['pqmf', 'mel', 'raw']
173 | assert output_mode in ['raw', 'pqmf']
174 | self.input_mode = input_mode
175 | self.output_mode = output_mode
176 | # retro-compatibility
177 | if (enable_pqmf_encode is not None) or (enable_pqmf_decode is not None):
178 | self.input_mode = "pqmf" if enable_pqmf_encode else "raw"
179 | self.output_mode = "pqmf" if enable_pqmf_decode else "raw"
180 | if (is_mel_input) is not None:
181 | self.input_mode = "mel"
182 | if loss_weights is not None:
183 | weights = loss_weights
184 | assert weights is not None, "RAVE model requires either weights or loss_weights (depreciated) keyword"
185 |
186 | # setup model
187 | self.encoder = encoder(n_channels=n_channels)
188 | self.decoder = decoder(n_channels=n_channels)
189 | self.discriminator = discriminator(n_channels=n_channels)
190 |
191 | self.audio_distance = audio_distance()
192 | self.multiband_audio_distance = multiband_audio_distance()
193 |
194 | self.gan_loss = gan_loss
195 |
196 | self.register_buffer("latent_pca", torch.eye(latent_size))
197 | self.register_buffer("latent_mean", torch.zeros(latent_size))
198 | self.register_buffer("fidelity", torch.zeros(latent_size))
199 |
200 | self.latent_size = latent_size
201 |
202 | self.automatic_optimization = False
203 |
204 | # SCHEDULE
205 | self.warmup = phase_1_duration
206 | self.warmup_quantize = warmup_quantize
207 | self.weights = _default_loss_weights
208 | self.weights.update(weights)
209 | self.warmed_up = False
210 |
211 | # CONSTANTS
212 | self.sr = sampling_rate
213 | self.valid_signal_crop = valid_signal_crop
214 | self.n_channels = n_channels
215 | self.feature_matching_fun = feature_matching_fun
216 | self.num_skipped_features = num_skipped_features
217 | self.update_discriminator_every = update_discriminator_every
218 |
219 | self.eval_number = 0
220 | self.beta_factor = 1.
221 | self.integrator = None
222 |
223 | self.register_buffer("receptive_field", torch.tensor([0, 0]).long())
224 | self.audio_monitor_epochs = audio_monitor_epochs
225 |
226 | def configure_optimizers(self):
227 | gen_p = list(self.encoder.parameters())
228 | gen_p += list(self.decoder.parameters())
229 | dis_p = list(self.discriminator.parameters())
230 |
231 | gen_opt = torch.optim.Adam(gen_p, 1e-3, (.5, .9))
232 | dis_opt = torch.optim.Adam(dis_p, 1e-4, (.5, .9))
233 |
234 | return ({'optimizer': gen_opt,
235 | 'lr_scheduler': {'scheduler': torch.optim.lr_scheduler.LinearLR(gen_opt, start_factor=1.0, end_factor=0.1, total_iters=self.warmup)}},
236 | {'optimizer':dis_opt})
237 |
238 | def _mel_encode(self, x: torch.Tensor):
239 | batch_size = x.shape[:-2]
240 | x = self.spectrogram(x)[..., :-1]
241 | x = torch.log1p(x).reshape(*batch_size, -1, x.shape[-1])
242 | return x
243 |
244 | def encode(self, x, return_mb: bool = False):
245 | x_enc = x
246 | if self.input_mode == "pqmf":
247 | x_enc = _pqmf_encode(self.pqmf, x_enc)
248 | elif self.input_mode == "mel":
249 | x_enc = self._mel_encode(x)
250 |
251 | z = self.encoder(x_enc)
252 | if return_mb:
253 | if self.input_mode == "pqmf":
254 | return z, x_enc
255 | else:
256 | x_multiband = _pqmf_encode(self.pqmf, x_enc)
257 | return z, x_multiband
258 | return z
259 |
260 | def decode(self, z):
261 | batch_size = z.shape[:-2]
262 | y = self.decoder(z)
263 | if self.output_mode == "pqmf":
264 | y = _pqmf_decode(self.pqmf, y, batch_size=batch_size, n_channels=self.n_channels)
265 | return y
266 |
267 | def forward(self, x):
268 | z = self.encode(x, return_mb=False)
269 | z = self.encoder.reparametrize(z)[0]
270 | return self.decode(z)
271 |
272 | def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
273 | self.lr_schedulers().step()
274 | return super().on_train_batch_end(outputs, batch, batch_idx)
275 |
276 | def split_features(self, features):
277 | feature_real = []
278 | feature_fake = []
279 | for scale in features:
280 | true, fake = zip(*map(
281 | lambda x: torch.split(x, x.shape[0] // 2, 0),
282 | scale,
283 | ))
284 | feature_real.append(true)
285 | feature_fake.append(fake)
286 | return feature_real, feature_fake
287 |
288 | def training_step(self, batch, batch_idx):
289 | p = Profiler()
290 | gen_opt, dis_opt = self.optimizers()
291 | x_raw = batch
292 | x_raw.requires_grad = True
293 |
294 | batch_size = x_raw.shape[:-2]
295 | self.encoder.set_warmed_up(self.warmed_up)
296 | self.decoder.set_warmed_up(self.warmed_up)
297 |
298 | # ENCODE INPUT
299 | # get multiband in case
300 | z, x_multiband = self.encode(x_raw, return_mb=True)
301 |
302 | z, reg = self.encoder.reparametrize(z)[:2]
303 | p.tick('encode')
304 |
305 | # DECODE LATENT
306 | y = self.decoder(z)
307 | if self.output_mode == "pqmf":
308 | y_multiband = y
309 | y_raw = _pqmf_decode(self.pqmf, y, batch_size=batch_size, n_channels=self.n_channels)
310 | else:
311 | y_raw = y
312 | y_multiband = _pqmf_encode(self.pqmf, y)
313 |
314 | # TODO this has been added for training with num_samples = 65536 samples, output padding seems to mess with output dimensions.
315 | # this may probably conflict with cached_conv
316 | y_raw = y_raw[..., :x_raw.shape[-1]]
317 | y_multiband = y_multiband[..., :x_multiband.shape[-1]]
318 |
319 | p.tick('decode')
320 |
321 | if self.valid_signal_crop and self.receptive_field.sum():
322 | x_multiband = rave.core.valid_signal_crop(
323 | x_multiband,
324 | *self.receptive_field,
325 | )
326 | y_multiband = rave.core.valid_signal_crop(
327 | y_multiband,
328 | *self.receptive_field,
329 | )
330 | p.tick('crop')
331 |
332 | # DISTANCE BETWEEN INPUT AND OUTPUT
333 | distances = {}
334 | multiband_distance = self.multiband_audio_distance(
335 | x_multiband, y_multiband)
336 | p.tick('mb distance')
337 | for k, v in multiband_distance.items():
338 | distances[f'multiband_{k}'] = self.weights['multiband_audio_distance'] * v
339 |
340 | fullband_distance = self.audio_distance(x_raw, y_raw)
341 | p.tick('fb distance')
342 |
343 | for k, v in fullband_distance.items():
344 | distances[f'fullband_{k}'] = self.weights['audio_distance'] * v
345 |
346 | feature_matching_distance = 0.
347 |
348 | if self.warmed_up: # DISCRIMINATION
349 | xy = torch.cat([x_raw, y_raw], 0)
350 | features = self.discriminator(xy)
351 |
352 | feature_real, feature_fake = self.split_features(features)
353 |
354 | loss_dis = 0
355 | loss_adv = 0
356 |
357 | pred_real = 0
358 | pred_fake = 0
359 |
360 | for scale_real, scale_fake in zip(feature_real, feature_fake):
361 | current_feature_distance = sum(
362 | map(
363 | self.feature_matching_fun,
364 | scale_real[self.num_skipped_features:],
365 | scale_fake[self.num_skipped_features:],
366 | )) / len(scale_real[self.num_skipped_features:])
367 |
368 | feature_matching_distance = feature_matching_distance + current_feature_distance
369 |
370 | _dis, _adv = self.gan_loss(scale_real[-1], scale_fake[-1])
371 |
372 | pred_real = pred_real + scale_real[-1].mean()
373 | pred_fake = pred_fake + scale_fake[-1].mean()
374 |
375 | loss_dis = loss_dis + _dis
376 | loss_adv = loss_adv + _adv
377 |
378 | feature_matching_distance = feature_matching_distance / len(
379 | feature_real)
380 |
381 | else:
382 | pred_real = torch.tensor(0.).to(x_raw)
383 | pred_fake = torch.tensor(0.).to(x_raw)
384 | loss_dis = torch.tensor(0.).to(x_raw)
385 | loss_adv = torch.tensor(0.).to(x_raw)
386 | p.tick('discrimination')
387 |
388 | # COMPOSE GEN LOSS
389 | loss_gen = {}
390 | loss_gen.update(distances)
391 | p.tick('update loss gen dict')
392 |
393 | if reg.item():
394 | loss_gen['regularization'] = reg * self.beta_factor
395 |
396 | if self.warmed_up:
397 | loss_gen['feature_matching'] = self.weights['feature_matching'] * feature_matching_distance
398 | loss_gen['adversarial'] = self.weights['adversarial'] * loss_adv
399 |
400 | # OPTIMIZATION
401 | if not (batch_idx %
402 | self.update_discriminator_every) and self.warmed_up:
403 | dis_opt.zero_grad()
404 | loss_dis.backward()
405 | dis_opt.step()
406 | p.tick('dis opt')
407 | else:
408 | gen_opt.zero_grad()
409 | loss_gen_value = 0.
410 | for k, v in loss_gen.items():
411 | loss_gen_value += v * self.weights.get(k, 1.)
412 | loss_gen_value.backward()
413 | gen_opt.step()
414 |
415 | # LOGGING
416 | self.log("beta_factor", self.beta_factor)
417 |
418 | if self.warmed_up:
419 | self.log("loss_dis", loss_dis)
420 | self.log("pred_real", pred_real.mean())
421 | self.log("pred_fake", pred_fake.mean())
422 |
423 | self.log_dict(loss_gen)
424 | p.tick('logging')
425 |
426 | def validation_step(self, x, batch_idx):
427 |
428 | z = self.encode(x)
429 | if isinstance(self.encoder, blocks.VariationalEncoder):
430 | mean = torch.split(z, z.shape[1] // 2, 1)[0]
431 | else:
432 | mean = None
433 |
434 | z = self.encoder.reparametrize(z)[0]
435 | y = self.decode(z)
436 |
437 | distance = self.audio_distance(x, y)
438 | full_distance = sum(distance.values())
439 |
440 | if self.trainer is not None:
441 | self.log('validation', full_distance)
442 |
443 | return torch.cat([x, y], -1), mean
444 |
445 | def validation_epoch_end(self, out):
446 | if not self.receptive_field.sum():
447 | print("Computing receptive field for this configuration...")
448 | lrf, rrf = rave.core.get_rave_receptive_field(self, n_channels=self.n_channels)
449 | self.receptive_field[0] = lrf
450 | self.receptive_field[1] = rrf
451 | print(
452 | f"Receptive field: {1000*lrf/self.sr:.2f}ms <-- x --> {1000*rrf/self.sr:.2f}ms"
453 | )
454 |
455 | if not len(out): return
456 |
457 | audio, z = list(zip(*out))
458 | audio = list(map(lambda x: x.cpu(), audio))
459 |
460 | if self.trainer.state.stage == RunningStage.SANITY_CHECKING:
461 | return
462 |
463 | # LATENT SPACE ANALYSIS
464 | if not self.warmed_up and isinstance(self.encoder,
465 | blocks.VariationalEncoder):
466 | z = torch.cat(z, 0)
467 | z = rearrange(z, "b c t -> (b t) c")
468 |
469 | self.latent_mean.copy_(z.mean(0))
470 | z = z - self.latent_mean
471 |
472 | pca = PCA(z.shape[-1]).fit(z.cpu().numpy())
473 |
474 | components = pca.components_
475 | components = torch.from_numpy(components).to(z)
476 | self.latent_pca.copy_(components)
477 |
478 | var = pca.explained_variance_ / np.sum(pca.explained_variance_)
479 | var = np.cumsum(var)
480 |
481 | self.fidelity.copy_(torch.from_numpy(var).to(self.fidelity))
482 |
483 | var_percent = [.8, .9, .95, .99]
484 | for p in var_percent:
485 | self.log(
486 | f"fidelity_{p}",
487 | np.argmax(var > p).astype(np.float32),
488 | )
489 |
490 | y = torch.cat(audio, 0)[:8].reshape(-1).numpy()
491 | if self.integrator is not None:
492 | y = self.integrator(y)
493 | self.logger.experiment.add_audio("audio_val", y, self.eval_number,
494 | self.sr)
495 | self.eval_number += 1
496 |
497 | def on_fit_start(self):
498 | tb = self.logger.experiment
499 |
500 | config = gin.operative_config_str()
501 | config = config.split('\n')
502 | config = ['```'] + config + ['```']
503 | config = '\n'.join(config)
504 | tb.add_text("config", config)
505 |
506 | model = str(self)
507 | model = model.split('\n')
508 | model = ['```'] + model + ['```']
509 | model = '\n'.join(model)
510 | tb.add_text("model", model)
511 |
512 |
--------------------------------------------------------------------------------
/rave/pqmf.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import cached_conv as cc
4 | import gin
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from einops import rearrange
9 | from scipy.optimize import fmin
10 | from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
11 |
12 |
13 | def reverse_half(x):
14 | mask = torch.ones_like(x)
15 | mask[..., 1::2, ::2] = -1
16 |
17 | return x * mask
18 |
19 |
20 | def center_pad_next_pow_2(x):
21 | next_2 = 2**math.ceil(math.log2(x.shape[-1]))
22 | pad = next_2 - x.shape[-1]
23 | return nn.functional.pad(x, (pad // 2, pad // 2 + int(pad % 2)))
24 |
25 |
26 | def make_odd(x):
27 | if not x.shape[-1] % 2:
28 | x = nn.functional.pad(x, (0, 1))
29 | return x
30 |
31 |
32 | def get_qmf_bank(h, n_band):
33 | """
34 | Modulates an input protoype filter into a bank of
35 | cosine modulated filters
36 | Parameters
37 | ----------
38 | h: torch.Tensor
39 | prototype filter
40 | n_band: int
41 | number of sub-bands
42 | """
43 | k = torch.arange(n_band).reshape(-1, 1)
44 | N = h.shape[-1]
45 | t = torch.arange(-(N // 2), N // 2 + 1)
46 |
47 | p = (-1)**k * math.pi / 4
48 |
49 | mod = torch.cos((2 * k + 1) * math.pi / (2 * n_band) * t + p)
50 | hk = 2 * h * mod
51 |
52 | return hk
53 |
54 |
55 | def kaiser_filter(wc, atten, N=None):
56 | """
57 | Computes a kaiser lowpass filter
58 | Parameters
59 | ----------
60 | wc: float
61 | Angular frequency
62 |
63 | atten: float
64 | Attenuation (dB, positive)
65 | """
66 | N_, beta = kaiserord(atten, wc / np.pi)
67 | N_ = 2 * (N_ // 2) + 1
68 | N = N if N is not None else N_
69 | h = firwin(N, wc, window=('kaiser', beta), scale=False, nyq=np.pi)
70 | return h
71 |
72 |
73 | def loss_wc(wc, atten, M, N):
74 | """
75 | Computes the objective described in https://ieeexplore.ieee.org/document/681427
76 | """
77 | h = kaiser_filter(wc, atten, N)
78 | g = np.convolve(h, h[::-1], "full")
79 | g = abs(g[g.shape[-1] // 2::2 * M][1:])
80 | return np.max(g)
81 |
82 |
83 | def get_prototype(atten, M, N=None):
84 | """
85 | Given an attenuation objective and the number of bands
86 | returns the corresponding lowpass filter
87 | """
88 | wc = fmin(lambda w: loss_wc(w, atten, M, N), 1 / M, disp=0)[0]
89 | return kaiser_filter(wc, atten, N)
90 |
91 |
92 | def polyphase_forward(x, hk, rearrange_filter=True):
93 | """
94 | Polyphase implementation of the analysis process (fast)
95 | Parameters
96 | ----------
97 | x: torch.Tensor
98 | signal to analyse ( B x 1 x T )
99 |
100 | hk: torch.Tensor
101 | filter bank ( M x T )
102 | """
103 | x = rearrange(x, "b c (t m) -> b (c m) t", m=hk.shape[0])
104 | if rearrange_filter:
105 | hk = rearrange(hk, "c (t m) -> c m t", m=hk.shape[0])
106 | x = nn.functional.conv1d(x, hk, padding=hk.shape[-1] // 2)[..., :-1]
107 | return x
108 |
109 |
110 | def polyphase_inverse(x, hk, rearrange_filter=True):
111 | """
112 | Polyphase implementation of the synthesis process (fast)
113 | Parameters
114 | ----------
115 | x: torch.Tensor
116 | signal to synthesize from ( B x 1 x T )
117 |
118 | hk: torch.Tensor
119 | filter bank ( M x T )
120 | """
121 |
122 | m = hk.shape[0]
123 |
124 | if rearrange_filter:
125 | hk = hk.flip(-1)
126 | hk = rearrange(hk, "c (t m) -> m c t", m=m) # polyphase
127 |
128 | pad = hk.shape[-1] // 2 + 1
129 | x = nn.functional.conv1d(x, hk, padding=int(pad))[..., :-1] * m
130 |
131 | x = x.flip(1)
132 | x = rearrange(x, "b (c m) t -> b c (t m)", m=m)
133 | x = x[..., 2 * hk.shape[1]:]
134 | return x
135 |
136 |
137 | def classic_forward(x, hk):
138 | """
139 | Naive implementation of the analysis process (slow)
140 | Parameters
141 | ----------
142 | x: torch.Tensor
143 | signal to analyse ( B x 1 x T )
144 |
145 | hk: torch.Tensor
146 | filter bank ( M x T )
147 | """
148 | x = nn.functional.conv1d(
149 | x,
150 | hk.unsqueeze(1),
151 | stride=hk.shape[0],
152 | padding=hk.shape[-1] // 2,
153 | )[..., :-1]
154 | return x
155 |
156 |
157 | def classic_inverse(x, hk):
158 | """
159 | Naive implementation of the synthesis process (slow)
160 | Parameters
161 | ----------
162 | x: torch.Tensor
163 | signal to synthesize from ( B x 1 x T )
164 |
165 | hk: torch.Tensor
166 | filter bank ( M x T )
167 | """
168 | hk = hk.flip(-1)
169 | y = torch.zeros(*x.shape[:2], hk.shape[0] * x.shape[-1]).to(x)
170 | y[..., ::hk.shape[0]] = x * hk.shape[0]
171 | y = nn.functional.conv1d(
172 | y,
173 | hk.unsqueeze(0),
174 | padding=hk.shape[-1] // 2,
175 | )[..., 1:]
176 | return y
177 |
178 |
179 | @torch.fx.wrap
180 | class PQMF(nn.Module):
181 | """
182 | Pseudo Quadrature Mirror Filter multiband decomposition / reconstruction
183 | Parameters
184 | ----------
185 | attenuation: int
186 | Attenuation of the rejected bands (dB, 80 - 120)
187 | n_band: int
188 | Number of bands, must be a power of 2 if the polyphase implementation
189 | is needed
190 | """
191 |
192 | def __init__(self, attenuation, n_band, polyphase=True, n_channels = 1):
193 | super().__init__()
194 | h = get_prototype(attenuation, n_band)
195 |
196 | if polyphase:
197 | power = math.log2(n_band)
198 | assert power == math.floor(
199 | power
200 | ), "when using the polyphase algorithm, n_band must be a power of 2"
201 |
202 | h = torch.from_numpy(h).float()
203 | hk = get_qmf_bank(h, n_band)
204 | hk = center_pad_next_pow_2(hk)
205 |
206 | self.register_buffer("hk", hk)
207 | self.register_buffer("h", h)
208 | self.n_band = n_band
209 | self.polyphase = polyphase
210 | self.n_channels = n_channels
211 |
212 | def forward(self, x):
213 | if x.ndim == 2:
214 | return torch.stack([self.forward(x[i]) for i in range(x.shape[0])])
215 | if self.n_band == 1:
216 | return x
217 | elif self.polyphase:
218 | x = polyphase_forward(x, self.hk)
219 | else:
220 | x = classic_forward(x, self.hk)
221 |
222 | x = reverse_half(x)
223 |
224 | return x
225 |
226 | def inverse(self, x):
227 | if x.ndim == 2:
228 | if self.n_channels == 1:
229 | return self.inverse(x[0]).unsqueeze(0)
230 | else:
231 | x = x.split(self.n_channels, -2)
232 | return torch.stack([self.inverse(x[i]) for i in len(x)])
233 |
234 | if self.n_band == 1:
235 | return x
236 |
237 | x = reverse_half(x)
238 |
239 | if self.polyphase:
240 | return polyphase_inverse(x, self.hk)
241 | else:
242 | return classic_inverse(x, self.hk)
243 |
244 |
245 | class CachedPQMF(PQMF):
246 |
247 | def __init__(self, *args, **kwargs):
248 | super().__init__(*args, **kwargs)
249 |
250 | hkf = make_odd(self.hk).unsqueeze(1)
251 |
252 | hki = self.hk.flip(-1)
253 | hki = rearrange(hki, "c (t m) -> m c t", m=self.hk.shape[0])
254 | hki = make_odd(hki)
255 |
256 | self.forward_conv = cc.Conv1d(
257 | hkf.shape[1],
258 | hkf.shape[0],
259 | hkf.shape[2],
260 | padding=cc.get_padding(hkf.shape[-1]),
261 | stride=hkf.shape[0],
262 | bias=False,
263 | )
264 | self.forward_conv.weight.data.copy_(hkf)
265 |
266 | self.inverse_conv = cc.Conv1d(
267 | hki.shape[1],
268 | hki.shape[0],
269 | hki.shape[-1],
270 | padding=cc.get_padding(hki.shape[-1]),
271 | bias=False,
272 | )
273 | self.inverse_conv.weight.data.copy_(hki)
274 |
275 | def script_cache(self):
276 | self.forward_conv.script_cache()
277 | self.inverse_conv.script_cache()
278 |
279 | def forward(self, x):
280 | if self.n_band == 1: return x
281 | x = self.forward_conv(x)
282 | x = reverse_half(x)
283 | return x
284 |
285 | def inverse(self, x):
286 | if self.n_band == 1: return x
287 | x = reverse_half(x)
288 | m = self.hk.shape[0]
289 | x = self.inverse_conv(x) * m
290 | x = x.flip(1)
291 | x = x.permute(0, 2, 1)
292 | x = x.reshape(x.shape[0], x.shape[1], -1, m).permute(0, 2, 1, 3)
293 | x = x.reshape(x.shape[0], x.shape[1], -1)
294 | return x
295 |
--------------------------------------------------------------------------------
/rave/prior/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import *
--------------------------------------------------------------------------------
/rave/prior/core.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class QuantizedNormal(nn.Module):
7 | def __init__(self, resolution, dither=True):
8 | super().__init__()
9 | self.resolution = resolution
10 | self.dither = dither
11 | self.clamp = 4
12 |
13 | def from_normal(self, x):
14 | return .5 * (1 + torch.erf(x / math.sqrt(2)))
15 |
16 | def to_normal(self, x):
17 | x = torch.erfinv(2 * x - 1) * math.sqrt(2)
18 | return torch.clamp(x, -self.clamp, self.clamp)
19 |
20 | def encode(self, x):
21 | x = self.from_normal(x)
22 | x = torch.floor(x * self.resolution)
23 | x = torch.clamp(x, 0, self.resolution - 1)
24 | return self.to_stack_one_hot(x.long())
25 |
26 | def to_stack_one_hot(self, x):
27 | x = nn.functional.one_hot(x, self.resolution)
28 | x = x.permute(0, 2, 1, 3)
29 | x = x.reshape(x.shape[0], x.shape[1], -1)
30 | x = x.permute(0, 2, 1).float()
31 | return x
32 |
33 | def decode(self, x):
34 | x = x.permute(0, 2, 1)
35 | x = x.reshape(x.shape[0], x.shape[1], -1, self.resolution)
36 | x = torch.argmax(x, -1) / self.resolution
37 | if self.dither:
38 | x = x + torch.rand_like(x) / self.resolution
39 | x = self.to_normal(x)
40 | x = x.permute(0, 2, 1)
41 | return x
42 |
43 |
44 | class DiagonalShift(nn.Module):
45 | def __init__(self, groups=1):
46 | super().__init__()
47 | assert isinstance(groups, int)
48 | assert groups > 0
49 | self.groups = groups
50 |
51 | def shift(self, x: torch.Tensor, i: int, n_dim: int):
52 | i = i // self.groups
53 | n_dim = n_dim // self.groups
54 | start = i
55 | end = -n_dim + i + 1
56 | end = end if end else None
57 | return x[..., start:end]
58 |
59 | def forward(self, x):
60 | n_dim = x.shape[1]
61 | x = torch.split(x, 1, 1)
62 | x = [
63 | self.shift(_x, i, n_dim) for _x, i in zip(
64 | x,
65 | torch.arange(n_dim).flip(0),
66 | )
67 | ]
68 | x = torch.cat(list(x), 1)
69 | return x
70 |
71 | def inverse(self, x):
72 | x = x.flip(1)
73 | x = self.forward(x)
74 | x = x.flip(1)
75 | return x
--------------------------------------------------------------------------------
/rave/prior/model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import pytorch_lightning as pl
6 | import gin
7 | from tqdm import tqdm
8 | import math
9 | import numpy as np
10 |
11 | from .residual_block import ResidualBlock
12 | from .core import DiagonalShift, QuantizedNormal
13 |
14 |
15 | import cached_conv as cc
16 |
17 | class Prior(pl.LightningModule):
18 |
19 | def __init__(self, resolution, res_size, skp_size, kernel_size, cycle_size,
20 | n_layers, pretrained_vae=None, fidelity=None, n_channels=1, latent_size=None, sr=44100):
21 | super().__init__()
22 |
23 | self.diagonal_shift = DiagonalShift()
24 | self.quantized_normal = QuantizedNormal(resolution)
25 |
26 | self.synth = pretrained_vae
27 | self.sr = sr
28 |
29 | if latent_size is not None:
30 | self.latent_size = 2**math.ceil(math.log2(latent_size))
31 | elif fidelity is not None:
32 | assert pretrained_vae, "giving fidelity keyword needs the pretrained_vae keyword to be given"
33 | latent_size = torch.where(pretrained_vae.fidelity > fidelity)[0][0]
34 | self.latent_size = 2**math.ceil(math.log2(latent_size))
35 | else:
36 | raise RuntimeError('please init Prior with either fidelity or latent_size keywords')
37 |
38 | self.pre_net = nn.Sequential(
39 | cc.Conv1d(
40 | resolution * self.latent_size,
41 | res_size,
42 | kernel_size,
43 | padding=cc.get_padding(kernel_size, mode="causal"),
44 | groups=self.latent_size,
45 | ),
46 | nn.LeakyReLU(.2),
47 | )
48 |
49 | self.residuals = nn.ModuleList([
50 | ResidualBlock(
51 | res_size,
52 | skp_size,
53 | kernel_size,
54 | 2**(i % cycle_size),
55 | ) for i in range(n_layers)
56 | ])
57 |
58 | self.post_net = nn.Sequential(
59 | cc.Conv1d(skp_size, skp_size, 1),
60 | nn.LeakyReLU(.2),
61 | cc.Conv1d(
62 | skp_size,
63 | resolution * self.latent_size,
64 | 1,
65 | groups=self.latent_size,
66 | ),
67 | )
68 |
69 | self.n_channels = n_channels
70 | self.val_idx = 0
71 | rf = (kernel_size - 1) * sum(2**(np.arange(n_layers) % cycle_size)) + 1
72 | if pretrained_vae is not None:
73 | ratio = self.get_model_ratio()
74 | self.min_receptive_field = 2**math.ceil(math.log2(rf * ratio))
75 |
76 | def get_model_ratio(self):
77 | x_len = 2**14
78 | x = torch.zeros(1, self.n_channels, x_len)
79 | z = self.encode(x)
80 | ratio_encode = x_len // z.shape[-1]
81 | return ratio_encode
82 |
83 | def configure_optimizers(self):
84 | p = []
85 | p.extend(list(self.pre_net.parameters()))
86 | p.extend(list(self.residuals.parameters()))
87 | p.extend(list(self.post_net.parameters()))
88 | return torch.optim.Adam(p, lr=1e-4)
89 |
90 | @torch.no_grad()
91 | def encode(self, x):
92 | self.synth.eval()
93 | z = self.synth.encode(x)
94 | z = self.post_process_latent(z)
95 | return z
96 |
97 | @torch.no_grad()
98 | def decode(self, z):
99 | self.synth.eval()
100 | z = self.pre_process_latent(z)
101 | return self.synth.decode(z)
102 |
103 | def forward(self, x):
104 | res = self.pre_net(x)
105 | skp = torch.tensor(0.).to(x)
106 | for layer in self.residuals:
107 | res, skp = layer(res, skp)
108 | x = self.post_net(skp)
109 | return x
110 |
111 | @torch.no_grad()
112 | def generate(self, x, argmax: bool = False):
113 | for i in tqdm(range(x.shape[-1] - 1)):
114 | if cc.USE_BUFFER_CONV:
115 | start = i
116 | else:
117 | start = None
118 |
119 | pred = self.forward(x[..., start:i + 1])
120 |
121 | if not cc.USE_BUFFER_CONV:
122 | pred = pred[..., -1:]
123 |
124 | pred = self.post_process_prediction(pred, argmax=argmax)
125 |
126 | x[..., i + 1:i + 2] = pred
127 | return x
128 |
129 | def split_classes(self, x):
130 | # B x D*C x T
131 | x = x.permute(0, 2, 1)
132 | x = x.reshape(x.shape[0], x.shape[1], self.latent_size, -1)
133 | x = x.permute(0, 2, 1, 3) # B x D x T x C
134 | return x
135 |
136 | def post_process_prediction(self, x, argmax: bool = False):
137 | x = self.split_classes(x)
138 | shape = x.shape[:-1]
139 | x = x.reshape(-1, x.shape[-1])
140 |
141 | if argmax:
142 | x = torch.argmax(x, -1)
143 | else:
144 | x = torch.softmax(x - torch.logsumexp(x, -1, keepdim=True), -1)
145 | x = torch.multinomial(x, 1, True).squeeze(-1)
146 |
147 | x = x.reshape(shape[0], shape[1], shape[2])
148 | x = self.quantized_normal.to_stack_one_hot(x)
149 | return x
150 |
151 | def training_step(self, batch, batch_idx):
152 | x = self.encode(batch)
153 | x = self.quantized_normal.encode(self.diagonal_shift(x))
154 | pred = self.forward(x)
155 |
156 | x = torch.argmax(self.split_classes(x[..., 1:]), -1)
157 | pred = self.split_classes(pred[..., :-1])
158 |
159 | loss = nn.functional.cross_entropy(
160 | pred.reshape(-1, self.quantized_normal.resolution),
161 | x.reshape(-1),
162 | )
163 |
164 | self.log("latent_prediction", loss)
165 | return loss
166 |
167 | def validation_step(self, batch, batch_idx):
168 | x = self.encode(batch)
169 | x = self.quantized_normal.encode(self.diagonal_shift(x))
170 | pred = self.forward(x)
171 |
172 | x = torch.argmax(self.split_classes(x[..., 1:]), -1)
173 | pred = self.split_classes(pred[..., :-1])
174 |
175 | loss = nn.functional.cross_entropy(
176 | pred.reshape(-1, self.quantized_normal.resolution),
177 | x.reshape(-1),
178 | )
179 |
180 | self.log("validation", loss)
181 | return batch
182 |
183 | def validation_epoch_end(self, out):
184 | x = torch.randn_like(self.encode(out[0]))
185 | x = self.quantized_normal.encode(self.diagonal_shift(x))
186 | z = self.generate(x)
187 | z = self.diagonal_shift.inverse(self.quantized_normal.decode(z))
188 |
189 | y = self.decode(z)
190 | self.logger.experiment.add_audio(
191 | "generation",
192 | y.reshape(-1),
193 | self.val_idx,
194 | self.synth.sr,
195 | )
196 | self.val_idx += 1
197 |
198 | @abc.abstractmethod
199 | def post_process_latent(self, z):
200 | raise NotImplementedError()
201 |
202 | @abc.abstractmethod
203 | def pre_process_latent(self, z):
204 | raise NotImplementedError()
205 |
206 |
207 |
208 | @gin.configurable
209 | class VariationalPrior(Prior):
210 |
211 | def post_process_latent(self, z):
212 | z = self.synth.encoder.reparametrize(z)[0]
213 | z = z - self.synth.latent_mean.unsqueeze(-1)
214 | z = F.conv1d(z, self.synth.latent_pca.unsqueeze(-1))
215 | z = z[:, :self.latent_size]
216 | return z
217 |
218 | def pre_process_latent(self, z):
219 | noise = torch.randn(
220 | z.shape[0],
221 | self.synth.latent_size - z.shape[1],
222 | z.shape[-1],
223 | ).type_as(z)
224 | z = torch.cat([z, noise], 1)
225 | z = F.conv1d(z, self.synth.latent_pca.T.unsqueeze(-1))
226 | z = z + self.synth.latent_mean.unsqueeze(-1)
227 | return z
--------------------------------------------------------------------------------
/rave/prior/residual_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import cached_conv as cc
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 |
8 | def __init__(self, res_size, skp_size, kernel_size, dilation):
9 | super().__init__()
10 | fks = (kernel_size - 1) * dilation + 1
11 |
12 | self.dconv = cc.Conv1d(
13 | res_size,
14 | 2 * res_size,
15 | kernel_size,
16 | padding=(fks - 1, 0),
17 | dilation=dilation,
18 | )
19 |
20 | self.rconv = nn.Conv1d(res_size, res_size, 1)
21 | self.sconv = nn.Conv1d(res_size, skp_size, 1)
22 |
23 | def forward(self, x, skp):
24 | res = x.clone()
25 |
26 | x = self.dconv(x)
27 | xa, xb = torch.split(x, x.shape[1] // 2, 1)
28 |
29 | x = torch.sigmoid(xa) * torch.tanh(xb)
30 | res = res + self.rconv(x)
31 | skp = skp + self.sconv(x)
32 | return res, skp
--------------------------------------------------------------------------------
/rave/quantization.py:
--------------------------------------------------------------------------------
1 | # Code adapted from https://github.com/lucidrains/vector-quantize-pytorch
2 |
3 | from typing import Any, Callable, Optional, Union
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from einops import repeat
8 | from torch import nn
9 |
10 |
11 | def ema_inplace(moving_avg, new, decay: float):
12 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
13 |
14 |
15 | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
16 | return (x + epsilon) / (x.sum() + n_categories * epsilon)
17 |
18 |
19 | def uniform_init(*shape: int):
20 | t = torch.empty(shape)
21 | nn.init.kaiming_uniform_(t)
22 | return t
23 |
24 |
25 | def sample_vectors(samples, num: int):
26 | num_samples, device = samples.shape[0], samples.device
27 |
28 | if num_samples >= num:
29 | indices = torch.randperm(num_samples, device=device)[:num]
30 | else:
31 | indices = torch.randint(0, num_samples, (num, ), device=device)
32 |
33 | return samples[indices]
34 |
35 |
36 | def kmeans(samples, num_clusters: int, num_iters: int = 10):
37 | dim, dtype = samples.shape[-1], samples.dtype
38 |
39 | means = sample_vectors(samples, num_clusters)
40 |
41 | for _ in range(num_iters):
42 | diffs = samples[:, None] - means[None]
43 | dists = -(diffs**2).sum(dim=-1)
44 |
45 | buckets = dists.max(dim=-1).indices
46 | bins = torch.bincount(buckets, minlength=num_clusters)
47 | zero_mask = bins == 0
48 | bins_min_clamped = bins.masked_fill(zero_mask, 1)
49 |
50 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
51 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
52 | new_means = new_means / bins_min_clamped[..., None]
53 |
54 | means = torch.where(zero_mask[..., None], means, new_means)
55 |
56 | return means, bins
57 |
58 |
59 | class EuclideanCodebook(nn.Module):
60 | """Codebook with Euclidean distance.
61 | Args:
62 | dim (int): Dimension.
63 | codebook_size (int): Codebook size.
64 | kmeans_init (bool): Whether to use k-means to initialize the codebooks.
65 | If set to true, run the k-means algorithm on the first training batch and use
66 | the learned centroids as initialization.
67 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
68 | decay (float): Decay for exponential moving average over the codebooks.
69 | epsilon (float): Epsilon value for numerical stability.
70 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
71 | that have an exponential moving average cluster size less than the specified threshold with
72 | randomly selected vector from the current batch.
73 | """
74 |
75 | def __init__(
76 | self,
77 | dim: int,
78 | codebook_size: int,
79 | kmeans_init: int = False,
80 | kmeans_iters: int = 10,
81 | decay: float = 0.99,
82 | epsilon: float = 1e-5,
83 | threshold_ema_dead_code: int = 2,
84 | ):
85 | super().__init__()
86 | self.decay = decay
87 | init_fn: Union[Callable[..., torch.Tensor],
88 | Any] = uniform_init if not kmeans_init else torch.zeros
89 | embed = init_fn(codebook_size, dim)
90 |
91 | self.codebook_size = codebook_size
92 |
93 | self.kmeans_iters = kmeans_iters
94 | self.epsilon = epsilon
95 | self.threshold_ema_dead_code = threshold_ema_dead_code
96 |
97 | self.register_buffer("inited", torch.Tensor([not kmeans_init]))
98 | self.register_buffer("cluster_size", torch.zeros(codebook_size))
99 | self.register_buffer("embed", embed)
100 | self.register_buffer("embed_avg", embed.clone())
101 |
102 | @torch.jit.unused
103 | def init_embed_(self, data):
104 | embed, cluster_size = kmeans(data, self.codebook_size,
105 | self.kmeans_iters)
106 | self.embed.data.copy_(embed)
107 | self.embed_avg.data.copy_(embed.clone())
108 | self.cluster_size.data.copy_(cluster_size)
109 | self.inited.data.copy_(torch.Tensor([True]))
110 |
111 | def replace_(self, samples, mask):
112 | modified_codebook = torch.where(
113 | mask[..., None], sample_vectors(samples, self.codebook_size),
114 | self.embed)
115 | self.embed.data.copy_(modified_codebook)
116 |
117 | def expire_codes_(self, batch_samples):
118 | if self.threshold_ema_dead_code == 0:
119 | return
120 |
121 | expired_codes = self.cluster_size < self.threshold_ema_dead_code
122 | if not torch.any(expired_codes):
123 | return
124 |
125 | batch_samples = batch_samples.reshape(-1, batch_samples.shape[-1])
126 | self.replace_(batch_samples, mask=expired_codes)
127 |
128 | def preprocess(self, x):
129 | return x.reshape(-1, x.shape[-1])
130 |
131 | def quantize(self, x):
132 | embed = self.embed.t()
133 | dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
134 | embed.pow(2).sum(0, keepdim=True))
135 | embed_ind = dist.max(dim=-1).indices
136 | return embed_ind
137 |
138 | def dequantize(self, embed_ind):
139 | quantize = F.embedding(embed_ind, self.embed)
140 | return quantize
141 |
142 | def encode(self, x):
143 | shape = x.shape
144 | # pre-process
145 | x = self.preprocess(x)
146 | # quantize
147 | embed_ind = self.quantize(x)
148 | # post-process
149 | embed_ind = embed_ind.reshape(shape[0], shape[1])
150 | return embed_ind
151 |
152 | def decode(self, embed_ind):
153 | quantize = self.dequantize(embed_ind)
154 | return quantize
155 |
156 | def forward(self, x):
157 | shape, dtype = x.shape, x.dtype
158 | x = self.preprocess(x)
159 |
160 | if not self.inited:
161 | self.init_embed_(x)
162 |
163 | embed_ind = self.quantize(x)
164 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
165 | embed_ind = embed_ind.reshape(shape[0], shape[1])
166 | quantize = self.dequantize(embed_ind)
167 |
168 | if self.training:
169 | # We do the expiry of code at that point as buffers are in sync
170 | # and all the workers will take the same decision.
171 | self.expire_codes_(x)
172 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
173 | embed_sum = x.t() @ embed_onehot
174 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
175 | cluster_size = (laplace_smoothing(
176 | self.cluster_size, self.codebook_size, self.epsilon) *
177 | self.cluster_size.sum())
178 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
179 | self.embed.data.copy_(embed_normalized)
180 |
181 | return quantize, embed_ind
182 |
183 |
184 | class VectorQuantization(nn.Module):
185 | """Vector quantization implementation.
186 | Currently supports only euclidean distance.
187 | Args:
188 | dim (int): Dimension
189 | codebook_size (int): Codebook size
190 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
191 | decay (float): Decay for exponential moving average over the codebooks.
192 | epsilon (float): Epsilon value for numerical stability.
193 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
194 | kmeans_iters (int): Number of iterations used for kmeans initialization.
195 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
196 | that have an exponential moving average cluster size less than the specified threshold with
197 | randomly selected vector from the current batch.
198 | commitment_weight (float): Weight for commitment loss.
199 | """
200 |
201 | def __init__(
202 | self,
203 | dim: int,
204 | codebook_size: int,
205 | codebook_dim: Optional[int] = None,
206 | decay: float = 0.99,
207 | epsilon: float = 1e-5,
208 | kmeans_init: bool = True,
209 | kmeans_iters: int = 50,
210 | threshold_ema_dead_code: int = 2,
211 | commitment_weight: float = 1.,
212 | ):
213 | super().__init__()
214 | _codebook_dim: int = codebook_dim or dim
215 |
216 | requires_projection = _codebook_dim != dim
217 | self.project_in = (nn.Linear(dim, _codebook_dim)
218 | if requires_projection else nn.Identity())
219 | self.project_out = (nn.Linear(_codebook_dim, dim)
220 | if requires_projection else nn.Identity())
221 |
222 | self.epsilon = epsilon
223 | self.commitment_weight = commitment_weight
224 |
225 | self._codebook = EuclideanCodebook(
226 | dim=_codebook_dim,
227 | codebook_size=codebook_size,
228 | kmeans_init=kmeans_init,
229 | kmeans_iters=kmeans_iters,
230 | decay=decay,
231 | epsilon=epsilon,
232 | threshold_ema_dead_code=threshold_ema_dead_code)
233 | self.codebook_size = codebook_size
234 |
235 | @property
236 | def codebook(self):
237 | return self._codebook.embed
238 |
239 | def encode(self, x):
240 | x = x.permute(0, 2, 1)
241 | x = self.project_in(x)
242 | embed_in = self._codebook.encode(x)
243 | return embed_in
244 |
245 | def decode(self, embed_ind):
246 | quantize = self._codebook.decode(embed_ind)
247 | quantize = self.project_out(quantize)
248 | quantize = quantize.permute(0, 2, 1)
249 | return quantize
250 |
251 | def forward(self, x):
252 | device = x.device
253 | x = x.permute(0, 2, 1)
254 | x = self.project_in(x)
255 |
256 | quantize, embed_ind = self._codebook(x)
257 |
258 | if self.training:
259 | quantize = x + (quantize - x).detach()
260 |
261 | loss = torch.tensor([0.0], device=device, requires_grad=self.training)
262 |
263 | if self.training:
264 | if self.commitment_weight > 0:
265 | commit_loss = F.mse_loss(quantize.detach(), x)
266 | loss = loss + commit_loss * self.commitment_weight
267 |
268 | quantize = self.project_out(quantize)
269 | quantize = quantize.permute(0, 2, 1)
270 | return quantize, embed_ind, loss
271 |
272 |
273 | class ResidualVectorQuantization(nn.Module):
274 | """Residual vector quantization implementation.
275 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
276 | """
277 |
278 | def __init__(self, num_quantizers, **kwargs):
279 | super().__init__()
280 | self.layers = nn.ModuleList(
281 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)])
282 |
283 | def forward(self, x):
284 | quantized_out = 0.0
285 | residual = x
286 |
287 | all_losses = []
288 | all_indices = []
289 |
290 | for layer in self.layers:
291 | quantized, indices, loss = layer(residual)
292 | residual = residual - quantized
293 | quantized_out = quantized_out + quantized
294 |
295 | all_indices.append(indices)
296 | all_losses.append(loss)
297 |
298 | out_losses = torch.stack(all_losses, 0).sum()
299 | all_indices = torch.stack(all_indices, 1)
300 | return quantized_out, out_losses, all_indices
301 |
302 | def encode(self, x: torch.Tensor) -> torch.Tensor:
303 | residual = x
304 | all_indices = []
305 | for layer in self.layers:
306 | indices = layer.encode(residual)
307 | quantized = layer.decode(indices)
308 | residual = residual - quantized
309 | all_indices.append(indices)
310 | out_indices = torch.stack(all_indices, 1)
311 | return out_indices
312 |
313 | def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
314 | quantized_out = torch.tensor(0.0, device=q_indices.device)
315 | for i, layer in enumerate(self.layers):
316 | quantized = layer.decode(q_indices[:, i])
317 | quantized_out = quantized_out + quantized
318 | return quantized_out
--------------------------------------------------------------------------------
/rave/resampler.py:
--------------------------------------------------------------------------------
1 | import cached_conv as cc
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .pqmf import kaiser_filter
7 |
8 |
9 | class Resampler(nn.Module):
10 |
11 | def __init__(self, target_sr, model_sr):
12 | super().__init__()
13 | assert target_sr != model_sr, "identical source and target rates"
14 |
15 | self.model_sr = model_sr
16 | self.taget_sr = target_sr
17 |
18 | ratio = target_sr // model_sr
19 | assert int(ratio) == ratio
20 |
21 | if ratio % 2 and cc.USE_BUFFER_CONV:
22 | raise ValueError(
23 | f"When using streaming mode, resampling ratio must be a power of 2, got {ratio}"
24 | )
25 |
26 | wc = np.pi / ratio
27 | filt = kaiser_filter(wc, 140)
28 | filt = torch.from_numpy(filt).float()
29 |
30 | self.downsample = cc.Conv1d(
31 | 1,
32 | 1,
33 | len(filt),
34 | stride=ratio,
35 | padding=cc.get_padding(len(filt), ratio),
36 | bias=False,
37 | )
38 |
39 | self.downsample.weight.data.copy_(filt.reshape(1, 1, -1))
40 |
41 | pad = len(filt) % ratio
42 |
43 | filt = nn.functional.pad(filt, (pad, 0))
44 | filt = filt.reshape(-1, ratio).permute(1, 0)
45 |
46 | pad = (filt.shape[-1] + 1) % 2
47 | filt = nn.functional.pad(filt, (pad, 0)).unsqueeze(1)
48 |
49 | self.upsample = cc.Conv1d(1,
50 | ratio,
51 | filt.shape[-1],
52 | stride=1,
53 | padding=cc.get_padding(filt.shape[-1]),
54 | bias=False)
55 |
56 | self.upsample.weight.data.copy_(filt)
57 |
58 | self.ratio = ratio
59 |
60 | def to_model_sampling_rate(self, x):
61 | x_down = x.reshape(-1, 1, x.shape[-1])
62 | x_down = self.downsample(x_down)
63 | return x_down.reshape(x.shape[0], x.shape[1], -1)
64 |
65 | def from_model_sampling_rate(self, x):
66 | x_up = x.reshape(-1, 1, x.shape[-1])
67 | x_up = self.upsample(x_up) # B x 2 x T
68 | x_up = x_up.permute(0, 2, 1).reshape(x_up.shape[0], -1).unsqueeze(1)
69 | x_up = x_up.reshape(x.shape[0], x.shape[1], -1)
70 | return x_up
71 |
--------------------------------------------------------------------------------
/rave/transforms.py:
--------------------------------------------------------------------------------
1 | from random import choice, randint, random, randrange
2 | import bisect
3 | import torchaudio
4 | import gin.torch
5 | from typing import Tuple
6 | import librosa as li
7 | import numpy as np
8 | import torch
9 | import scipy.signal as signal
10 | from udls.transforms import *
11 |
12 |
13 | class Transform(object):
14 | def __call__(self, x: torch.Tensor):
15 | raise NotImplementedError
16 |
17 |
18 | class RandomApply(Transform):
19 | """
20 | Apply transform with probability p
21 | """
22 | def __init__(self, transform, p=.5):
23 | self.transform = transform
24 | self.p = p
25 |
26 | def __call__(self, x: np.ndarray):
27 | if random() < self.p:
28 | x = self.transform(x)
29 | return x
30 |
31 | class Resample(Transform):
32 | """
33 | Resample target signal to target sample rate.
34 | """
35 | def __init__(self, orig_sr: int, target_sr: int):
36 | self.orig_sr = orig_sr
37 | self.target_sr = target_sr
38 |
39 | def __call__(self, x: np.ndarray):
40 | return torchaudio.functional.resample(torch.from_numpy(x).float(), self.orig_sr, self.target_sr).numpy()
41 |
42 |
43 | class Compose(Transform):
44 | """
45 | Apply a list of transform sequentially
46 | """
47 | def __init__(self, transform_list):
48 | self.transform_list = transform_list
49 |
50 | def __call__(self, x: np.ndarray):
51 | for elm in self.transform_list:
52 | x = elm(x)
53 | return x
54 |
55 |
56 | class RandomPitch(Transform):
57 | def __init__(self, n_signal, pitch_range = [0.7, 1.3], max_factor: int = 20, prob: float = 0.5):
58 | self.n_signal = n_signal
59 | self.pitch_range = pitch_range
60 | self.factor_list, self.ratio_list = self._get_factors(max_factor, pitch_range)
61 | self.prob = prob
62 |
63 | def _get_factors(self, factor_limit, pitch_range):
64 | factor_list = []
65 | ratio_list = []
66 | for x in range(1, factor_limit):
67 | for y in range(1, factor_limit):
68 | if (x==y):
69 | continue
70 | factor = x / y
71 | if factor <= pitch_range[1] and factor >= pitch_range[0]:
72 | i = bisect.bisect_left(factor_list, factor)
73 | factor_list.insert(i, factor)
74 | ratio_list.insert(i, (x, y))
75 | return factor_list, ratio_list
76 |
77 | def __call__(self, x: np.ndarray):
78 | perform_pitch = bool(torch.bernoulli(torch.tensor(self.prob)))
79 | if not perform_pitch:
80 | return x
81 | random_range = list(self.pitch_range)
82 | random_range[1] = min(random_range[1], x.shape[-1] / self.n_signal)
83 | random_pitch = random() * (random_range[1] - random_range[0]) + random_range[0]
84 | ratio_idx = bisect.bisect_left(self.factor_list, random_pitch)
85 | if ratio_idx == len(self.factor_list):
86 | ratio_idx -= 1
87 | up, down = self.ratio_list[ratio_idx]
88 | x_pitched = signal.resample_poly(x, up, down, padtype='mean', axis=-1)
89 | return x_pitched
90 |
91 |
92 | class RandomCrop(Transform):
93 | """
94 | Randomly crops signal to fit n_signal samples
95 | """
96 | def __init__(self, n_signal):
97 | self.n_signal = n_signal
98 |
99 | def __call__(self, x: np.ndarray):
100 | in_point = randint(0, x.shape[-1] - self.n_signal)
101 | x = x[..., in_point:in_point + self.n_signal]
102 | return x
103 |
104 |
105 | class Dequantize(Transform):
106 | def __init__(self, bit_depth):
107 | self.bit_depth = bit_depth
108 |
109 | def __call__(self, x: np.ndarray):
110 | x += np.random.rand(*x.shape) / 2**self.bit_depth
111 | return x
112 |
113 |
114 | @gin.configurable
115 | class Compress(Transform):
116 | def __init__(self, time="0.1,0.1", lookup="6:-70,-60,-20 ", gain="0", sr=44100):
117 | self.sox_args = ['compand', time, lookup, gain]
118 | self.sr = sr
119 |
120 | def __call__(self, x: torch.Tensor):
121 | x = torchaudio.sox_effects.apply_effects_tensor(torch.from_numpy(x).float(), self.sr, [self.sox_args])[0].numpy()
122 | return x
123 |
124 | @gin.configurable
125 | class RandomCompress(Transform):
126 | def __init__(self, threshold = -40, amp_range = [-60, 0], attack=0.1, release=0.1, prob=0.8, sr=44100):
127 | assert prob >= 0. and prob <= 1., "prob must be between 0. and 1."
128 | self.amp_range = amp_range
129 | self.threshold = threshold
130 | self.attack = attack
131 | self.release = release
132 | self.prob = prob
133 | self.sr = sr
134 |
135 | def __call__(self, x: torch.Tensor):
136 | perform = bool(torch.bernoulli(torch.full((1,), self.prob)))
137 | if perform:
138 | amp_factor = torch.rand((1,)) * (self.amp_range[1] - self.amp_range[0]) + self.amp_range[0]
139 | x_aug = torchaudio.sox_effects.apply_effects_tensor(torch.from_numpy(x).float(),
140 | self.sr,
141 | [['compand', f'{self.attack},{self.release}', f'6:-80,{self.threshold},{float(amp_factor)}']]
142 | )[0].numpy()
143 | return x_aug
144 | else:
145 | return x
146 |
147 | @gin.configurable
148 | class RandomGain(Transform):
149 | def __init__(self, gain_range: Tuple[int, int] = [-6, 3], prob: float = 0.5, limit = True):
150 | assert prob >= 0. and prob <= 1., "prob must be between 0. and 1."
151 | self.gain_range = gain_range
152 | self.prob = prob
153 | self.limit = limit
154 |
155 | def __call__(self, x: torch.Tensor):
156 | perform = bool(torch.bernoulli(torch.full((1,), self.prob)))
157 | if perform:
158 | gain_factor = np.random.rand(1)[None, None][0] * (self.gain_range[1] - self.gain_range[0]) + self.gain_range[0]
159 | amp_factor = np.power(10, gain_factor / 20)
160 | x_amp = x * amp_factor
161 | if (self.limit) and (np.abs(x_amp).max() > 1):
162 | x_amp = x_amp / np.abs(x_amp).max()
163 | return x
164 | else:
165 | return x
166 |
167 |
168 | @gin.configurable
169 | class RandomMute(Transform):
170 | def __init__(self, prob: torch.Tensor = 0.1):
171 | assert prob >= 0. and prob <= 1., "prob must be between 0. and 1."
172 | self.prob = prob
173 |
174 | def __call__(self, x: torch.Tensor):
175 | mask = torch.bernoulli(torch.full((x.shape[0],), 1 - self.prob))
176 | mask = np.random.binomial(1, 1-self.prob, size=1)
177 | return x * mask
178 |
179 |
180 | @gin.configurable
181 | class FrequencyMasking(Transform):
182 | def __init__(self, prob = 0.5, max_size: int = 80):
183 | self.prob = prob
184 | self.max_size = max_size
185 |
186 | def __call__(self, x: torch.Tensor):
187 | perform = bool(torch.bernoulli(torch.full((1,), self.prob)))
188 | if not perform:
189 | return x
190 | spectrogram = signal.stft(x, nperseg=4096)[2]
191 | mask_size = randrange(1, self.max_size)
192 | freq_idx = randrange(0, spectrogram.shape[-2] - mask_size)
193 | spectrogram[..., freq_idx:freq_idx+mask_size, :] = 0
194 | x_inv = signal.istft(spectrogram)[1]
195 | return x_inv
196 |
197 |
198 |
199 | # Utilitary for GIN recording of augmentations
200 |
201 |
202 | _augmentations = []
203 |
204 | @gin.configurable()
205 | def add_augmentation(aug):
206 | global _augmentations
207 | _augmentations.append(aug)
208 |
209 | def get_augmentations():
210 | return _augmentations
--------------------------------------------------------------------------------
/rave/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "2.3.1"
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py>=1.2.0
2 | einops>=0.5.0
3 | gin-config
4 | GPUtil>=1.4.0
5 | librosa>=0.9.2
6 | numpy>=1.23.3
7 | pytorch_lightning==1.9.0
8 | PyYAML>=6.0
9 | scikit_learn>=1.1.2
10 | scipy==1.10.0
11 | torch
12 | tqdm>=4.64.1
13 | udls>=1.0.1
14 | cached-conv>=2.5.0
15 | nn-tilde>=1.5.2
16 | torchaudio
17 | tensorboard
18 | pytest>=7.2.2
19 | Flask>=2.2.3
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/export_onnx.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | torch.set_grad_enabled(False)
4 | import os
5 |
6 | import cached_conv as cc
7 | import gin
8 | import torch.nn as nn
9 | from absl import app, flags
10 | from effortless_config import Config
11 |
12 | import rave
13 |
14 | flags.DEFINE_string('run', default=None, required=True, help='Run to export')
15 | FLAGS = flags.FLAGS
16 |
17 |
18 | def main(argv):
19 | gin.parse_config_file(os.path.join(FLAGS.run, "config.gin"))
20 | checkpoint = rave.core.search_for_run(FLAGS.run)
21 |
22 | print(f"using {checkpoint}")
23 |
24 | pretrained = rave.RAVE()
25 | pretrained.load_state_dict(torch.load(checkpoint)["state_dict"])
26 | pretrained.eval()
27 |
28 | for m in pretrained.modules():
29 | if hasattr(m, "weight_g"):
30 | nn.utils.remove_weight_norm(m)
31 |
32 | def recursive_replace(model: nn.Module):
33 | for name, child in model.named_children():
34 | if isinstance(child, cc.convs.Conv1d):
35 | conv = nn.Conv1d(
36 | child.in_channels,
37 | child.out_channels,
38 | child.kernel_size,
39 | child.stride,
40 | child._pad[0],
41 | child.dilation,
42 | child.groups,
43 | child.bias,
44 | )
45 | conv.weight.data.copy_(child.weight.data)
46 | if conv.bias is not None:
47 | conv.bias.data.copy_(child.bias.data)
48 | setattr(model, name, conv)
49 | elif isinstance(child, cc.convs.ConvTranspose1d):
50 | conv = nn.ConvTranspose1d(
51 | child.in_channels,
52 | child.out_channels,
53 | child.kernel_size,
54 | child.stride,
55 | child.padding,
56 | child.output_padding,
57 | child.groups,
58 | child.bias,
59 | child.dilation,
60 | child.padding_mode,
61 | )
62 | conv.weight.data.copy_(child.weight.data)
63 | if conv.bias is not None:
64 | conv.bias.data.copy_(child.bias.data)
65 | setattr(model, name, conv)
66 | else:
67 | recursive_replace(child)
68 |
69 | recursive_replace(pretrained)
70 |
71 | x = torch.randn(1, pretrained.n_channels, 2**15)
72 | pretrained(x)
73 |
74 | name = os.path.basename(os.path.normpath(FLAGS.run))
75 | export_path = os.path.join(FLAGS.run, name)
76 | torch.onnx.export(
77 | pretrained,
78 | x,
79 | f"{export_path}.onnx",
80 | export_params=True,
81 | opset_version=12,
82 | input_names=["audio_in"],
83 | output_names=["audio_out"],
84 | dynamic_axes={
85 | "audio_in": {
86 | 2: "audio_length"
87 | },
88 | "audio_out": [0],
89 | },
90 | do_constant_folding=False,
91 | )
92 |
93 |
94 | if __name__ == '__main__':
95 | app.run(main)
--------------------------------------------------------------------------------
/scripts/generate.py:
--------------------------------------------------------------------------------
1 | from absl import app, flags, logging
2 | import pdb
3 | import torch, torchaudio, argparse, os, tqdm, re, gin
4 | import cached_conv as cc
5 |
6 | try:
7 | import rave
8 | except:
9 | import sys, os
10 | sys.path.append(os.path.abspath('.'))
11 | import rave
12 |
13 |
14 | FLAGS = flags.FLAGS
15 | flags.DEFINE_string('model', required=True, default=None, help="model path")
16 | flags.DEFINE_multi_string('input', required=True, default=None, help="model inputs (file or folder)")
17 | flags.DEFINE_string('out_path', 'generations', help="output path")
18 | flags.DEFINE_string('name', None, help="name of the model")
19 | flags.DEFINE_integer('gpu', default=-1, help='GPU to use')
20 | flags.DEFINE_bool('stream', default=False, help='simulates streaming mode')
21 | flags.DEFINE_integer('chunk_size', default=None, help="chunk size for encoding/decoding (default: full file)")
22 |
23 |
24 | def get_audio_files(path):
25 | audio_files = []
26 | valid_exts = rave.core.get_valid_extensions()
27 | for root, _, files in os.walk(path):
28 | valid_files = list(filter(lambda x: os.path.splitext(x)[1] in valid_exts, files))
29 | audio_files.extend([(path, os.path.join(root, f)) for f in valid_files])
30 | return audio_files
31 |
32 |
33 | def main(argv):
34 | torch.set_float32_matmul_precision('high')
35 | cc.use_cached_conv(FLAGS.stream)
36 |
37 | model_path = FLAGS.model
38 | paths = FLAGS.input
39 | # load model
40 | logging.info("building rave")
41 | is_scripted = False
42 | if not os.path.exists(model_path):
43 | logging.error('path %s does not seem to exist.'%model_path)
44 | exit()
45 | if os.path.splitext(model_path)[1] == ".ts":
46 | model = torch.jit.load(model_path)
47 | is_scripted = True
48 | else:
49 | config_path = rave.core.search_for_config(model_path)
50 | if config_path is None:
51 | logging.error('config not found in folder %s'%model_path)
52 | gin.parse_config_file(config_path)
53 | model = rave.RAVE()
54 | run = rave.core.search_for_run(model_path)
55 | if run is None:
56 | logging.error("run not found in folder %s"%model_path)
57 | model = model.load_from_checkpoint(run)
58 |
59 | # device
60 | if FLAGS.gpu >= 0:
61 | device = torch.device('cuda:%d'%FLAGS.gpu)
62 | model = model.to(device)
63 | else:
64 | device = torch.device('cpu')
65 |
66 |
67 | # make output directories
68 | if FLAGS.name is None:
69 | FLAGS.name = "_".join(os.path.basename(model_path).split('_')[:-1])
70 | out_path = os.path.join(FLAGS.out_path, FLAGS.name)
71 | os.makedirs(out_path, exist_ok=True)
72 |
73 | # parse inputs
74 | audio_files = sum([get_audio_files(f) for f in paths], [])
75 | receptive_field = rave.core.get_minimum_size(model)
76 |
77 | progress_bar = tqdm.tqdm(audio_files)
78 | cc.MAX_BATCH_SIZE = 8
79 |
80 | for i, (d, f) in enumerate(progress_bar):
81 | #TODO reset cache
82 |
83 | try:
84 | x, sr = torchaudio.load(f)
85 | except:
86 | logging.warning('could not open file %s.'%f)
87 | continue
88 | progress_bar.set_description(f)
89 |
90 | # load file
91 | if sr != model.sr:
92 | x = torchaudio.functional.resample(x, sr, model.sr)
93 | if model.n_channels != x.shape[0]:
94 | if model.n_channels < x.shape[0]:
95 | x = x[:model.n_channels]
96 | else:
97 | print('[Warning] file %s has %d channels, butt model has %d channels ; skipping'%(f, model.n_channels))
98 | x = x.to(device)
99 | if FLAGS.stream:
100 | if FLAGS.chunk_size:
101 | assert FLAGS.chunk_size > receptive_field, "chunk_size must be higher than models' receptive field (here : %s)"%receptive_field
102 | x = list(x.split(FLAGS.chunk_size, dim=-1))
103 | if x[-1].shape[0] < FLAGS.chunk_size:
104 | x[-1] = torch.nn.functional.pad(x[-1], (0, FLAGS.chunk_size - x[-1].shape[-1]))
105 | x = torch.stack(x, 0)
106 | else:
107 | x = x[None]
108 |
109 | # forward into model
110 | out = []
111 | for x_chunk in x:
112 | x_chunk = x_chunk.to(device)
113 | out_tmp = model(x_chunk[None])
114 | out.append(out_tmp)
115 | out = torch.cat(out, -1)
116 | else:
117 | out = model.forward(x[None])
118 |
119 | # save file
120 | out_path = re.sub(d, "", f)
121 | out_path = os.path.join(FLAGS.out_path, f)
122 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
123 | torchaudio.save(out_path, out[0].cpu(), sample_rate=model.sr)
124 |
125 | if __name__ == "__main__":
126 | app.run(main)
--------------------------------------------------------------------------------
/scripts/main_cli.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from absl import app
4 |
5 | AVAILABLE_SCRIPTS = [
6 | 'preprocess', 'train', 'train_prior', 'export', 'export_onnx', 'remote_dataset', 'generate'
7 | ]
8 |
9 |
10 | def help():
11 | print(f"""usage: rave [ {' | '.join(AVAILABLE_SCRIPTS)} ]
12 |
13 | positional arguments:
14 | command Command to launch with rave.
15 | """)
16 | exit()
17 |
18 |
19 | def main():
20 | if len(sys.argv) == 1:
21 | help()
22 | elif sys.argv[1] not in AVAILABLE_SCRIPTS:
23 | help()
24 |
25 | command = sys.argv[1]
26 |
27 | if command == 'train':
28 | from scripts import train
29 | sys.argv[0] = train.__name__
30 | app.run(train.main)
31 | elif command == 'train_prior':
32 | from scripts import train_prior
33 | sys.argv[0] = train_prior.__name__
34 | app.run(train_prior.main)
35 | elif command == 'export':
36 | from scripts import export
37 | sys.argv[0] = export.__name__
38 | app.run(export.main)
39 | elif command == 'preprocess':
40 | from scripts import preprocess
41 | sys.argv[0] = preprocess.__name__
42 | app.run(preprocess.main)
43 | elif command == 'export_onnx':
44 | from scripts import export_onnx
45 | sys.argv[0] = export_onnx.__name__
46 | app.run(export_onnx.main)
47 | elif command == "generate":
48 | from scripts import generate
49 | sys.argv[0] = generate.__name__
50 | app.run(generate.main)
51 | elif command == 'remote_dataset':
52 | from scripts import remote_dataset
53 | sys.argv[0] = remote_dataset.__name__
54 | app.run(remote_dataset.main)
55 | else:
56 | raise Exception(f'Command {command} not found')
57 |
--------------------------------------------------------------------------------
/scripts/preprocess.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import multiprocessing
3 | import os
4 | import pathlib
5 | import subprocess
6 | from datetime import timedelta
7 | from functools import partial
8 | from itertools import repeat
9 | from typing import Callable, Iterable, Sequence, Tuple
10 |
11 | import lmdb
12 | import numpy as np
13 | import torch
14 | import yaml
15 | import math
16 | from absl import app, flags
17 | from tqdm import tqdm
18 | from udls.generated import AudioExample
19 |
20 | torch.set_grad_enabled(False)
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | flags.DEFINE_multi_string('input_path',
25 | None,
26 | help='Path to a directory containing audio files',
27 | required=True)
28 | flags.DEFINE_string('output_path',
29 | None,
30 | help='Output directory for the dataset',
31 | required=True)
32 | flags.DEFINE_integer('num_signal',
33 | 131072,
34 | help='Number of audio samples to use during training')
35 | flags.DEFINE_integer('channels', 1, help="Number of audio channels")
36 | flags.DEFINE_integer('sampling_rate',
37 | 44100,
38 | help='Sampling rate to use during training')
39 | flags.DEFINE_integer('max_db_size',
40 | 100,
41 | help='Maximum size (in GB) of the dataset')
42 | flags.DEFINE_multi_string(
43 | 'ext',
44 | default=['aif', 'aiff', 'wav', 'opus', 'mp3', 'aac', 'flac', 'ogg'],
45 | help='Extension to search for in the input directory')
46 | flags.DEFINE_bool('lazy',
47 | default=False,
48 | help='Decode and resample audio samples.')
49 | flags.DEFINE_bool('dyndb',
50 | default=True,
51 | help="Allow the database to grow dynamically")
52 |
53 |
54 | def float_array_to_int16_bytes(x):
55 | return np.floor(x * (2**15 - 1)).astype(np.int16).tobytes()
56 |
57 |
58 | def load_audio_chunk(path: str, n_signal: int,
59 | sr: int, channels: int = 1) -> Iterable[np.ndarray]:
60 |
61 | _, input_channels = get_audio_channels(path)
62 | channel_map = range(channels)
63 | if input_channels < channels:
64 | channel_map = (math.ceil(channels / input_channels) * list(range(input_channels)))[:channels]
65 |
66 | processes = []
67 | for i in range(channels):
68 | process = subprocess.Popen(
69 | [
70 | 'ffmpeg', '-hide_banner', '-loglevel', 'panic', '-i', path,
71 | '-ar', str(sr),
72 | '-f', 's16le',
73 | '-filter_complex', 'channelmap=%d-0'%channel_map[i],
74 | '-'
75 | ],
76 | stdout=subprocess.PIPE,
77 | )
78 | processes.append(process)
79 |
80 | chunk = [p.stdout.read(n_signal * 4) for p in processes]
81 | while len(chunk[0]) == n_signal * 4:
82 | yield b''.join(chunk)
83 | chunk = [p.stdout.read(n_signal * 4) for p in processes]
84 | process.stdout.close()
85 |
86 |
87 | def get_audio_length(path: str) -> float:
88 | process = subprocess.Popen(
89 | [
90 | 'ffprobe', '-i', path, '-v', 'error', '-show_entries',
91 | 'format=duration'
92 | ],
93 | stdout=subprocess.PIPE,
94 | stderr=subprocess.PIPE,
95 | )
96 | stdout, _ = process.communicate()
97 | if process.returncode: return None
98 | try:
99 | stdout = stdout.decode().split('\n')[1].split('=')[-1]
100 | length = float(stdout)
101 | _, channels = get_audio_channels(path)
102 | return path, float(length), int(channels)
103 | except:
104 | return None
105 |
106 | def get_audio_channels(path: str) -> int:
107 | process = subprocess.Popen(
108 | [
109 | 'ffprobe', '-i', path, '-v', 'error', '-show_entries',
110 | 'stream=channels'
111 | ],
112 | stdout=subprocess.PIPE,
113 | stderr=subprocess.PIPE,
114 | )
115 | stdout, _ = process.communicate()
116 | if process.returncode: return None
117 | try:
118 | stdout = stdout.decode().split('\n')[1].split('=')[-1]
119 | channels = int(stdout)
120 | return path, int(channels)
121 | except:
122 | return None
123 |
124 |
125 | def flatten(iterator: Iterable):
126 | for elm in iterator:
127 | for sub_elm in elm:
128 | yield sub_elm
129 |
130 | def get_metadata(audio_samples, channels: int = 1):
131 | audio = np.frombuffer(audio_samples, dtype=np.int16)
132 | audio = audio.astype(float) / (2**15 - 1)
133 | audio = audio.reshape(channels, -1)
134 | peak_amplitude = np.amax(np.abs(audio))
135 | rms_amplitude = np.sqrt(np.mean(audio**2))
136 | return {'peak': peak_amplitude, 'rms_amplitude': rms_amplitude}
137 |
138 |
139 | def process_audio_array(audio: Tuple[int, bytes],
140 | env: lmdb.Environment,
141 | channels: int = 1) -> int:
142 | audio_id, audio_samples = audio
143 | buffers = {}
144 | buffers['waveform'] = AudioExample.AudioBuffer(
145 | shape=(channels, int(len(audio_samples) / channels)),
146 | sampling_rate=FLAGS.sampling_rate,
147 | data=audio_samples,
148 | precision=AudioExample.Precision.INT16,
149 | )
150 |
151 | ae = AudioExample(buffers=buffers)
152 | key = f'{audio_id:08d}'
153 | with env.begin(write=True) as txn:
154 | txn.put(
155 | key.encode(),
156 | ae.SerializeToString(),
157 | )
158 | return audio_id
159 |
160 |
161 | def process_audio_file(audio: Tuple[int, Tuple[str, float]],
162 | env: lmdb.Environment) -> int:
163 | audio_id, (path, length, channels) = audio
164 | ae = AudioExample(metadata={'path': path, 'length': str(length), 'channels': str(channels)})
165 | key = f'{audio_id:08d}'
166 | with env.begin(write=True) as txn:
167 | txn.put(
168 | key.encode(),
169 | ae.SerializeToString(),
170 | )
171 | return length
172 |
173 |
174 | def flatmap(pool: multiprocessing.Pool,
175 | func: Callable,
176 | iterable: Iterable,
177 | chunksize=None):
178 | queue = multiprocessing.Manager().Queue(maxsize=os.cpu_count())
179 | pool.map_async(
180 | functools.partial(flat_mappper, func),
181 | zip(iterable, repeat(queue)),
182 | chunksize,
183 | lambda _: queue.put(None),
184 | lambda *e: print(e),
185 | )
186 |
187 | item = queue.get()
188 | while item is not None:
189 | yield item
190 | item = queue.get()
191 |
192 |
193 | def flat_mappper(func, arg):
194 | data, queue = arg
195 | for item in func(data):
196 | queue.put(item)
197 |
198 |
199 | def search_for_audios(path_list: Sequence[str], extensions: Sequence[str]):
200 | paths = map(pathlib.Path, path_list)
201 | audios = []
202 | for p in paths:
203 | for ext in extensions:
204 | audios.append(p.rglob(f'*.{ext}'))
205 | audios.append(p.rglob(f'*.{ext.upper()}'))
206 | audios = flatten(audios)
207 | return audios
208 |
209 |
210 | def main(argv):
211 | if FLAGS.lazy and os.name in ["nt", "posix"]:
212 | while (answer := input(
213 | "Using lazy datasets on Windows/macOS might result in slow training. Continue ? (y/n) "
214 | ).lower()) not in ["y", "n"]:
215 | print("Answer 'y' or 'n'.")
216 | if answer == "n":
217 | print("Aborting...")
218 | exit()
219 |
220 |
221 | chunk_load = partial(load_audio_chunk,
222 | n_signal=FLAGS.num_signal,
223 | sr=FLAGS.sampling_rate,
224 | channels=FLAGS.channels)
225 |
226 | output_dir = os.path.join(*os.path.split(FLAGS.output_path)[:-1])
227 | if not os.path.isdir(output_dir):
228 | os.makedirs(output_dir)
229 |
230 | # create database
231 | env = lmdb.open(
232 | FLAGS.output_path,
233 | map_size=FLAGS.max_db_size * 1024**3,
234 | map_async=not FLAGS.dyndb,
235 | writemap=not FLAGS.dyndb,
236 | )
237 | pool = multiprocessing.Pool()
238 |
239 |
240 | # search for audio files
241 | audios = search_for_audios(FLAGS.input_path, FLAGS.ext)
242 | audios = map(str, audios)
243 | audios = map(os.path.abspath, audios)
244 | audios = [*audios]
245 | if len(audios) == 0:
246 | print("No valid file found in %s. Aborting"%FLAGS.input_path)
247 |
248 | if not FLAGS.lazy:
249 |
250 | # load chunks
251 | chunks = flatmap(pool, chunk_load, audios)
252 | chunks = enumerate(chunks)
253 |
254 | processed_samples = map(partial(process_audio_array, env=env, channels=FLAGS.channels), chunks)
255 |
256 | pbar = tqdm(processed_samples)
257 | n_seconds = 0
258 | for audio_id in pbar:
259 | n_seconds = (FLAGS.num_signal * 2) / FLAGS.sampling_rate * audio_id
260 | pbar.set_description(
261 | f'dataset length: {timedelta(seconds=n_seconds)}')
262 | pbar.close()
263 | else:
264 | audio_lengths = pool.imap_unordered(get_audio_length, audios)
265 | audio_lengths = filter(lambda x: x is not None, audio_lengths)
266 | audio_lengths = enumerate(audio_lengths)
267 | processed_samples = map(partial(process_audio_file, env=env),
268 | audio_lengths)
269 | pbar = tqdm(processed_samples)
270 | n_seconds = 0
271 | for length in pbar:
272 | n_seconds += length
273 | pbar.set_description(
274 | f'dataset length: {timedelta(seconds=n_seconds)}')
275 | pbar.close()
276 |
277 | with open(os.path.join(
278 | FLAGS.output_path,
279 | 'metadata.yaml',
280 | ), 'w') as metadata:
281 | yaml.safe_dump({'lazy': FLAGS.lazy, 'channels': FLAGS.channels, 'n_seconds': n_seconds, 'sr': FLAGS.sampling_rate}, metadata)
282 | pool.close()
283 | env.close()
284 |
285 |
286 | if __name__ == '__main__':
287 | app.run(main)
288 |
--------------------------------------------------------------------------------
/scripts/remote_dataset.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import logging
3 | import os
4 |
5 | import flask
6 | import numpy as np
7 | from absl import flags
8 | from udls import AudioExample
9 |
10 | from rave.dataset import get_dataset
11 |
12 | logging.basicConfig(level=logging.ERROR)
13 | log = logging.getLogger('werkzeug')
14 | log.setLevel(logging.ERROR)
15 |
16 | FLAGS = flags.FLAGS
17 | flags.DEFINE_string(
18 | "db_path",
19 | default=None,
20 | required=True,
21 | help="path to database.",
22 | )
23 | flags.DEFINE_integer(
24 | "sr",
25 | default=44100,
26 | help="sampling rate.",
27 | )
28 | flags.DEFINE_integer(
29 | "n_signal",
30 | default=2**16,
31 | help="sample size.",
32 | )
33 | flags.DEFINE_integer(
34 | "port",
35 | default=5000,
36 | help="port to serve the dataset.",
37 | )
38 |
39 |
40 | def main(argv):
41 | app = flask.Flask(__name__)
42 | dataset = get_dataset(db_path=FLAGS.db_path,
43 | sr=FLAGS.sr,
44 | n_signal=FLAGS.n_signal)
45 |
46 | @app.route("/")
47 | def main():
48 | return ("RAVE remote dataset
\n"
49 | f"Serving: {os.path.abspath(FLAGS.db_path)}
\n"
50 | f"Length: {len(dataset)}
")
51 |
52 | @app.route("/len")
53 | def length():
54 | return flask.jsonify(len(dataset))
55 |
56 | @app.route("/get/")
57 | def get(index):
58 | index = int(index)
59 | ae = AudioExample()
60 | ae.put("audio", dataset[index], np.float32)
61 | ae = base64.b64encode(bytes(ae))
62 | return ae
63 |
64 | app.run(host="0.0.0.0", port=FLAGS.port)
65 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import sys
4 | from typing import Any, Dict
5 |
6 | import gin
7 | import pytorch_lightning as pl
8 | import torch
9 | from absl import flags, app
10 | from torch.utils.data import DataLoader
11 |
12 | try:
13 | import rave
14 | except:
15 | import sys, os
16 | sys.path.append(os.path.abspath('.'))
17 | import rave
18 |
19 | import rave
20 | import rave.core
21 | import rave.dataset
22 | from rave.transforms import get_augmentations, add_augmentation
23 |
24 |
25 | FLAGS = flags.FLAGS
26 |
27 | flags.DEFINE_string('name', None, help='Name of the run', required=True)
28 | flags.DEFINE_multi_string('config',
29 | default='v2.gin',
30 | help='RAVE configuration to use')
31 | flags.DEFINE_multi_string('augment',
32 | default = [],
33 | help = 'augmentation configurations to use')
34 | flags.DEFINE_string('db_path',
35 | None,
36 | help='Preprocessed dataset path',
37 | required=True)
38 | flags.DEFINE_string('out_path',
39 | default="runs/",
40 | help='Output folder')
41 | flags.DEFINE_integer('max_steps',
42 | 6000000,
43 | help='Maximum number of training steps')
44 | flags.DEFINE_integer('val_every', 10000, help='Checkpoint model every n steps')
45 | flags.DEFINE_integer('save_every',
46 | 500000,
47 | help='save every n steps (default: just last)')
48 | flags.DEFINE_integer('n_signal',
49 | 131072,
50 | help='Number of audio samples to use during training')
51 | flags.DEFINE_integer('channels', 0, help="number of audio channels")
52 | flags.DEFINE_integer('batch', 8, help='Batch size')
53 | flags.DEFINE_string('ckpt',
54 | None,
55 | help='Path to previous checkpoint of the run')
56 | flags.DEFINE_multi_string('override', default=[], help='Override gin binding')
57 | flags.DEFINE_integer('workers',
58 | default=8,
59 | help='Number of workers to spawn for dataset loading')
60 | flags.DEFINE_multi_integer('gpu', default=None, help='GPU to use')
61 | flags.DEFINE_bool('derivative',
62 | default=False,
63 | help='Train RAVE on the derivative of the signal')
64 | flags.DEFINE_bool('normalize',
65 | default=False,
66 | help='Train RAVE on normalized signals')
67 | flags.DEFINE_list('rand_pitch',
68 | default=None,
69 | help='activates random pitch')
70 | flags.DEFINE_float('ema',
71 | default=None,
72 | help='Exponential weight averaging factor (optional)')
73 | flags.DEFINE_bool('progress',
74 | default=True,
75 | help='Display training progress bar')
76 | flags.DEFINE_bool('smoke_test',
77 | default=False,
78 | help="Run training with n_batches=1 to test the model")
79 |
80 |
81 | class EMA(pl.Callback):
82 |
83 | def __init__(self, factor=.999) -> None:
84 | super().__init__()
85 | self.weights = {}
86 | self.factor = factor
87 |
88 | def on_train_batch_end(self, trainer, pl_module, outputs, batch,
89 | batch_idx) -> None:
90 | for n, p in pl_module.named_parameters():
91 | if n not in self.weights:
92 | self.weights[n] = p.data.clone()
93 | continue
94 |
95 | self.weights[n] = self.weights[n] * self.factor + p.data * (
96 | 1 - self.factor)
97 |
98 | def swap_weights(self, module):
99 | for n, p in module.named_parameters():
100 | current = p.data.clone()
101 | p.data.copy_(self.weights[n])
102 | self.weights[n] = current
103 |
104 | def on_validation_epoch_start(self, trainer, pl_module) -> None:
105 | if self.weights:
106 | self.swap_weights(pl_module)
107 | else:
108 | print("no ema weights available")
109 |
110 | def on_validation_epoch_end(self, trainer, pl_module) -> None:
111 | if self.weights:
112 | self.swap_weights(pl_module)
113 | else:
114 | print("no ema weights available")
115 |
116 | def state_dict(self) -> Dict[str, Any]:
117 | return self.weights.copy()
118 |
119 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
120 | self.weights.update(state_dict)
121 |
122 | def add_gin_extension(config_name: str) -> str:
123 | if config_name[-4:] != '.gin':
124 | config_name += '.gin'
125 | return config_name
126 |
127 | def parse_augmentations(augmentations):
128 | for a in augmentations:
129 | gin.parse_config_file(a)
130 | add_augmentation()
131 | gin.clear_config()
132 | return get_augmentations()
133 |
134 | def main(argv):
135 | torch.set_float32_matmul_precision('high')
136 | torch.backends.cudnn.benchmark = True
137 |
138 | # check dataset channels
139 | n_channels = rave.dataset.get_training_channels(FLAGS.db_path, FLAGS.channels)
140 | gin.bind_parameter('RAVE.n_channels', n_channels)
141 |
142 | # parse augmentations
143 | augmentations = parse_augmentations(map(add_gin_extension, FLAGS.augment))
144 | gin.bind_parameter('dataset.get_dataset.augmentations', augmentations)
145 |
146 | # parse configuration
147 | if FLAGS.ckpt:
148 | config_file = rave.core.search_for_config(FLAGS.ckpt)
149 | if config_file is None:
150 | print('Config file not found in %s'%FLAGS.run)
151 | gin.parse_config_file(config_file)
152 | else:
153 | gin.parse_config_files_and_bindings(
154 | map(add_gin_extension, FLAGS.config),
155 | FLAGS.override,
156 | )
157 |
158 | # create model
159 | model = rave.RAVE(n_channels=FLAGS.channels)
160 | if FLAGS.derivative:
161 | model.integrator = rave.dataset.get_derivator_integrator(model.sr)[1]
162 |
163 | # parse datasset
164 | dataset = rave.dataset.get_dataset(FLAGS.db_path,
165 | model.sr,
166 | FLAGS.n_signal,
167 | derivative=FLAGS.derivative,
168 | normalize=FLAGS.normalize,
169 | rand_pitch=FLAGS.rand_pitch,
170 | n_channels=n_channels)
171 | train, val = rave.dataset.split_dataset(dataset, 98)
172 |
173 | # get data-loader
174 | num_workers = FLAGS.workers
175 | if os.name == "nt" or sys.platform == "darwin":
176 | num_workers = 0
177 | train = DataLoader(train,
178 | FLAGS.batch,
179 | True,
180 | drop_last=True,
181 | num_workers=num_workers)
182 | val = DataLoader(val, FLAGS.batch, False, num_workers=num_workers)
183 |
184 | # CHECKPOINT CALLBACKS
185 | validation_checkpoint = pl.callbacks.ModelCheckpoint(monitor="validation",
186 | filename="best")
187 | last_filename = "last" if FLAGS.save_every is None else "epoch-{epoch:04d}"
188 | last_checkpoint = rave.core.ModelCheckpoint(filename=last_filename, step_period=FLAGS.save_every)
189 |
190 | val_check = {}
191 | if len(train) >= FLAGS.val_every:
192 | val_check["val_check_interval"] = 1 if FLAGS.smoke_test else FLAGS.val_every
193 | else:
194 | nepoch = FLAGS.val_every // len(train)
195 | val_check["check_val_every_n_epoch"] = nepoch
196 |
197 | if FLAGS.smoke_test:
198 | val_check['limit_train_batches'] = 1
199 | val_check['limit_val_batches'] = 1
200 |
201 | gin_hash = hashlib.md5(
202 | gin.operative_config_str().encode()).hexdigest()[:10]
203 |
204 | RUN_NAME = f'{FLAGS.name}_{gin_hash}'
205 |
206 | os.makedirs(os.path.join(FLAGS.out_path, RUN_NAME), exist_ok=True)
207 |
208 | if FLAGS.gpu == [-1]:
209 | gpu = 0
210 | else:
211 | gpu = FLAGS.gpu or rave.core.setup_gpu()
212 |
213 | print('selected gpu:', gpu)
214 |
215 | accelerator = None
216 | devices = None
217 | if FLAGS.gpu == [-1]:
218 | pass
219 | elif torch.cuda.is_available():
220 | accelerator = "cuda"
221 | devices = FLAGS.gpu or rave.core.setup_gpu()
222 | elif torch.backends.mps.is_available():
223 | print(
224 | "Training on mac is not available yet. Use --gpu -1 to train on CPU (not recommended)."
225 | )
226 | exit()
227 | accelerator = "mps"
228 | devices = 1
229 |
230 | callbacks = [
231 | validation_checkpoint,
232 | last_checkpoint,
233 | rave.model.WarmupCallback(),
234 | rave.model.QuantizeCallback(),
235 | # rave.core.LoggerCallback(rave.core.ProgressLogger(RUN_NAME)),
236 | rave.model.BetaWarmupCallback(),
237 | ]
238 |
239 | if FLAGS.ema is not None:
240 | callbacks.append(EMA(FLAGS.ema))
241 |
242 | trainer = pl.Trainer(
243 | logger=pl.loggers.TensorBoardLogger(
244 | FLAGS.out_path,
245 | name=RUN_NAME,
246 | ),
247 | accelerator=accelerator,
248 | devices=devices,
249 | callbacks=callbacks,
250 | max_epochs=300000,
251 | max_steps=FLAGS.max_steps,
252 | profiler="simple",
253 | enable_progress_bar=FLAGS.progress,
254 | **val_check,
255 | )
256 |
257 | run = rave.core.search_for_run(FLAGS.ckpt)
258 | if run is not None:
259 | print('loading state from file %s'%run)
260 | loaded = torch.load(run, map_location='cpu')
261 | # model = model.load_state_dict(loaded)
262 | trainer.fit_loop.epoch_loop._batches_that_stepped = loaded['global_step']
263 | # model = model.load_state_dict(loaded['state_dict'])
264 |
265 | with open(os.path.join(FLAGS.out_path, RUN_NAME, "config.gin"), "w") as config_out:
266 | config_out.write(gin.operative_config_str())
267 |
268 | trainer.fit(model, train, val, ckpt_path=run)
269 |
270 |
271 | if __name__ == "__main__":
272 | app.run(main)
273 |
--------------------------------------------------------------------------------
/scripts/train_prior.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import sys
4 |
5 | import gin
6 | import pytorch_lightning as pl
7 | import torch
8 | from absl import flags, app
9 | from torch.utils.data import DataLoader
10 |
11 | try:
12 | import rave
13 | except:
14 | import sys, os
15 | sys.path.append(os.path.abspath('.'))
16 | import rave
17 |
18 | import rave
19 | import rave.dataset
20 | import rave.prior
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | flags.DEFINE_string('name', None, help='Name of the run')
25 | flags.DEFINE_string('model', default=None, required=True, help="pretrained RAVE path")
26 | flags.DEFINE_multi_string('config', default="prior/prior_v1.gin", help="config path")
27 | flags.DEFINE_string('db_path', default=None, required=True, help="Preprocessed dataset path")
28 | flags.DEFINE_string('out_path', default="runs/", help="out directory path")
29 | flags.DEFINE_multi_integer('gpu', default=None, help='GPU to use')
30 | flags.DEFINE_integer('batch', 8, help="batch size")
31 | flags.DEFINE_integer('n_signal', 0, help="chunk size (default: given by prior config)")
32 | flags.DEFINE_string('ckpt', default=None, help="checkpoint to resume")
33 | flags.DEFINE_integer('workers',
34 | default=8,
35 | help='Number of workers to spawn for dataset loading')
36 | flags.DEFINE_integer('val_every', 10000, help='Checkpoint model every n steps')
37 | flags.DEFINE_integer('save_every',
38 | None,
39 | help='save every n steps (default: just last)')
40 | flags.DEFINE_integer('max_steps', default=1000000, help="max training steps")
41 | flags.DEFINE_multi_string('override', default=[], help='Override gin binding')
42 |
43 | flags.DEFINE_bool('derivative',
44 | default=False,
45 | help='Train RAVE on the derivative of the signal')
46 | flags.DEFINE_bool('normalize',
47 | default=False,
48 | help='Train RAVE on normalized signals')
49 | flags.DEFINE_list('rand_pitch',
50 | default=None,
51 | help='activates random pitch')
52 | flags.DEFINE_bool('progress',
53 | default=True,
54 | help='Display training progress bar')
55 | flags.DEFINE_bool('smoke_test',
56 | default=False,
57 | help="Run training with n_batches=1 to test the model")
58 |
59 | def add_gin_extension(config_name: str) -> str:
60 | if config_name[-4:] != '.gin':
61 | config_name += '.gin'
62 | return config_name
63 |
64 |
65 | def main(argv):
66 |
67 | # load pretrained RAVE
68 | config_file = rave.core.search_for_config(FLAGS.model)
69 | if config_file is None:
70 | print('no configuration file found at address :'%FLAGS.model)
71 | gin.parse_config_file(config_file)
72 | run = rave.core.search_for_run(FLAGS.model)
73 | if run is None:
74 | print('no checkpoint found in %s'%FLAGS.model)
75 | exit()
76 | pretrained = rave.RAVE()
77 | print('model found : %s'%run)
78 | checkpoint = torch.load(run, map_location='cpu')
79 | if "EMA" in checkpoint["callbacks"]:
80 | pretrained.load_state_dict(
81 | checkpoint["callbacks"]["EMA"],
82 | strict=False,
83 | )
84 | else:
85 | pretrained.load_state_dict(
86 | checkpoint["state_dict"],
87 | strict=False,
88 | )
89 | pretrained.eval()
90 | gin.clear_config()
91 |
92 | # parse configuration
93 | if FLAGS.ckpt:
94 | config_file = rave.core.search_for_config(FLAGS.ckpt)
95 | if config_file is None:
96 | print('Config gile not found in %s'%FLAGS.run)
97 | gin.parse_config_file(config_file)
98 | else:
99 | gin.parse_config_files_and_bindings(
100 | map(add_gin_extension, FLAGS.config),
101 | FLAGS.override
102 | )
103 |
104 | # create model
105 | if isinstance(pretrained.encoder, rave.blocks.VariationalEncoder):
106 | prior = rave.prior.VariationalPrior(pretrained_vae=pretrained)
107 | else:
108 | raise NotImplementedError("prior not implemented for encoder of type %s"%(type(pretrained.encoder)))
109 |
110 | dataset = rave.dataset.get_dataset(FLAGS.db_path,
111 | pretrained.sr,
112 | max(FLAGS.n_signal, prior.min_receptive_field),
113 | derivative=FLAGS.derivative,
114 | normalize=FLAGS.normalize,
115 | rand_pitch=FLAGS.rand_pitch,
116 | n_channels=pretrained.n_channels)
117 |
118 | train, val = rave.dataset.split_dataset(dataset, 98)
119 |
120 | # get data-loader
121 | num_workers = FLAGS.workers
122 | if os.name == "nt" or sys.platform == "darwin":
123 | num_workers = 0
124 | train = DataLoader(train,
125 | FLAGS.batch,
126 | True,
127 | drop_last=True,
128 | num_workers=num_workers)
129 | val = DataLoader(val, FLAGS.batch, False, num_workers=num_workers)
130 |
131 | # CHECKPOINT CALLBACKS
132 | validation_checkpoint = pl.callbacks.ModelCheckpoint(monitor="validation",
133 | filename="best")
134 | last_filename = "last" if FLAGS.save_every is None else "epoch-{epoch:04d}"
135 | last_checkpoint = rave.core.ModelCheckpoint(filename=last_filename, step_period=FLAGS.save_every)
136 |
137 | val_check = {}
138 | if len(train) >= FLAGS.val_every:
139 | val_check["val_check_interval"] = 1 if FLAGS.smoke_test else FLAGS.val_every
140 | else:
141 | nepoch = FLAGS.val_every // len(train)
142 | val_check["check_val_every_n_epoch"] = nepoch
143 |
144 | if FLAGS.smoke_test:
145 | val_check['limit_train_batches'] = 1
146 | val_check['limit_val_batches'] = 1
147 |
148 | gin_hash = hashlib.md5(
149 | gin.operative_config_str().encode()).hexdigest()[:10]
150 |
151 | RUN_NAME = f'{FLAGS.name}_{gin_hash}'
152 | os.makedirs(os.path.join(FLAGS.out_path, RUN_NAME), exist_ok=True)
153 |
154 | if FLAGS.gpu == [-1]:
155 | gpu = 0
156 | else:
157 | gpu = FLAGS.gpu or rave.core.setup_gpu()
158 |
159 | print('selected gpu:', gpu)
160 |
161 | accelerator = None
162 | devices = None
163 | if FLAGS.gpu == [-1]:
164 | pass
165 | elif torch.cuda.is_available():
166 | accelerator = "cuda"
167 | devices = FLAGS.gpu or rave.core.setup_gpu()
168 | elif torch.backends.mps.is_available():
169 | print(
170 | "Training on mac is not available yet. Use --gpu -1 to train on CPU (not recommended)."
171 | )
172 | exit()
173 | accelerator = "mps"
174 | devices = 1
175 |
176 | callbacks = [
177 | validation_checkpoint,
178 | last_checkpoint,
179 | ]
180 |
181 | trainer = pl.Trainer(
182 | logger=pl.loggers.TensorBoardLogger(
183 | FLAGS.out_path,
184 | name=RUN_NAME,
185 | ),
186 | accelerator=accelerator,
187 | devices=devices,
188 | callbacks=callbacks,
189 | max_epochs=300000,
190 | max_steps=FLAGS.max_steps,
191 | profiler="simple",
192 | enable_progress_bar=FLAGS.progress,
193 | **val_check,
194 | )
195 |
196 | run = rave.core.search_for_run(FLAGS.ckpt)
197 | if run is not None:
198 | print('loading state from file %s'%run)
199 | loaded = torch.load(run, map_location='cpu')
200 | trainer.fit_loop.epoch_loop._batches_that_stepped = loaded['global_step']
201 |
202 | with open(os.path.join(FLAGS.out_path, RUN_NAME, "config.gin"), "w") as config_out:
203 | config_out.write(gin.operative_config_str())
204 |
205 | trainer.fit(prior, train, val, ckpt_path=run)
206 |
207 | if __name__== "__main__":
208 | app.run(main)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 |
4 | import setuptools
5 |
6 | # imports __version__
7 | exec(open('rave/version.py').read())
8 |
9 | with open("README.md", "r") as readme:
10 | readme = readme.read()
11 |
12 | with open("requirements.txt", "r") as requirements:
13 | requirements = requirements.read()
14 |
15 | setuptools.setup(
16 | name="acids-rave",
17 | version=__version__, # type: ignore
18 | author="Antoine CAILLON",
19 | author_email="caillon@ircam.fr",
20 | description="RAVE: a Realtime Audio Variatione autoEncoder",
21 | long_description=readme,
22 | long_description_content_type="text/markdown",
23 | packages=setuptools.find_packages(),
24 | package_data={
25 | 'rave/configs': ['*.gin'],
26 | },
27 | classifiers=[
28 | "Programming Language :: Python :: 3",
29 | "License :: OSI Approved :: MIT License",
30 | "Operating System :: OS Independent",
31 | ],
32 | entry_points={"console_scripts": [
33 | "rave = scripts.main_cli:main",
34 | ]},
35 | install_requires=requirements.split("\n"),
36 | python_requires='>=3.9',
37 | include_package_data=True,
38 | )
39 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acids-ircam/RAVE/f048ec4569afba7c6ba38d590be8a07b8ab24840/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_configs.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 | import tempfile
4 |
5 | import gin
6 | import pytest
7 | import torch
8 | import torch.nn as nn
9 |
10 | import rave
11 | from scripts import export
12 |
13 | gin.enter_interactive_mode()
14 |
15 | configs = [
16 | ["v1.gin"],
17 | ["v2.gin"],
18 | ["v2.gin", "adain.gin"],
19 | ["v2.gin", "wasserstein.gin"],
20 | ["v2.gin", "spherical.gin"],
21 | ["v2.gin", "hybrid.gin"],
22 | ["v2_small.gin", "adain.gin"],
23 | ["v2_small.gin", "wasserstein.gin"],
24 | ["v2_small.gin", "spherical.gin"],
25 | ["v2_small.gin", "hybrid.gin"],
26 | ["discrete.gin"],
27 | ["discrete.gin", "snake.gin"],
28 | ["discrete.gin", "snake.gin", "adain.gin"],
29 | ["discrete.gin", "snake.gin", "descript_discriminator.gin"],
30 | ["discrete.gin", "spectral_discriminator.gin"],
31 | ["discrete.gin", "noise.gin"],
32 | ["discrete.gin", "hybrid.gin"],
33 | ["v3.gin"],
34 | ["v3.gin", "hybrid.gin"]
35 | ]
36 |
37 | configs += [c + ["causal.gin"] for c in configs]
38 |
39 | model_sampling_rate = [44100, 22050]
40 | stereo = [True, False]
41 |
42 | configs = list(itertools.product(configs, model_sampling_rate, stereo))
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "config,sr,stereo",
47 | configs,
48 | ids=map(
49 | lambda e: " ".join(e[0]) + f" [{e[1]}] " +
50 | ("stereo" if e[2] else "mono"), configs),
51 | )
52 | def test_config(config, sr, stereo):
53 |
54 | gin.clear_config()
55 | gin.parse_config_files_and_bindings(config, [
56 | f"SAMPLING_RATE={sr}",
57 | "CAPACITY=2",
58 | ])
59 |
60 | n_channels = 2 if stereo else 1
61 | model = rave.RAVE(n_channels=n_channels)
62 |
63 | x = torch.randn(1, n_channels, 2**15)
64 | z, _ = model.encode(x, return_mb=True)
65 | z, _ = model.encoder.reparametrize(z)[:2]
66 | y = model.decode(z)
67 | score = model.discriminator(y)
68 |
69 | assert x.shape == y.shape
70 |
71 | if isinstance(model.encoder, rave.blocks.VariationalEncoder):
72 | script_class = export.VariationalScriptedRAVE
73 | elif isinstance(model.encoder, rave.blocks.DiscreteEncoder):
74 | script_class = export.DiscreteScriptedRAVE
75 | elif isinstance(model.encoder, rave.blocks.WasserteinEncoder):
76 | script_class = export.WasserteinScriptedRAVE
77 | elif isinstance(model.encoder, rave.blocks.SphericalEncoder):
78 | script_class = export.SphericalScriptedRAVE
79 | else:
80 | raise ValueError(f"Encoder type {type(model.encoder)} "
81 | "not supported for export.")
82 |
83 | x = torch.zeros(1, n_channels, 2**14)
84 |
85 | model(x)
86 |
87 | for m in model.modules():
88 | if hasattr(m, "weight_g"):
89 | nn.utils.remove_weight_norm(m)
90 |
91 | scripted_rave = script_class(
92 | pretrained=model,
93 | channels=n_channels,
94 | )
95 |
96 | scripted_rave_resampled = script_class(
97 | pretrained=model,
98 | channels=n_channels,
99 | target_sr=44100,
100 | )
101 |
102 | with tempfile.TemporaryDirectory() as tmpdir:
103 | scripted_rave.export_to_ts(os.path.join(tmpdir, "ori.ts"))
104 | scripted_rave_resampled.export_to_ts(
105 | os.path.join(tmpdir, "resampled.ts"))
106 |
--------------------------------------------------------------------------------
/tests/test_resampler.py:
--------------------------------------------------------------------------------
1 | import cached_conv as cc
2 | import gin
3 | import pytest
4 | import torch
5 |
6 | from rave.resampler import Resampler
7 |
8 | configs = [(44100, 22050), (48000, 16000)]
9 |
10 |
11 | @pytest.mark.parametrize("target_sr,model_sr", configs)
12 | def test_resampler(target_sr, model_sr):
13 | gin.clear_config()
14 | cc.use_cached_conv(False)
15 |
16 | resampler = Resampler(target_sr, model_sr)
17 |
18 | x = torch.randn(1, 1, 2**12 * 3)
19 |
20 | y = resampler.to_model_sampling_rate(x)
21 | z = resampler.from_model_sampling_rate(y)
22 |
23 | assert x.shape == z.shape
24 |
25 | cc.use_cached_conv(True)
26 |
27 | try:
28 | resampler = Resampler(target_sr, model_sr)
29 |
30 | x = torch.randn(1, 1, 2**12 * 3)
31 |
32 | y = resampler.to_model_sampling_rate(x)
33 | z = resampler.from_model_sampling_rate(y)
34 |
35 | assert x.shape == z.shape
36 |
37 | except ValueError:
38 | pass
39 |
--------------------------------------------------------------------------------
/tests/test_residual.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import cached_conv as cc
4 | import gin
5 | import pytest
6 | import torch
7 |
8 | from rave.blocks import *
9 |
10 | gin.enter_interactive_mode()
11 |
12 | kernel_size = [
13 | 1,
14 | 3,
15 | ]
16 |
17 | dilations = [[1, 1], [3, 1]]
18 |
19 | kernel_sizes = [
20 | [3],
21 | [3, 5],
22 | [3, 5, 7],
23 | ]
24 |
25 | dilations_list = [
26 | [[1, 1]],
27 | [[1, 1], [3, 1], [5, 1]],
28 | ]
29 |
30 | ratios = [
31 | 2,
32 | 4,
33 | 8,
34 | ]
35 |
36 |
37 | @pytest.mark.parametrize('kernel_sizes,dilations_list',
38 | itertools.product(kernel_sizes, dilations_list))
39 | def test_residual_stack(kernel_sizes, dilations_list):
40 | dim = 16
41 | x = torch.randn(1, dim, 32)
42 | cc.use_cached_conv(False)
43 | stack_regular = ResidualStack(
44 | dim=dim,
45 | kernel_sizes=[3],
46 | dilations_list=[[1, 1], [3, 1], [5, 1]],
47 | )
48 |
49 | cc.use_cached_conv(True)
50 | stack_stream = ResidualStack(
51 | dim=dim,
52 | kernel_sizes=[3],
53 | dilations_list=[[1, 1], [3, 1], [5, 1]],
54 | )
55 |
56 | for p1, p2 in zip(stack_regular.parameters(), stack_stream.parameters()):
57 | p2.data.copy_(p1.data)
58 |
59 | delay = stack_stream.cumulative_delay
60 |
61 | y_regular = stack_regular(x)
62 | y_stream = stack_stream(x)
63 |
64 | if delay:
65 | y_regular = y_regular[..., delay:-delay]
66 | y_stream = y_stream[..., delay + delay:]
67 |
68 | assert torch.allclose(y_regular, y_stream, 1e-4, 1e-4)
69 |
70 |
71 | @pytest.mark.parametrize('kernel_size,dilations_list',
72 | itertools.product(kernel_size, dilations))
73 | def test_residual_layer(kernel_size, dilations_list):
74 | dim = 16
75 | x = torch.randn(1, dim, 32)
76 |
77 | cc.use_cached_conv(False)
78 | layer_regular = ResidualLayer(dim, kernel_size, dilations_list)
79 |
80 | cc.use_cached_conv(True)
81 | layer_stream = ResidualLayer(dim, kernel_size, dilations_list)
82 |
83 | for p1, p2 in zip(layer_regular.parameters(), layer_stream.parameters()):
84 | p2.data.copy_(p1.data)
85 |
86 | delay = layer_stream.cumulative_delay
87 |
88 | y_regular = layer_regular(x)
89 | y_stream = layer_stream(x)
90 |
91 | if delay:
92 | y_regular = y_regular[..., delay:-delay]
93 | y_stream = y_stream[..., delay + delay:]
94 |
95 | assert torch.allclose(y_regular, y_stream, 1e-3, 1e-4)
96 |
97 |
98 | @pytest.mark.parametrize('ratio,', ratios)
99 | def test_upsample_layer(ratio):
100 | dim = 16
101 | x = torch.randn(1, dim, 32)
102 |
103 | cc.use_cached_conv(False)
104 | upsample_regular = UpsampleLayer(dim, dim, ratio)
105 |
106 | cc.use_cached_conv(True)
107 | upsample_stream = UpsampleLayer(dim, dim, ratio)
108 |
109 | for p1, p2 in zip(upsample_regular.parameters(),
110 | upsample_stream.parameters()):
111 | p2.data.copy_(p1.data)
112 |
113 | delay = upsample_stream.cumulative_delay
114 |
115 | y_regular = upsample_regular(x)
116 | y_stream = upsample_stream(x)
117 |
118 | if delay:
119 | y_regular = y_regular[..., delay:-delay]
120 | y_stream = y_stream[..., delay + delay:]
121 |
122 | assert torch.allclose(y_regular, y_stream, 1e-3, 1e-4)
123 |
--------------------------------------------------------------------------------