├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── conf
├── 1gpu.yml
├── ablations
│ ├── baseline.yml
│ ├── diff-mb.yml
│ ├── equal-mb.yml
│ ├── no-adv.yml
│ ├── no-data-balance.yml
│ ├── no-low-hop.yml
│ ├── no-mb.yml
│ ├── no-mpd-msd.yml
│ ├── no-mpd.yml
│ └── only-speech.yml
├── base.yml
├── downsampling
│ ├── 1024x.yml
│ ├── 128x.yml
│ ├── 1536x.yml
│ └── 768x.yml
├── final
│ ├── 16khz.yml
│ ├── 24khz.yml
│ ├── 44khz-16kbps.yml
│ └── 44khz.yml
├── neuronic.yml
├── quantizer
│ ├── 24kbps.yml
│ ├── 256d.yml
│ ├── 2d.yml
│ ├── 32d.yml
│ ├── 4d.yml
│ ├── 512d.yml
│ ├── dropout-0.0.yml
│ ├── dropout-0.25.yml
│ └── dropout-0.5.yml
└── size
│ ├── medium.yml
│ └── small.yml
├── jobs
├── benchmark.slurm
└── simple.slurm
├── pyproject.toml
├── scripts
├── benchmark.py
├── compute_entropy.py
├── evaluate.py
├── get_samples.py
├── input_pipeline.py
├── mushra.py
├── organize_daps.py
├── save_test_set.py
└── train.py
├── setup.cfg
├── src
└── dac_jax
│ ├── __init__.py
│ ├── __main__.py
│ ├── audio_utils.py
│ ├── compare
│ ├── __init__.py
│ └── encodec.py
│ ├── model
│ ├── __init__.py
│ ├── core.py
│ ├── dac.py
│ ├── discriminator.py
│ └── encodec.py
│ ├── nn
│ ├── __init__.py
│ ├── encodec_layers.py
│ ├── encodec_quantize.py
│ ├── layers.py
│ ├── loss.py
│ └── quantize.py
│ └── utils
│ ├── __init__.py
│ ├── decode.py
│ ├── encode.py
│ ├── load_torch_weights.py
│ └── load_torch_weights_encodec.py
└── tests
├── README.md
├── __init__.py
├── test_audio_utils.py
├── test_binding.py
├── test_cli.py
├── test_dac_equivalence.py
├── test_encodec_equivalence.py
└── test_train.py
/.gitattributes:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DBraun/DAC-JAX/919ce4a2a9ec4c5c3fa7d10dcb2944259da00865/.gitattributes
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/env.sh
108 | venv/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | # PyCharm
131 | .idea
132 |
133 | # Files created by experiments
134 | output/
135 | snapshot/
136 | *.m4a
137 | *.wav
138 | notebooks/scratch.ipynb
139 | notebooks/inspect.ipynb
140 | notebooks/effects.ipynb
141 | notebooks/*.ipynb
142 | notebooks/*.gif
143 | notebooks/*.wav
144 | notebooks/*.mp4
145 | *runs/
146 | boards/
147 | samples/
148 | *.ipynb
149 | tmp/
150 |
151 | results.json
152 | metrics.csv
153 | mprofile_*
154 | mem.png
155 |
156 | results/
157 | mprofile*
158 | *.png
159 | # do not ignore the test wav file
160 | !tests/audio/short_test_audio.wav
161 | !tests/audio/output.wav
162 | */.DS_Store
163 | .DS_Store
164 | env.sh
165 | _codebraid/
166 | **/*.html
167 | **/*.exec.md
168 | flagged/
169 | log.txt
170 | ckpt/
171 | .syncthing*
172 | tests/assets/
173 | archived/
174 |
175 | *_remote_module_*
176 | *.zip
177 | *.pth
178 | encoded_out/
179 | recon/
180 | recons/
181 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023-present, Descript
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DAC-JAX and EnCodec-JAX
2 |
3 | This repository holds **unofficial** JAX implementations of Descript's DAC and Meta's EnCodec.
4 | We are not affiliated with Descript or Meta.
5 |
6 | You can read the DAC-JAX paper [here](https://arxiv.org/abs/2405.11554).
7 |
8 | ## Background
9 |
10 | In 2022, Meta published "[High Fidelity Neural Audio Compression](https://arxiv.org/abs/2210.13438)".
11 | They eventually open-sourced the code inside [AudioCraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/ENCODEC.md).
12 |
13 | In 2023, Descript published a related work "[High-Fidelity Audio Compression with Improved RVQGAN](https://arxiv.org/abs/2306.06546)"
14 | and released their code under the name [DAC](https://github.com/descriptinc/descript-audio-codec/) (Descript Audio Codec).
15 |
16 | Both EnCodec and DAC are neural audio codecs which use residual vector quantization inside a fully convolutional
17 | encoder-decoder architecture.
18 |
19 | ## Usage
20 |
21 | ### Installation
22 |
23 | 1. Upgrade `pip` and `setuptools`:
24 | ```bash
25 | pip install --upgrade pip setuptools
26 | ```
27 |
28 | 2. Install the **CPU** version of [PyTorch](https://pytorch.org/).
29 | We strongly suggest the CPU version because trying to install a GPU version can conflict with JAX's CUDA-related installation.
30 | PyTorch is required because it's used to load pretrained model weights.
31 |
32 | 3. Install [JAX](https://jax.readthedocs.io/en/latest/installation.html) (with GPU support).
33 |
34 | 4. Install DAC-JAX with one of the following:
35 |
36 |
40 |
41 | ```
42 | pip install git+https://github.com/DBraun/DAC-JAX
43 | ```
44 |
45 | Or,
46 |
47 | ```bash
48 | python -m pip install .
49 | ```
50 |
51 | Or, if you intend to contribute, clone and do an [editable install](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs):
52 | ```bash
53 | python -m pip install -e ".[dev]"
54 | ```
55 |
56 | ### Weights
57 | The original Descript repository releases model weights under the MIT license. These weights are for models that natively support 16 kHz, 24kHz, and 44.1kHz sampling rates. Our scripts download these PyTorch weights and load them into JAX.
58 | Weights are automatically downloaded when you first run an `encode` or `decode` command. You can download them in advance with one of the following commands:
59 | ```bash
60 | python -m dac_jax download_model # downloads the default 44kHz variant
61 | python -m dac_jax download_model --model_type 44khz --model_bitrate 16kbps # downloads the 44kHz 16 kbps variant
62 | python -m dac_jax download_model --model_type 44khz # downloads the 44kHz variant
63 | python -m dac_jax download_model --model_type 24khz # downloads the 24kHz variant
64 | python -m dac_jax download_model --model_type 16khz # downloads the 16kHz variant
65 | ```
66 |
67 | EnCodec weights can be downloaded similarly. This will download the 32 kHz EnCodec used in [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md).
68 | ```bash
69 | python -m dac_jax download_encodec
70 | ```
71 |
72 | For both DAC and EnCodec, the default download location is `~/.cache/dac_jax`. You can change the location by setting an **absolute path** value for an environment variable `DAC_JAX_CACHE`. For example, on macOS/Linux:
73 | ```bash
74 | export DAC_JAX_CACHE=/Users/admin/my-project/dac_jax_models
75 | ```
76 |
77 | If you do this, remember to still have `DAC_JAX_CACHE` set before you use the `load_model` function.
78 |
79 | ### Compress audio
80 | ```
81 | python -m dac_jax encode /path/to/input --output /path/to/output/codes
82 | ```
83 |
84 | This command will create `.dac` files with the same name as the input files.
85 | It will also preserve the directory structure relative to input root and
86 | re-create it in the output directory. Please use `python -m dac_jax encode --help`
87 | for more options.
88 |
89 | ### Reconstruct audio from compressed codes
90 | ```
91 | python -m dac_jax decode /path/to/output/codes --output /path/to/reconstructed_input
92 | ```
93 |
94 | This command will create `.wav` files with the same name as the input files.
95 | It will also preserve the directory structure relative to input root and
96 | re-create it in the output directory. Please use `python -m dac_jax decode --help`
97 | for more options.
98 |
99 | ### Programmatic usage (DAC and EnCodec)
100 |
101 | Here we use `jax.jit` for optimized encoding and decoding.
102 | This does not do sample-rate conversion or volume normalization in the encoder or decoder.
103 |
104 | ```python
105 | from functools import partial
106 |
107 | import jax
108 | from jax import numpy as jnp
109 | import librosa
110 |
111 | import dac_jax
112 |
113 | model, variables = dac_jax.load_model(model_type="44khz")
114 |
115 | # If you want to use pretrained 32 kHz EnCodec from Meta's MusicGen, use this:
116 | # model, variables = dac_jax.load_encodec_model()
117 |
118 | @jax.jit
119 | def encode_to_codes(x: jnp.ndarray):
120 | codes, scale = model.apply(
121 | variables,
122 | x,
123 | method="encode",
124 | )
125 | return codes, scale
126 |
127 | @partial(jax.jit, static_argnums=(1, 2))
128 | def decode_from_codes(codes: jnp.ndarray, scale, length: int = None):
129 | recons = model.apply(
130 | variables,
131 | codes,
132 | scale,
133 | length,
134 | method="decode",
135 | )
136 | return recons
137 |
138 | # Load a mono audio file with the correct sample rate
139 | signal, sample_rate = librosa.load('input.wav', sr=model.sample_rate, mono=True, duration=.5)
140 |
141 | signal = jnp.array(signal, dtype=jnp.float32)
142 | while signal.ndim < 3:
143 | signal = jnp.expand_dims(signal, axis=0)
144 |
145 | original_length = signal.shape[-1]
146 |
147 | codes, scale = encode_to_codes(signal)
148 | assert codes.shape[1] == model.num_codebooks
149 |
150 | recons = decode_from_codes(codes, scale, original_length)
151 | ```
152 |
153 | ### DAC with Binding
154 |
155 | Here we use DAC-JAX as a "[bound](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html#bind)" module, freeing us from repeatedly passing variables as an argument and using `.apply`. Note that bound modules are not meant to be used in fine-tuning.
156 |
157 | ```python
158 | import dac_jax
159 | from dac_jax import DACFile
160 |
161 | from jax import numpy as jnp
162 | import librosa
163 |
164 | # Download a model and bind variables to it.
165 | model, variables = dac_jax.load_model(model_type="44khz")
166 | model = model.bind(variables)
167 |
168 | # Load a mono audio file
169 | signal, sample_rate = librosa.load('input.wav', sr=44100, mono=True, duration=.5)
170 |
171 | signal = jnp.array(signal, dtype=jnp.float32)
172 | while signal.ndim < 3:
173 | signal = jnp.expand_dims(signal, axis=0)
174 |
175 | # Encode audio signal as one long file (may run out of GPU memory on long files).
176 | # This performs resampling to the codec's sample rate and volume normalization.
177 | dac_file = model.encode_to_dac(signal, sample_rate)
178 |
179 | # Save to a file
180 | dac_file.save("dac_file_001.dac")
181 |
182 | # Load a file
183 | dac_file = DACFile.load("dac_file_001.dac")
184 |
185 | # Decode audio signal. Since we're passing a dac_file, this undoes the
186 | # previous sample rate conversion and volume normalization.
187 | y = model.decode(dac_file)
188 |
189 | # Calculate mean-square error of reconstruction in time-domain
190 | mse = jnp.square(y-signal).mean()
191 | ```
192 |
193 | ### DAC compression with constant GPU memory regardless of input length:
194 |
195 | ```python
196 | import dac_jax
197 |
198 | import jax
199 | import jax.numpy as jnp
200 | import librosa
201 |
202 | # Download a model and set padding to False because we will use the chunk functions.
203 | model, variables = dac_jax.load_model(model_type="44khz", padding=False)
204 |
205 | # Load a mono audio file at any sample rate
206 | signal, sample_rate = librosa.load('input.wav', sr=None, mono=True)
207 |
208 | signal = jnp.array(signal, dtype=jnp.float32)
209 | while signal.ndim < 3:
210 | # signal will eventually be shaped [B, C, T]
211 | signal = jnp.expand_dims(signal, axis=0)
212 |
213 | # Jit-compile these functions because they're used inside a loop over chunks.
214 | @jax.jit
215 | def compress_chunk(x):
216 | return model.apply(variables, x, method='compress_chunk')
217 |
218 | @jax.jit
219 | def decompress_chunk(c):
220 | return model.apply(variables, c, method='decompress_chunk')
221 |
222 | win_duration = 0.5 # Adjust based on your GPU's memory size
223 | dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
224 |
225 | # Save and load to and from disk
226 | dac_file.save("compressed.dac")
227 | dac_file = dac_jax.DACFile.load("compressed.dac")
228 |
229 | # Decompress it back to audio
230 | y = model.decompress(decompress_chunk, dac_file)
231 | ```
232 |
233 | ## DAC Training
234 | The baseline model configuration can be trained using the following commands.
235 |
236 | ```bash
237 | python scripts/train.py --args.load conf/final/44khz.yml --train.ckpt_dir="/tmp/dac_jax_runs"
238 | ```
239 |
240 | In root directory, monitor with Tensorboard (`runs` will appear next to `scripts`):
241 | ```bash
242 | tensorboard --logdir="/tmp/dac_jax_runs"
243 | ```
244 |
245 | ## Testing
246 |
247 | ```
248 | python -m pytest tests
249 | ```
250 |
251 | ## Limitations
252 |
253 | Pull requests—especially ones which address any of the limitations below—are welcome.
254 |
255 | * We implement the "chunked" `compress`/`decompress` methods from the PyTorch repository, although this technique has some problems outlined [here](https://github.com/descriptinc/descript-audio-codec/issues/39).
256 | * We have not run all evaluation scripts in the `scripts` directory. For some of them, it makes sense to just keep using PyTorch instead of JAX.
257 | * The model architecture code (`model/dac.py`) has many static methods to help with finding DAC's `delay` and `output_length`. Please help us refactor this so that code is not so duplicated and at risk of typos.
258 | * In `audio_utils.py` we use [DM_AUX's](https://github.com/google-deepmind/dm_aux) STFT function instead of `jax.scipy.signal.stft`. We believe this is faster but requires more memory.
259 | * The source code of DAC-JAX has some `todo:` markings which indicate (mostly minor) improvements we'd like to have.
260 | * We don't have a Docker image yet like the original [DAC repository](https://github.com/descriptinc/descript-audio-codec) does.
261 | * Please check the limitations of [argbind](https://github.com/pseeth/argbind?tab=readme-ov-file#limitations-and-known-issues).
262 | * We don't provide a training script for EnCodec.
263 |
264 | ## Citation
265 |
266 | If you use this repository in your work, please cite EnCodec:
267 | ```
268 | @article{defossez2022high,
269 | title={High fidelity neural audio compression},
270 | author={D{\'e}fossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
271 | journal={arXiv preprint arXiv:2210.13438},
272 | year={2022}
273 | }
274 | ```
275 |
276 | DAC:
277 |
278 | ```
279 | @article{kumar2024high,
280 | title={High-fidelity audio compression with improved rvqgan},
281 | author={Kumar, Rithesh and Seetharaman, Prem and Luebs, Alejandro and Kumar, Ishaan and Kumar, Kundan},
282 | journal={Advances in Neural Information Processing Systems},
283 | volume={36},
284 | year={2024}
285 | }
286 | ```
287 |
288 |
289 |
290 | and DAC-JAX:
291 |
292 | ```
293 | @misc{braun2024dacjax,
294 | title={{DAC-JAX}: A {JAX} Implementation of the Descript Audio Codec},
295 | author={David Braun},
296 | year={2024},
297 | eprint={2405.11554},
298 | archivePrefix={arXiv},
299 | primaryClass={cs.SD}
300 | }
301 | ```
302 |
--------------------------------------------------------------------------------
/conf/1gpu.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 |
4 | train.batch_size: 12
5 | train.val_batch_size: 12
6 |
--------------------------------------------------------------------------------
/conf/ablations/baseline.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
--------------------------------------------------------------------------------
/conf/ablations/diff-mb.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | Discriminator.sample_rate: 44100
6 | Discriminator.fft_sizes: [2048, 1024, 512]
7 | Discriminator.bands:
8 | - [0.0, 0.05]
9 | - [0.05, 0.1]
10 | - [0.1, 0.25]
11 | - [0.25, 0.5]
12 | - [0.5, 1.0]
13 |
14 |
15 | # re-weight lambdas to make up for
16 | # lost discriminators vs baseline
17 | lambdas:
18 | mel/loss: 15.0
19 | adv/feat_loss: 5.0
20 | adv/gen_loss: 1.0
21 | vq/commitment_loss: 0.25
22 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/equal-mb.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | Discriminator.sample_rate: 44100
6 | Discriminator.fft_sizes: [2048, 1024, 512]
7 | Discriminator.bands:
8 | - [0.0, 0.2]
9 | - [0.2, 0.4]
10 | - [0.4, 0.6]
11 | - [0.6, 0.8]
12 | - [0.8, 1.0]
13 |
14 |
15 | # re-weight lambdas to make up for
16 | # lost discriminators vs baseline
17 | lambdas:
18 | mel/loss: 15.0
19 | adv/feat_loss: 5.0
20 | adv/gen_loss: 1.0
21 | vq/commitment_loss: 0.25
22 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/no-adv.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | lambdas:
6 | mel/loss: 1.0
7 | waveform/loss: 1.0
8 | vq/commitment_loss: 0.25
9 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/no-data-balance.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | train/build_dataset.folders:
6 | speech:
7 | - /data/daps/train
8 | - /data/vctk
9 | - /data/vocalset
10 | - /data/read_speech
11 | - /data/french_speech
12 | - /data/emotional_speech/
13 | - /data/common_voice/
14 | - /data/german_speech/
15 | - /data/russian_speech/
16 | - /data/spanish_speech/
17 | music:
18 | - /data/musdb/train
19 | - /data/jamendo
20 | general:
21 | - /data/audioset/data/unbalanced_train_segments/
22 | - /data/audioset/data/balanced_train_segments/
23 |
--------------------------------------------------------------------------------
/conf/ablations/no-low-hop.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | mel_spectrogram_loss.n_mels: [80]
6 | mel_spectrogram_loss.window_lengths: [512]
7 | mel_spectrogram_loss.lower_edge_hz: [0]
8 | mel_spectrogram_loss.upper_edge_hz: [null]
9 | mel_spectrogram_loss.pow: 1.0
10 | mel_spectrogram_loss.clamp_eps: 1.0e-5
11 | mel_spectrogram_loss.mag_weight: 0.0
12 |
13 | lambdas:
14 | mel/loss: 100.0
15 | adv/feat_loss: 2.0
16 | adv/gen_loss: 1.0
17 | vq/commitment_loss: 0.25
18 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/no-mb.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | Discriminator.sample_rate: 44100
6 | Discriminator.fft_sizes: [2048, 1024, 512]
7 | Discriminator.bands:
8 | - [0.0, 1.0]
9 |
10 | # re-weight lambdas to make up for
11 | # lost discriminators vs baseline
12 | lambdas:
13 | mel/loss: 15.0
14 | adv/feat_loss: 5.0
15 | adv/gen_loss: 1.0
16 | vq/commitment_loss: 0.25
17 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/no-mpd-msd.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | Discriminator.sample_rate: 44100
6 | Discriminator.rates: []
7 | Discriminator.periods: []
8 | Discriminator.fft_sizes: [2048, 1024, 512]
9 | Discriminator.bands:
10 | - [0.0, 0.1]
11 | - [0.1, 0.25]
12 | - [0.25, 0.5]
13 | - [0.5, 0.75]
14 | - [0.75, 1.0]
15 |
16 | lambdas:
17 | mel/loss: 15.0
18 | adv/feat_loss: 2.66
19 | adv/gen_loss: 1.0
20 | vq/commitment_loss: 0.25
21 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/no-mpd.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | Discriminator.sample_rate: 44100
6 | Discriminator.rates: [1]
7 | Discriminator.periods: []
8 | Discriminator.fft_sizes: [2048, 1024, 512]
9 | Discriminator.bands:
10 | - [0.0, 0.1]
11 | - [0.1, 0.25]
12 | - [0.25, 0.5]
13 | - [0.5, 0.75]
14 | - [0.75, 1.0]
15 |
16 | lambdas:
17 | mel/loss: 15.0
18 | adv/feat_loss: 2.5
19 | adv/gen_loss: 1.0
20 | vq/commitment_loss: 0.25
21 | vq/codebook_loss: 1.0
--------------------------------------------------------------------------------
/conf/ablations/only-speech.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | train/build_dataset.folders:
6 | speech_fb:
7 | - /data/daps/train
8 | speech_hq:
9 | - /data/vctk
10 | - /data/vocalset
11 | - /data/read_speech
12 | - /data/french_speech
13 | speech_uq:
14 | - /data/emotional_speech/
15 | - /data/common_voice/
16 | - /data/german_speech/
17 | - /data/russian_speech/
18 | - /data/spanish_speech/
19 |
20 | val/build_dataset.folders:
21 | speech_hq:
22 | - /data/daps/val
23 |
--------------------------------------------------------------------------------
/conf/base.yml:
--------------------------------------------------------------------------------
1 | # Model setup
2 | DAC.sample_rate: 44100
3 | DAC.encoder_dim: 64
4 | DAC.encoder_rates: [2, 4, 8, 8]
5 | DAC.decoder_dim: 1536
6 | DAC.decoder_rates: [8, 8, 4, 2]
7 |
8 | # Quantization
9 | DAC.num_codebooks: 9
10 | DAC.codebook_size: 1024
11 | DAC.codebook_dim: 8
12 | DAC.quantizer_dropout: 1.0
13 |
14 | # Discriminator
15 | Discriminator.sample_rate: 44100
16 | Discriminator.rates: []
17 | Discriminator.periods: [2, 3, 5, 7, 11]
18 | Discriminator.fft_sizes: [2048, 1024, 512]
19 | Discriminator.bands:
20 | - [0.0, 0.1]
21 | - [0.1, 0.25]
22 | - [0.25, 0.5]
23 | - [0.5, 0.75]
24 | - [0.75, 1.0]
25 |
26 | # Schedules
27 | create_generator_schedule.learning_rate: 1e-4
28 | create_generator_schedule.lr_gamma: 0.999996
29 |
30 | create_discriminator_schedule.learning_rate: 1e-4
31 | create_discriminator_schedule.lr_gamma: 0.999996
32 |
33 | # Optimization
34 | create_generator_optimizer.adam_b1: 0.8
35 | create_generator_optimizer.adam_b2: 0.99
36 | create_generator_optimizer.adam_weight_decay: .01
37 | create_generator_optimizer.grad_clip: 1e3
38 |
39 | create_discriminator_optimizer.adam_b1: 0.8
40 | create_discriminator_optimizer.adam_b2: 0.99
41 | create_discriminator_optimizer.adam_weight_decay: .01
42 | create_discriminator_optimizer.grad_clip: 10
43 |
44 | #lambdas:
45 | # mel/loss: 15.0
46 | # adv/feat_loss: 200 # 2.0 * (5+5+5+5+5+25+25+25) = 2.0 * 100
47 | # # 2.0 comes from the PyTorch DAC base.yml Then we multiply since we normalized the magnitude
48 | # # of our feature loss differently than the PyTorch version.
49 | # # 5 is (6-1) where 6 is number of convs in MPD. Then there are 5 of these because the
50 | # # number of periods is 5.
51 | # # 25 is number of bands (5) times the number of convs (5) in MRD. Then there are 3 of these
52 | # # because of the number of fft sizes is 3.
53 | # adv/gen_loss: 8 # 1.0 * 8 where 8 is number of Discriminator rates+periods+fft sizes = (0+5+3)
54 | # # 1.0 comes from the PyTorch DAC base.yml
55 | # vq/commitment_loss: 2.25 # 0.25 * 9 since we normalize based on the number of codebooks.
56 | # vq/codebook_loss: 9 # 1 * 9 since we normalized based on the number of codebooks.
57 |
58 | lambdas:
59 | mel/loss: 15.0
60 | adv/feat_loss: 2
61 | adv/gen_loss: 1
62 | vq/commitment_loss: 0.25
63 | vq/codebook_loss: 1
64 |
65 | train.batch_size: 72
66 | train.val_batch_size: 100
67 | train.sample_batch_size: 100
68 | train.num_iterations: 250000
69 | train.valid_freq: 1000
70 | train.sample_freq: 10000
71 | train.ckpt_max_keep: 4
72 | train.seed: 0
73 | train.tabulate: 1
74 |
75 | EarlyStopping.min_delta: .001
76 | EarlyStopping.patience: 4
77 |
78 | log_training.log_every_steps: 10
79 |
80 | # Loss setup
81 | multiscale_stft_loss.window_lengths: [2048, 512]
82 |
83 | mel_spectrogram_loss.n_mels: [5, 10, 20, 40, 80, 160, 320]
84 | mel_spectrogram_loss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048]
85 | mel_spectrogram_loss.lower_edge_hz: [0, 0, 0, 0, 0, 0, 0]
86 | mel_spectrogram_loss.upper_edge_hz: [null, null, null, null, null, null, null]
87 | mel_spectrogram_loss.pow: 1.0
88 | mel_spectrogram_loss.clamp_eps: 1.0e-5
89 | mel_spectrogram_loss.mag_weight: 0.0
90 |
91 | # Data Augmentation
92 | VolumeNorm.config:
93 | min_db: -16
94 | max_db: -16
95 |
96 | train/augment_batch.transforms:
97 | - VolumeNorm
98 | - RescaleAudio
99 | - ShiftPhase
100 |
101 | val/augment_batch.transforms:
102 | - VolumeNorm
103 | - RescaleAudio
104 |
105 | sample/augment_batch.transforms:
106 | - VolumeNorm
107 | - RescaleAudio
108 |
109 | # Data
110 | # This should be equivalent to how DAC used salient_excerpt from AudioTools.
111 | SaliencyParams.enabled: 1
112 | SaliencyParams.num_tries: 8
113 | SaliencyParams.loudness_cutoff: -40
114 | SaliencyParams.search_function: SaliencyParams.search_uniform
115 |
116 | # Data
117 | create_dataset.worker_count: 0
118 | create_dataset.worker_buffer_size: 1
119 |
120 | create_dataset.extensions:
121 | - .wav
122 | - .flac
123 | - .ogg
124 | # - .mp3
125 |
126 | train/create_dataset.duration: 0.38
127 | val/create_dataset.duration: 5.0
128 | sample/create_dataset.duration: 5.0
129 | test/create_dataset.duration: 10.0
130 |
131 | val/create_dataset.num_steps: 4
132 |
133 | train/create_dataset.sources:
134 | speech_fb:
135 | - /data/daps/train
136 | speech_hq:
137 | - /data/vctk
138 | - /data/vocalset
139 | - /data/read_speech
140 | - /data/french_speech
141 | speech_uq:
142 | - /data/emotional_speech/
143 | - /data/common_voice/
144 | - /data/german_speech/
145 | - /data/russian_speech/
146 | - /data/spanish_speech/
147 | music_hq:
148 | - /data/musdb/train
149 | music_uq:
150 | - /data/jamendo
151 | general:
152 | - /data/audioset/data/unbalanced_train_segments/
153 | - /data/audioset/data/balanced_train_segments/
154 |
155 | val/create_dataset.sources:
156 | speech_hq:
157 | - /data/daps/val
158 | music_hq:
159 | - /data/musdb/test
160 | general:
161 | - /data/audioset/data/eval_segments/
162 |
163 | sample/create_dataset.sources:
164 | speech_hq:
165 | - /data/daps/val
166 | music_hq:
167 | - /data/musdb/test
168 | general:
169 | - /data/audioset/data/eval_segments/
170 |
171 | test/create_dataset.sources:
172 | speech_hq:
173 | - /data/daps/test
174 | music_hq:
175 | - /data/musdb/test
176 | general:
177 | - /data/audioset/data/eval_segments/
178 |
--------------------------------------------------------------------------------
/conf/downsampling/1024x.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | # Model setup
6 | DAC.sample_rate: 44100
7 | DAC.encoder_dim: 64
8 | DAC.encoder_rates: [2, 8, 8, 8]
9 | DAC.decoder_dim: 1536
10 | DAC.decoder_rates: [8, 4, 4, 2, 2, 2]
11 |
12 | # Quantization
13 | DAC.num_codebooks: 19
14 | DAC.codebook_size: 1024
15 | DAC.codebook_dim: 8
16 | DAC.quantizer_dropout: 1.0
17 |
--------------------------------------------------------------------------------
/conf/downsampling/128x.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | # Model setup
6 | DAC.sample_rate: 44100
7 | DAC.encoder_dim: 64
8 | DAC.encoder_rates: [2, 4, 4, 4]
9 | DAC.decoder_dim: 1536
10 | DAC.decoder_rates: [4, 4, 2, 2, 2, 1]
11 |
12 | # Quantization
13 | DAC.num_codebooks: 2
14 | DAC.codebook_size: 1024
15 | DAC.codebook_dim: 8
16 | DAC.quantizer_dropout: 1.0
17 |
--------------------------------------------------------------------------------
/conf/downsampling/1536x.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | # Model setup
6 | DAC.sample_rate: 44100
7 | DAC.encoder_dim: 96
8 | DAC.encoder_rates: [2, 8, 8, 12]
9 | DAC.decoder_dim: 1536
10 | DAC.decoder_rates: [12, 4, 4, 2, 2, 2]
11 |
12 | # Quantization
13 | DAC.num_codebooks: 28
14 | DAC.codebook_size: 1024
15 | DAC.codebook_dim: 8
16 | DAC.quantizer_dropout: 1.0
17 |
--------------------------------------------------------------------------------
/conf/downsampling/768x.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | # Model setup
6 | DAC.sample_rate: 44100
7 | DAC.encoder_dim: 64
8 | DAC.encoder_rates: [2, 6, 8, 8]
9 | DAC.decoder_dim: 1536
10 | DAC.decoder_rates: [6, 4, 4, 2, 2, 2]
11 |
12 | # Quantization
13 | DAC.num_codebooks: 14
14 | DAC.codebook_size: 1024
15 | DAC.codebook_dim: 8
16 | DAC.quantizer_dropout: 1.0
17 |
--------------------------------------------------------------------------------
/conf/final/16khz.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 |
4 | DAC.sample_rate: 16000
5 |
6 | DAC.encoder_rates: [2, 4, 5, 8]
7 |
8 | DAC.decoder_rates: [8, 5, 4, 2]
9 |
10 | DAC.num_codebooks: 12
11 |
12 | DAC.quantizer_dropout: 0.5
13 |
14 | Discriminator.sample_rate: 16000
15 |
16 | train.num_iterations: 400000
--------------------------------------------------------------------------------
/conf/final/24khz.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 |
4 | DAC.sample_rate: 24000
5 |
6 | DAC.encoder_rates: [2, 4, 5, 8]
7 |
8 | DAC.decoder_rates: [8, 5, 4, 2]
9 |
10 | DAC.num_codebooks: 32
11 |
12 | DAC.quantizer_dropout: 0.5
13 |
14 | Discriminator.sample_rate: 24000
15 |
16 | train.num_iterations: 400000
--------------------------------------------------------------------------------
/conf/final/44khz-16kbps.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 |
4 | DAC.num_codebooks: 18 # Max bitrate of 16kbps
5 |
6 | DAC.quantizer_dropout: 0.5
7 |
8 | train.num_iterations: 400000
--------------------------------------------------------------------------------
/conf/final/44khz.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 |
4 | DAC.quantizer_dropout: 0.5
5 |
6 | train.num_iterations: 400000
--------------------------------------------------------------------------------
/conf/neuronic.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/final/44khz.yml
3 |
4 | train.batch_size: 4
5 | train.val_batch_size: 4
6 | train.sample_batch_size: 1
7 | train.valid_freq: 4000
8 | train.sample_freq: 4000
9 | train.ckpt_max_keep: 1
10 | train.tabulate: 0
11 |
12 | EarlyStopping.patience: 10
13 |
14 | # Data
15 | create_dataset.worker_count: 0
16 |
17 | # Data Augmentation
18 | VolumeChange.config:
19 | min_db: -10
20 | max_db: 0
21 |
22 | #train/build_transforms.augment:
23 | # - VolumeChange
24 | # - RescaleAudio
25 | # - ShiftPhase
26 | #
27 | #val/build_transforms.augment:
28 | # - VolumeChange
29 | # - RescaleAudio
30 | #
31 | #sample/build_transforms.augment:
32 | # - VolumeNorm
33 | # - RescaleAudio
34 |
35 | # Data
36 | # This should be equivalent to how DAC used salient_excerpt from AudioTools.
37 | SaliencyParams.enabled: 1
38 | SaliencyParams.num_tries: 8
39 | SaliencyParams.loudness_cutoff: -40
40 | SaliencyParams.search_function: SaliencyParams.search_bias_early
41 |
42 | train/create_dataset.sources:
43 | musdb18hq:
44 | - /scratch/$USER/datasets/musdb18hq/train/*/mixture.wav
45 | # nsynth:
46 | # - /scratch/$USER/datasets/nsynth/nsynth-train/audio
47 |
48 | val/create_dataset.num_steps: 100
49 | val/create_dataset.duration: 2
50 | val/create_dataset.sources:
51 | musdb18hq:
52 | - /scratch/$USER/datasets/musdb18hq/test/*/mixture.wav
53 | # nsynth:
54 | # - /scratch/$USER/datasets/nsynth/nsynth-valid/audio
55 |
56 | sample/create_dataset.duration: 2
57 | sample/create_dataset.sources:
58 | musdb18hq:
59 | - /scratch/$USER/datasets/musdb18hq/test/*/mixture.wav
60 | # nsynth:
61 | # - /scratch/$USER/datasets/nsynth/nsynth-valid/audio
62 |
63 | test/create_dataset.duration: 4
64 | test/create_dataset.sources:
65 | musdb18hq:
66 | - /scratch/$USER/datasets/musdb18hq/test/*/mixture.wav
67 | # nsynth:
68 | # - /scratch/$USER/datasets/nsynth/nsynth-test/audio
--------------------------------------------------------------------------------
/conf/quantizer/24kbps.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.num_codebooks: 28
6 |
--------------------------------------------------------------------------------
/conf/quantizer/256d.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.codebook_dim: 256
6 |
--------------------------------------------------------------------------------
/conf/quantizer/2d.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.codebook_dim: 2
6 |
--------------------------------------------------------------------------------
/conf/quantizer/32d.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.codebook_dim: 32
6 |
--------------------------------------------------------------------------------
/conf/quantizer/4d.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.codebook_dim: 4
6 |
--------------------------------------------------------------------------------
/conf/quantizer/512d.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.codebook_dim: 512
6 |
--------------------------------------------------------------------------------
/conf/quantizer/dropout-0.0.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.quantizer_dropout: 0.0
6 |
--------------------------------------------------------------------------------
/conf/quantizer/dropout-0.25.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.quantizer_dropout: 0.25
6 |
--------------------------------------------------------------------------------
/conf/quantizer/dropout-0.5.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.quantizer_dropout: 0.5
6 |
--------------------------------------------------------------------------------
/conf/size/medium.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.decoder_dim: 1024
6 |
--------------------------------------------------------------------------------
/conf/size/small.yml:
--------------------------------------------------------------------------------
1 | $include:
2 | - conf/base.yml
3 | - conf/1gpu.yml
4 |
5 | DAC.decoder_dim: 512
6 |
--------------------------------------------------------------------------------
/jobs/benchmark.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=jax-gpu # create a short name for your job
3 | #SBATCH --nodes=1 # node count
4 | #SBATCH --ntasks=1 # total number of tasks across all nodes
5 | #SBATCH --cpus-per-task=1 # cpu-cores per task (>1 if multi-threaded tasks)
6 | #SBATCH --mem-per-cpu=16G # RAM usage per cpu-core
7 | #SBATCH --gres=gpu:1 # number of gpus per node
8 | #SBATCH --time=03:00:00 # total run time limit (HH:MM:SS)
9 | #SBATCH --mail-type=END # choice could be 'fail'
10 | #SBATCH --mail-user=db1224@princeton.edu
11 |
12 | module purge
13 | module load anaconda3/2024.02
14 |
15 | eval "$(conda shell.bash hook)"
16 | conda activate jax-env
17 |
18 | python scripts/benchmark.py --model_type=16khz
19 | python scripts/benchmark.py --model_type=24khz
20 | python scripts/benchmark.py --model_type=44khz
21 |
--------------------------------------------------------------------------------
/jobs/simple.slurm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=DAC-JAX # create a short name for your job
3 | #SBATCH --nodes=1 # node count
4 | #SBATCH --ntasks=1 # total number of tasks across all nodes
5 | #SBATCH --cpus-per-task=2 # cpu-cores per task (>1 if multi-threaded tasks)
6 | #SBATCH --mem-per-cpu=16G # RAM usage per cpu-core
7 | #SBATCH --gres=gpu:2 # number of gpus per node
8 | #SBATCH --time=60:00:00 # total run time limit (HH:MM:SS)
9 | #SBATCH --signal=B:USR1@120 # 120 sec grace period for cleanup after timeout
10 | #SBATCH --signal=B:SIGTERM@120 # 120 sec grace period for cleanup after scancel is sent
11 | #SBATCH --mail-type=END # choice could be 'fail'
12 | #SBATCH --mail-user=db1224@princeton.edu
13 |
14 | function cleanup() {
15 | echo 'Running cleanup script'
16 | kill $TRAIN_PID
17 | kill $TB_PID
18 | cp -r "/scratch/$USER/runs" "/n/fs/audiovis/$USER/DAC-JAX/runs"
19 | rm -rf "/scratch/$USER"
20 | exit 0
21 | }
22 |
23 | ## Trap the SIGTERM signal (sent by scancel) and call the cleanup function
24 | trap cleanup EXIT SIGINT SIGTERM
25 |
26 | module purge
27 | module load anaconda3/2024.02
28 |
29 | eval "$(conda shell.bash hook)"
30 |
31 | conda activate ../Terrapin/.env/jax-env
32 | export PYTHONPATH=$PWD
33 |
34 | ## prepare data
35 | ##echo "$(date '+%H:%M:%S'): Copying data to /scratch"
36 | ##mkdir -p "/scratch/$USER/datasets"
37 | ##rsync -a --info=progress2 --no-i-r "/n/fs/audiovis/$USER/datasets/nsynth" "/scratch/$USER/datasets"
38 | #
39 | ##cd "/scratch/$USER/datasets/nsynth" || exit
40 | ##echo "$(date '+%H:%M:%S'): Unzipping test"
41 | ##tar -xzf nsynth-test.jsonwav.tar.gz
42 | ##echo "$(date '+%H:%M:%S'): Unzipping valid"
43 | ##tar -xzf nsynth-valid.jsonwav.tar.gz
44 | ##echo "$(date '+%H:%M:%S'): Unzipping train"
45 | ##tar -xzf nsynth-train.jsonwav.tar.gz
46 | ##echo "$(date '+%H:%M:%S'): Copied data to /scratch"
47 |
48 | ## prepare data
49 | echo "$(date '+%H:%M:%S'): Copying data to /scratch"
50 | mkdir -p "/scratch/$USER/datasets"
51 | rsync -a --info=progress2 --no-i-r "/n/fs/audiovis/$USER/datasets/musdb18hq" "/scratch/$USER/datasets"
52 |
53 | cd "/scratch/$USER/datasets/musdb18hq" || exit
54 | echo "$(date '+%H:%M:%S'): Unzipping musdb18hq"
55 | unzip -q musdb18hq.zip
56 | echo "$(date '+%H:%M:%S'): Copied data to /scratch"
57 |
58 | ## Launch TensorBoard and get the process ID of TensorBoard
59 | tensorboard --logdir="/scratch/$USER/runs" --port=10013 --samples_per_plugin audio=20 --bind_all & TB_PID=$!
60 |
61 | cd "/n/fs/audiovis/$USER/DAC-JAX" || exit
62 | python scripts/train.py \
63 | --args.load conf/neuronic.yml \
64 | --train.name="slurm_$SLURM_JOB_ID" \
65 | --train.ckpt_dir="/scratch/$USER/runs" \
66 | & TRAIN_PID=$!
67 |
68 | wait $TRAIN_PID
69 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/scripts/benchmark.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import argbind
4 | import jax
5 | from jax import random
6 |
7 | from dac_jax import load_model
8 |
9 |
10 | @argbind.bind(without_prefix=True)
11 | def benchmark_dac(model_type="44khz", model_bitrate='8kbps', win_durations: List[str] = None):
12 |
13 | if win_durations is None:
14 | win_durations = [0.37, 0.38, 0.42, 0.46, 0.5, 1, 5, 10, 20]
15 | else:
16 | win_durations = [float(x) for x in win_durations]
17 |
18 | # Set padding to False since we're using chunk functions.
19 | model, variables = load_model(model_type=model_type, model_bitrate=model_bitrate, padding=False)
20 |
21 | @jax.jit
22 | def compress_chunk(x):
23 | return model.apply(variables, x, method='compress_chunk')
24 |
25 | @jax.jit
26 | def decompress_chunk(c):
27 | return model.apply(variables, c, method='decompress_chunk')
28 |
29 | audio_sr = model.sample_rate # not always a valid assumption, in case you copy-paste this elsewhere
30 |
31 | print(f'Benchmarking model: {model_type}, {model_bitrate}')
32 |
33 | for win_duration in win_durations:
34 | # Force chunk-encoding by making duration 1 more than win_duration:
35 | # (one day the compress function will default to unchunked if the audio length is <= win_duration)
36 | T = 1 + int(win_duration * model.sample_rate)
37 | x = random.normal(random.key(0), shape=(1, 1, T))
38 | try:
39 | dac_file = model.compress(compress_chunk, x, audio_sr, win_duration=win_duration, benchmark=True)
40 | recons = model.decompress(decompress_chunk, dac_file, benchmark=True)
41 | except Exception as e:
42 | print(f'Exception for win duration "{win_duration}": {e}')
43 |
44 |
45 | if __name__ == "__main__":
46 | # example usage:
47 | # python3 benchmark.py --model_type=16khz --win_durations="0.5 1 5 10 20"
48 | print(f'devices: {jax.devices()}')
49 |
50 | args = argbind.parse_args()
51 | with argbind.scope(args):
52 | benchmark_dac()
53 |
54 |
55 | # @argbind.bind(without_prefix=True)
56 | # def benchmark_dac_encode(model_type="44khz", model_bitrate='8kbps', batch_size: int = 1, durations: List[str] = None):
57 | #
58 | # if durations is None:
59 | # durations = [1, 2, 4, 8, 16, 32]
60 | # else:
61 | # durations = [float(x) for x in durations]
62 | #
63 | # model, variables = load_model(model_type=model_type, model_bitrate=model_bitrate)
64 | #
65 | # @jax.jit
66 | # def encode(audio):
67 | # audio = model.apply(variables, audio, model.sample_rate, method="preprocess")
68 | # _, codes, _, _, _ = model.apply(variables, audio, train=False, method="encode")
69 | # return codes
70 | #
71 | # for duration in durations:
72 | # print(f'Benchmarking encode for model: {model_type}, {model_bitrate} with duration {duration} sec and batch size {batch_size}.')
73 | #
74 | # T = int(duration * model.sample_rate)
75 | # x = random.normal(random.key(0), shape=(batch_size, 1, T))
76 | # import tqdm
77 | # for _ in tqdm.trange(100):
78 | # try:
79 | # encode(x)
80 | # except Exception as e:
81 | # print(f'Exception for duration "{duration}": {e}')
82 |
83 |
84 | # if __name__ == "__main__":
85 | # # example usage:
86 | # # python3 benchmark.py --model_type=44khz --durations="5" --batch_size=8
87 | # print(f'devices: {jax.devices()}')
88 | #
89 | # args = argbind.parse_args()
90 | # with argbind.scope(args):
91 | # benchmark_dac_encode()
92 |
--------------------------------------------------------------------------------
/scripts/compute_entropy.py:
--------------------------------------------------------------------------------
1 | import argbind
2 | import jax
3 | from audiotools import AudioSignal
4 | import numpy as np
5 | import tqdm
6 |
7 | from dac_jax import load_model
8 | from dac_jax.audio_utils import find_audio
9 |
10 |
11 | @argbind.bind(without_prefix=True, positional=True)
12 | def main(
13 | folder: str,
14 | model_path: str,
15 | metadata_path: str,
16 | n_samples: int = 1024,
17 | ):
18 | files = find_audio(folder)[:n_samples]
19 | key = jax.random.key(0)
20 | key, subkey = jax.random.split(key)
21 | signals = [
22 | AudioSignal.salient_excerpt(f, subkey, loudness_cutoff=-20, duration=1.0)
23 | for f in files
24 | ]
25 |
26 | assert model_path is not None
27 | assert metadata_path is not None
28 |
29 | model, variables = load_model(load_path=model_path, metadata_path=metadata_path)
30 | model = model.bind(variables)
31 |
32 | codes = []
33 | for x in tqdm.tqdm(signals):
34 | x = jax.device_put(x, model.device)
35 | o = model.encode(x.audio_data, x.sample_rate)
36 | codes.append(np.array(o["codes"]))
37 |
38 | codes = np.concatenate(codes, axis=-1)
39 | entropy = []
40 |
41 | for i in range(codes.shape[1]):
42 | codes_ = codes[0, i, :]
43 | counts = np.bincount(codes_)
44 | counts = (counts / counts.sum())
45 | counts = np.maximum(counts, 1e-10)
46 | entropy.append(-(counts * np.log(counts)).sum().item() * np.log2(np.e))
47 |
48 | pct = sum(entropy) / (10 * len(entropy))
49 | print(f"Entropy for each codebook: {entropy}")
50 | print(f"Effective percentage: {pct * 100}%")
51 |
52 |
53 | if __name__ == "__main__":
54 | args = argbind.parse_args()
55 | with argbind.scope(args):
56 | main()
57 |
--------------------------------------------------------------------------------
/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import multiprocessing as mp
3 | from concurrent.futures import ProcessPoolExecutor
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 |
7 | import argbind
8 | from audiotools import AudioSignal
9 | from audiotools import metrics
10 | from audiotools.core import util
11 | from audiotools.ml.decorators import Tracker
12 | import jax.numpy as jnp
13 | import numpy as np
14 |
15 | from dac_jax.nn.loss import multiscale_stft_loss, mel_spectrogram_loss, sisdr_loss, l1_loss
16 |
17 |
18 | @dataclass
19 | class State:
20 | stft_loss: multiscale_stft_loss
21 | mel_loss: mel_spectrogram_loss
22 | waveform_loss: l1_loss
23 | sisdr_loss: sisdr_loss
24 |
25 |
26 | def get_metrics(signal_path, recons_path, state):
27 | output = {}
28 | signal = AudioSignal(signal_path)
29 | recons = AudioSignal(recons_path)
30 | for sr in [22050, 44100]:
31 | x = signal.clone().resample(sr)
32 | y = recons.clone().resample(sr)
33 | k = "22k" if sr == 22050 else "44k"
34 | output.update(
35 | {
36 | f"mel-{k}": state.mel_loss(x, y),
37 | f"stft-{k}": state.stft_loss(x, y),
38 | f"waveform-{k}": state.waveform_loss(x, y),
39 | f"sisdr-{k}": state.sisdr_loss(x, y),
40 | f"visqol-audio-{k}": metrics.quality.visqol(x, y),
41 | f"visqol-speech-{k}": metrics.quality.visqol(x, y, "speech"),
42 | }
43 | )
44 | output["path"] = signal.path_to_file
45 | output.update(signal.metadata)
46 | return output
47 |
48 |
49 | @argbind.bind(without_prefix=True)
50 | def evaluate(
51 | input: str = "samples/input",
52 | output: str = "samples/output",
53 | n_proc: int = 50,
54 | ):
55 | tracker = Tracker()
56 |
57 | state = State(
58 | waveform_loss=l1_loss,
59 | stft_loss=multiscale_stft_loss,
60 | mel_loss=mel_spectrogram_loss,
61 | sisdr_loss=sisdr_loss,
62 | )
63 |
64 | audio_files = util.find_audio(input)
65 | output = Path(output)
66 | output.mkdir(parents=True, exist_ok=True)
67 |
68 | @tracker.track("metrics", len(audio_files))
69 | def record(future, writer):
70 | o = future.result()
71 | for k, v in o.items():
72 | if isinstance(v, jnp.ndarray): # todo:
73 | o[k] = np.array(v).item() # todo:
74 | writer.writerow(o)
75 | o.pop("path")
76 | return o
77 |
78 | futures = []
79 | with tracker.live:
80 | with open(output / "metrics.csv", "w") as csvfile:
81 | with ProcessPoolExecutor(n_proc, mp.get_context("fork")) as pool:
82 | for i in range(len(audio_files)):
83 | future = pool.submit(
84 | get_metrics, audio_files[i], output / audio_files[i].name, state
85 | )
86 | futures.append(future)
87 |
88 | keys = list(futures[0].result().keys())
89 | writer = csv.DictWriter(csvfile, fieldnames=keys)
90 | writer.writeheader()
91 |
92 | for future in futures:
93 | record(future, writer)
94 |
95 | tracker.done("test", f"N={len(audio_files)}")
96 |
97 |
98 | if __name__ == "__main__":
99 | args = argbind.parse_args()
100 | with argbind.scope(args):
101 | evaluate()
102 |
--------------------------------------------------------------------------------
/scripts/get_samples.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import argbind
4 | from audiotools import AudioSignal
5 | from audiotools.ml.decorators import Tracker
6 | from train import Accelerator
7 | from train import DAC
8 |
9 | from dac_jax.audio_utils import find_audio
10 | from dac_jax.compare.encodec import Encodec
11 |
12 |
13 | Encodec = argbind.bind(Encodec)
14 |
15 |
16 | def load_state(
17 | accel: Accelerator,
18 | tracker: Tracker,
19 | save_path: str,
20 | tag: str = "latest",
21 | load_weights: bool = False,
22 | model_type: str = "dac",
23 | bandwidth: float = 24.0,
24 | ):
25 | kwargs = {
26 | "folder": f"{save_path}/{tag}",
27 | "map_location": "cpu",
28 | "package": not load_weights,
29 | }
30 | tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}")
31 |
32 | if model_type == "dac":
33 | generator, _ = DAC.load_from_folder(**kwargs)
34 | elif model_type == "encodec":
35 | generator = Encodec(bandwidth=bandwidth)
36 |
37 | generator = accel.prepare_model(generator)
38 | return generator
39 |
40 |
41 | def process(signal, accel, generator, **kwargs):
42 | signal = signal.to(accel.device)
43 | recons = generator(signal.audio_data, signal.sample_rate, **kwargs)["audio"]
44 | recons = AudioSignal(recons, signal.sample_rate)
45 | recons = recons.normalize(signal.loudness())
46 | return recons.cpu()
47 |
48 |
49 | @argbind.bind(without_prefix=True)
50 | def get_samples(
51 | accel,
52 | path: str = "ckpt",
53 | input: str = "samples/input",
54 | output: str = "samples/output",
55 | model_type: str = "dac",
56 | model_tag: str = "latest",
57 | bandwidth: float = 24.0,
58 | n_quantizers: int = None,
59 | ):
60 | tracker = Tracker(log_file=f"{path}/eval.txt", rank=accel.local_rank)
61 | generator = load_state(
62 | accel,
63 | tracker,
64 | save_path=path,
65 | model_type=model_type,
66 | bandwidth=bandwidth,
67 | tag=model_tag,
68 | )
69 | kwargs = {"n_quantizers": n_quantizers} if model_type == "dac" else {}
70 |
71 | audio_files = find_audio(input)
72 |
73 | global process
74 | process = tracker.track("process", len(audio_files))(process)
75 |
76 | output = Path(output)
77 | output.mkdir(parents=True, exist_ok=True)
78 |
79 | with tracker.live:
80 | for i in range(len(audio_files)):
81 | signal = AudioSignal(audio_files[i])
82 | recons = process(signal, accel, generator, **kwargs)
83 | recons.write(output / audio_files[i].name)
84 |
85 | tracker.done("test", f"N={len(audio_files)}")
86 |
87 |
88 | if __name__ == "__main__":
89 | args = argbind.parse_args()
90 | with argbind.scope(args):
91 | with Accelerator() as accel:
92 | get_samples(accel)
93 |
--------------------------------------------------------------------------------
/scripts/input_pipeline.py:
--------------------------------------------------------------------------------
1 | from typing import List, Mapping
2 |
3 | import argbind
4 | from audiotree import SaliencyParams
5 | from audiotree.datasources import (
6 | AudioDataSimpleSource,
7 | AudioDataBalancedSource,
8 | )
9 | from audiotree.transforms import ReduceBatchTransform
10 | from grain import python as grain
11 |
12 | SaliencyParams = argbind.bind(SaliencyParams, "train", "val", "test", "sample")
13 |
14 |
15 | @argbind.bind("train", "val", "test", "sample")
16 | def create_dataset(
17 | batch_size: int,
18 | sample_rate: int,
19 | duration: float = 0.2,
20 | sources: Mapping[str, List[str]] = None,
21 | extensions: List[str] = None,
22 | mono: int = 1, # bool
23 | train: int = 0, # bool
24 | num_steps: int = None,
25 | seed: int = 0,
26 | worker_count: int = 0,
27 | worker_buffer_size: int = 2,
28 | enable_profiling: int = 0, # bool
29 | num_epochs: int = 1, # for train/val use 1, but for sample set it to None so that it loops forever.
30 | ):
31 |
32 | assert sources is not None
33 |
34 | if train:
35 | assert num_steps is not None and num_steps > 0
36 | datasource = AudioDataBalancedSource(
37 | sources=sources,
38 | num_records=num_steps * batch_size,
39 | sample_rate=sample_rate,
40 | mono=mono,
41 | duration=duration,
42 | extensions=extensions,
43 | saliency_params=SaliencyParams(), # rely on argbind,
44 | )
45 | else:
46 | datasource = AudioDataSimpleSource(
47 | sources=sources,
48 | num_records=num_steps * batch_size if num_steps is not None else None,
49 | sample_rate=sample_rate,
50 | mono=mono,
51 | duration=duration,
52 | extensions=extensions,
53 | )
54 |
55 | shard_options = grain.NoSharding() # todo:
56 |
57 | index_sampler = grain.IndexSampler(
58 | num_records=len(datasource),
59 | num_epochs=num_epochs,
60 | shard_options=shard_options,
61 | shuffle=bool(train),
62 | seed=seed,
63 | )
64 |
65 | pygrain_ops = [
66 | grain.Batch(batch_size=batch_size, drop_remainder=True),
67 | ReduceBatchTransform(),
68 | ]
69 |
70 | dataloader = grain.DataLoader(
71 | data_source=datasource,
72 | sampler=index_sampler,
73 | operations=pygrain_ops,
74 | worker_count=worker_count,
75 | worker_buffer_size=worker_buffer_size,
76 | shard_options=shard_options,
77 | enable_profiling=bool(enable_profiling),
78 | )
79 |
80 | return dataloader
81 |
82 |
83 | if __name__ == "__main__":
84 |
85 | from tqdm import tqdm
86 | from absl import logging
87 |
88 | logging.set_verbosity(logging.INFO)
89 |
90 | folder1 = "/mnt/d/Datasets/dx7/patches-DX7-AllTheWeb-Bridge-Music-Recording-Studio-Sysex-Set-4-Instruments-Bass-Bass3-bass-10-syx-01-SUPERBASS2-note69"
91 | folder2 = "/mnt/d/Datasets/dx7/patches-DX7-AllTheWeb-Bridge-Music-Recording-Studio-Sysex-Set-4-Instruments-Accordion-ACCORD01-SYX-06-AKKORDEON-note69"
92 |
93 | sources = {
94 | "a": [folder1],
95 | "b": [folder2],
96 | }
97 |
98 | num_steps = 1000
99 |
100 | ds = create_dataset(
101 | batch_size=32,
102 | sample_rate=44_100,
103 | sources=sources,
104 | duration=0.5,
105 | train=True,
106 | mono=True,
107 | seed=0,
108 | num_steps=num_steps,
109 | extensions=None,
110 | worker_count=0,
111 | worker_buffer_size=1,
112 | saliency_params=SaliencyParams(False, 8, -70),
113 | )
114 |
115 | for x in tqdm(ds, total=num_steps, desc="Grain Dataset"):
116 | pass
117 |
--------------------------------------------------------------------------------
/scripts/mushra.py:
--------------------------------------------------------------------------------
1 | import string
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import List
5 |
6 | import argbind
7 | import gradio as gr
8 | from audiotools import preference as pr
9 |
10 |
11 | @argbind.bind(without_prefix=True)
12 | @dataclass
13 | class Config:
14 | folder: str = None
15 | save_path: str = "results.csv"
16 | conditions: List[str] = None
17 | reference: str = None
18 | seed: int = 0
19 | share: bool = False
20 | n_samples: int = 10
21 |
22 |
23 | def get_text(wav_file: str):
24 | txt_file = Path(wav_file).with_suffix(".txt")
25 | if Path(txt_file).exists():
26 | with open(txt_file, "r") as f:
27 | txt = f.read()
28 | else:
29 | txt = ""
30 | return f"""
{txt}
"""
31 |
32 |
33 | def main(config: Config):
34 | with gr.Blocks() as app:
35 | save_path = config.save_path
36 | samples = gr.State(pr.Samples(config.folder, n_samples=config.n_samples))
37 |
38 | reference = config.reference
39 | conditions = config.conditions
40 |
41 | player = pr.Player(app)
42 | player.create()
43 | if reference is not None:
44 | player.add("Play Reference")
45 |
46 | user = pr.create_tracker(app)
47 | ratings = []
48 |
49 | with gr.Row():
50 | txt = gr.HTML("")
51 |
52 | with gr.Row():
53 | gr.Button("Rate audio quality", interactive=False)
54 | with gr.Column(scale=8):
55 | gr.HTML(pr.slider_mushra)
56 |
57 | for i in range(len(conditions)):
58 | with gr.Row().style(equal_height=True):
59 | x = string.ascii_uppercase[i]
60 | player.add(f"Play {x}")
61 | with gr.Column(scale=9):
62 | ratings.append(gr.Slider(value=50, interactive=True))
63 |
64 | def build(user, samples, *ratings):
65 | # Filter out samples user has done already, by looking in the CSV.
66 | samples.filter_completed(user, save_path)
67 |
68 | # Write results to CSV
69 | if samples.current > 0:
70 | start_idx = 1 if reference is not None else 0
71 | name = samples.names[samples.current - 1]
72 | result = {"sample": name, "user": user}
73 | for k, r in zip(samples.order[start_idx:], ratings):
74 | result[k] = r
75 | pr.save_result(result, save_path)
76 |
77 | updates, done, pbar = samples.get_next_sample(reference, conditions)
78 | wav_file = updates[0]["value"]
79 |
80 | txt_update = gr.update(value=get_text(wav_file))
81 |
82 | return (
83 | updates
84 | + [gr.update(value=50) for _ in ratings]
85 | + [done, samples, pbar, txt_update]
86 | )
87 |
88 | progress = gr.HTML()
89 | begin = gr.Button("Submit", elem_id="start-survey")
90 | begin.click(
91 | fn=build,
92 | inputs=[user, samples] + ratings,
93 | outputs=player.to_list() + ratings + [begin, samples, progress, txt],
94 | ).then(None, _js=pr.reset_player)
95 |
96 | # Comment this back in to actually launch the script.
97 | app.launch(share=config.share)
98 |
99 |
100 | if __name__ == "__main__":
101 | args = argbind.parse_args()
102 | with argbind.scope(args):
103 | config = Config()
104 | main(config)
105 |
--------------------------------------------------------------------------------
/scripts/organize_daps.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import shutil
4 | from collections import defaultdict
5 | from typing import Tuple
6 |
7 | import argbind
8 | import numpy as np
9 | import tqdm
10 | from audiotools.core import util
11 |
12 |
13 | @argbind.bind()
14 | def split(
15 | audio_files, ratio: Tuple[float, float, float] = (0.8, 0.1, 0.1), seed: int = 0
16 | ):
17 | assert sum(ratio) == 1.0
18 | util.seed(seed)
19 |
20 | idx = np.arange(len(audio_files))
21 | np.random.shuffle(idx)
22 |
23 | b = np.cumsum([0] + list(ratio)) * len(idx)
24 | b = [int(_b) for _b in b]
25 | train_idx = idx[b[0] : b[1]]
26 | val_idx = idx[b[1] : b[2]]
27 | test_idx = idx[b[2] :]
28 |
29 | audio_files = np.array(audio_files)
30 | train_files = audio_files[train_idx]
31 | val_files = audio_files[val_idx]
32 | test_files = audio_files[test_idx]
33 |
34 | return train_files, val_files, test_files
35 |
36 |
37 | def assign(val_split, test_split):
38 | def _assign(value):
39 | if value in val_split:
40 | return "val"
41 | if value in test_split:
42 | return "test"
43 | return "train"
44 |
45 | return _assign
46 |
47 |
48 | DAPS_VAL = ["f2", "m2"]
49 | DAPS_TEST = ["f10", "m10"]
50 |
51 |
52 | @argbind.bind(without_prefix=True)
53 | def process(
54 | dataset: str = "daps",
55 | daps_subset: str = "",
56 | ):
57 | get_split = None
58 | get_value = lambda path: path
59 |
60 | data_path = pathlib.Path("/data")
61 | dataset_path = data_path / dataset
62 | audio_files = util.find_audio(dataset_path)
63 |
64 | if dataset == "daps":
65 | get_split = assign(DAPS_VAL, DAPS_TEST)
66 | get_value = lambda path: (str(path).split("/")[-1].split("_", maxsplit=4)[0])
67 | audio_files = [
68 | x
69 | for x in util.find_audio(dataset_path)
70 | if daps_subset in str(x) and "breaths" not in str(x)
71 | ]
72 |
73 | if get_split is None:
74 | _, val, test = split(audio_files)
75 | get_split = assign(val, test)
76 |
77 | splits = defaultdict(list)
78 | for x in audio_files:
79 | _split = get_split(get_value(x))
80 | splits[_split].append(x)
81 |
82 | with util.chdir(dataset_path):
83 | for k, v in splits.items():
84 | v = sorted(v)
85 | print(f"Processing {k} in {dataset_path} of length {len(v)}")
86 | for _v in tqdm.tqdm(v):
87 | tgt_path = pathlib.Path(
88 | str(_v).replace(str(dataset_path), str(dataset_path / k))
89 | )
90 | tgt_path.parent.mkdir(parents=True, exist_ok=True)
91 | shutil.copyfile(_v, tgt_path)
92 |
93 |
94 | if __name__ == "__main__":
95 | args = argbind.parse_args()
96 | with argbind.scope(args):
97 | process()
98 |
--------------------------------------------------------------------------------
/scripts/save_test_set.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from pathlib import Path
3 |
4 | import argbind
5 | import torch
6 | from audiotools.ml.decorators import Tracker
7 |
8 | import scripts.train as train
9 |
10 |
11 | @torch.no_grad()
12 | def process(batch, test_data):
13 | signal = test_data.transform(batch["signal"].clone(), **batch["transform_args"])
14 | return signal.cpu()
15 |
16 |
17 | @argbind.bind(without_prefix=True)
18 | @torch.no_grad()
19 | def save_test_set(args, sample_rate: int = 44100, output: str = "samples/input"):
20 | tracker = Tracker()
21 | with argbind.scope(args, "test"):
22 | test_data = train.create_dataset(sample_rate=sample_rate)
23 |
24 | global process
25 | process = tracker.track("process", len(test_data))(process)
26 |
27 | output = Path(output)
28 | output.mkdir(parents=True, exist_ok=True)
29 | (output.parent / "input").mkdir(parents=True, exist_ok=True)
30 | with open(output / "metadata.csv", "w") as csvfile:
31 | keys = ["path", "original"]
32 | writer = csv.DictWriter(csvfile, fieldnames=keys)
33 | writer.writeheader()
34 |
35 | with tracker.live:
36 | for i in range(len(test_data)):
37 | signal = process(test_data[i], test_data)
38 | input_path = output.parent / "input" / f"sample_{i}.wav"
39 | metadata = {
40 | "path": str(input_path),
41 | "original": str(signal.path_to_input_file),
42 | }
43 | writer.writerow(metadata)
44 | signal.write(input_path)
45 | tracker.done("test", f"N={len(test_data)}")
46 |
47 |
48 | if __name__ == "__main__":
49 | args = argbind.parse_args()
50 | with argbind.scope(args):
51 | save_test_set(args)
52 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = dac_jax
3 | version = attr: dac_jax.__version__
4 | url = https://github.com/DBraun/DAC-JAX
5 | author = David Braun
6 | author_email = braun@ccrma.stanford.edu
7 | description = Descript Audio Codec and EnCodec in JAX.
8 | long_description = file: README.md
9 | long_description_content_type = "text/markdown"
10 | keywords =
11 | audio
12 | compression
13 | machine learning
14 | license = MIT
15 | classifiers =
16 | Intended Audience :: Developers
17 | Natural Language :: English
18 | Programming Language :: Python :: 3.10
19 | Programming Language :: Python :: 3.11
20 | Programming Language :: Python :: 3.12
21 | Programming Language :: Python :: 3.13
22 | Topic :: Artistic Software
23 | Topic :: Multimedia
24 | Topic :: Multimedia :: Sound/Audio
25 | Topic :: Multimedia :: Sound/Audio :: Editors
26 | Topic :: Software Development :: Libraries
27 |
28 | [options]
29 | package_dir =
30 | = src
31 | packages = find:
32 | python_requires = >=3.10
33 | install_requires =
34 | argbind @ git+https://github.com/DBraun/argbind.git@improve.subclasses
35 | audiotree>=0.2.0
36 | clu>=0.0.12
37 | dm_aux @ git+https://github.com/DBraun/dm_aux.git@DBraun-patch-2
38 | einops>=0.8.0
39 | grain==0.2.*
40 | huggingface-hub
41 | jax-ai-stack>=2025.2.5
42 | jaxloudnorm @ git+https://github.com/boris-kuz/jaxloudnorm.git
43 | librosa>=0.10.1
44 | omegaconf
45 | tqdm>=4.66.4
46 |
47 | [options.packages.find]
48 | where = src
49 |
50 | [options.extras_require]
51 | dev =
52 | audiocraft
53 | descript-audiotools
54 | descript-audio-codec
55 | pytest
56 | pytest-cov
57 | pandas
58 | pandas
59 | pesq
60 | encodec
61 |
--------------------------------------------------------------------------------
/src/dac_jax/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.0"
2 |
3 | __author__ = """David Braun"""
4 | __email__ = "braun@ccrma.stanford.edu"
5 |
6 | from dac_jax import nn
7 | from dac_jax import model
8 | from dac_jax import utils
9 | from dac_jax.utils import load_model, load_encodec_model
10 | from dac_jax.model import DACFile
11 | from dac_jax.model import DAC
12 | from dac_jax.model import EncodecModel
13 | from dac_jax.nn.quantize import QuantizedResult
14 |
--------------------------------------------------------------------------------
/src/dac_jax/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import argbind
4 |
5 | from dac_jax.utils import download_model
6 | from dac_jax.utils import download_encodec
7 | from dac_jax.utils.decode import decode
8 | from dac_jax.utils.encode import encode
9 |
10 | STAGES = ["encode", "decode", "download_model", "download_encodec"]
11 |
12 |
13 | def run(stage: str):
14 | """Run stages.
15 |
16 | Parameters
17 | ----------
18 | stage : str
19 | Stage to run
20 | """
21 | if stage not in STAGES:
22 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
23 | stage_fn = globals()[stage]
24 |
25 | stage_fn()
26 |
27 |
28 | if __name__ == "__main__":
29 | group = sys.argv.pop(1)
30 | args = argbind.parse_args(group=group)
31 |
32 | with argbind.scope(args):
33 | run(group)
34 |
--------------------------------------------------------------------------------
/src/dac_jax/audio_utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import math
3 | from pathlib import Path
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import chex
7 | import dm_aux as aux
8 | from einops import rearrange
9 | import jax.numpy as jnp
10 | import jax.scipy.signal
11 | import jaxloudnorm as jln
12 | import librosa
13 |
14 |
15 | def find_audio(folder: Union[str, Path], ext: List[str] = None) -> List[Path]:
16 | """Finds all audio files in a directory recursively.
17 | Returns a list.
18 |
19 | Parameters
20 | ----------
21 | folder : str
22 | Folder to look for audio files in, recursively.
23 | ext : List[str], optional
24 | Extensions to look for without the ., by default
25 | ``['.wav', '.flac', '.mp3', '.mp4']``.
26 |
27 | Copied from
28 | https://github.com/descriptinc/audiotools/blob/7776c296c711db90176a63ff808c26e0ee087263/audiotools/core/util.py#L225
29 | """
30 | if ext is None:
31 | ext = [".wav", ".flac", ".mp3", ".mp4"]
32 |
33 | folder = Path(folder)
34 | # Take care of case where user has passed in an audio file directly
35 | # into one of the calling functions.
36 | if str(folder).endswith(tuple(ext)):
37 | # if, however, there's a glob in the path, we need to
38 | # return the glob, not the file.
39 | if "*" in str(folder):
40 | return glob.glob(str(folder), recursive=("**" in str(folder)))
41 | else:
42 | return [folder]
43 |
44 | files = []
45 | for x in ext:
46 | files += folder.glob(f"**/*{x}")
47 | return files
48 |
49 |
50 | def compute_stft_padding(
51 | length, window_length: int, hop_length: int, match_stride: bool
52 | ):
53 | """Compute how the STFT should be padded, based on match_stride.
54 |
55 | Parameters
56 | ----------
57 | length: int
58 | window_length : int
59 | Window length of STFT.
60 | hop_length : int
61 | Hop length of STFT.
62 | match_stride : bool
63 | Whether to match stride, making the STFT have the same alignment as convolutional layers.
64 |
65 | Returns
66 | -------
67 | tuple
68 | Amount to pad on either side of audio.
69 | """
70 | if match_stride:
71 | assert (
72 | hop_length == window_length // 4
73 | ), "For match_stride, hop must equal n_fft // 4"
74 | right_pad = math.ceil(length / hop_length) * hop_length - length
75 | pad = (window_length - hop_length) // 2
76 | else:
77 | right_pad = 0
78 | pad = 0
79 |
80 | return right_pad, pad
81 |
82 |
83 | def stft(
84 | x: jnp.ndarray,
85 | frame_length=2048,
86 | hop_factor=0.25,
87 | window="hann",
88 | match_stride=False,
89 | padding_type: str = "reflect",
90 | ):
91 | """Reference:
92 | https://github.com/descriptinc/audiotools/blob/7776c296c711db90176a63ff808c26e0ee087263/audiotools/core/audio_signal.py#L1123
93 | """
94 |
95 | batch_size, num_channels, audio_length = x.shape
96 |
97 | frame_step = int(frame_length * hop_factor)
98 |
99 | right_pad, pad = compute_stft_padding(
100 | audio_length, frame_length, frame_step, match_stride
101 | )
102 | x = jnp.pad(
103 | x, pad_width=((0, 0), (0, 0), (pad, pad + right_pad)), mode=padding_type
104 | )
105 |
106 | x = rearrange(x, "b c t -> (b c) t")
107 |
108 | if window == "sqrt_hann":
109 | from scipy import signal as scipy_signal
110 |
111 | window = jnp.sqrt(scipy_signal.get_window("hann", frame_length))
112 |
113 | # todo: https://github.com/google-deepmind/dm_aux/issues/2
114 | stft_data = aux.spectral.stft(
115 | x,
116 | n_fft=frame_length,
117 | frame_step=frame_step,
118 | window_fn=window,
119 | pad_mode=padding_type,
120 | pad=aux.spectral.Pad.BOTH,
121 | )
122 | stft_data = rearrange(stft_data, "(b c) nt nf -> b c nf nt", b=batch_size)
123 |
124 | if match_stride:
125 | # Drop first two and last two frames, which are added
126 | # because of padding. Now num_frames * hop_length = num_samples.
127 | if hop_factor == 0.25:
128 | stft_data = stft_data[..., 2:-2]
129 | else:
130 | # I think this would be correct if DAC torch ever allowed match_stride==True and hop_factor==0.5
131 | stft_data = stft_data[..., 1:-1]
132 |
133 | return stft_data
134 |
135 |
136 | def mel_spectrogram(
137 | spectrograms: chex.Array,
138 | log_scale: bool = True,
139 | sample_rate: int = 16000,
140 | frame_length: Optional[int] = 2048,
141 | num_features: int = 128,
142 | lower_edge_hertz: float = 0.0,
143 | upper_edge_hertz: Optional[float] = None,
144 | ) -> chex.Array:
145 | """Converts the spectrograms to Mel-scale.
146 |
147 | Adapted from dm_aux:
148 | https://github.com/google-deepmind/dm_aux/blob/77f5ed76df2928bac8550e1c5466c0dac2934be3/dm_aux/spectral.py#L312
149 |
150 | https://en.wikipedia.org/wiki/Mel_scale
151 |
152 | Args:
153 | spectrograms: Input spectrograms of shape [batch_size, time_steps,
154 | num_features].
155 | log_scale: Whether to return the mel_filterbanks in the log scale.
156 | sample_rate: The sample rate of the input audio.
157 | frame_length: The length of each spectrogram frame.
158 | num_features: The number of mel spectrogram features.
159 | lower_edge_hertz: Lowest frequency to consider to general mel filterbanks.
160 | upper_edge_hertz: Highest frequency to consider to general mel filterbanks.
161 | If None, use `sample_rate / 2.0`.
162 |
163 | Returns:
164 | Converted spectrograms in (log) Mel-scale.
165 | """
166 | # This setup mimics tf.signal.linear_to_mel_weight_matrix.
167 | linear_to_mel_weight_matrix = librosa.filters.mel(
168 | sr=sample_rate,
169 | n_fft=frame_length,
170 | n_mels=num_features,
171 | fmin=lower_edge_hertz,
172 | fmax=upper_edge_hertz,
173 | ).T
174 | spectrograms = jnp.matmul(spectrograms, linear_to_mel_weight_matrix)
175 |
176 | if log_scale:
177 | spectrograms = jnp.log(spectrograms + 1e-6)
178 | return spectrograms
179 |
180 |
181 | def decibel_loudness(stft_data: jnp.ndarray, clamp_eps=1e-5, pow=2.0) -> jnp.ndarray:
182 | return jnp.log10(jnp.power(jnp.maximum(jnp.abs(stft_data), clamp_eps), pow))
183 |
184 |
185 | def db2linear(decibels: jnp.ndarray):
186 | return jnp.pow(10.0, decibels / 20.0)
187 |
188 |
189 | def volume_norm(
190 | audio_data: jnp.ndarray,
191 | target_db: jnp.ndarray,
192 | sample_rate: int,
193 | filter_class: str = "K-weighting",
194 | block_size: float = 0.400,
195 | min_loudness: float = -70,
196 | zeros: int = 2048,
197 | ):
198 | """Calculates loudness using an implementation of ITU-R BS.1770-4.
199 | Allows control over gating block size and frequency weighting filters for
200 | additional control. Measure the integrated gated loudness of a signal.
201 |
202 | API is derived from PyLoudnorm, but this implementation is ported to PyTorch
203 | and is tensorized across batches. When on GPU, an FIR approximation of the IIR
204 | filters is used to compute loudness for speed.
205 |
206 | Uses the weighting filters and block size defined by the meter
207 | the integrated loudness is measured based upon the gating algorithm
208 | defined in the ITU-R BS.1770-4 specification.
209 |
210 | Parameters
211 | ----------
212 | audio_data: jnp.ndarray
213 | audio signal [B, C, T]
214 | target_db: jnp.ndarray
215 | array of target decibel loudnesses [B]
216 | sample_rate: int
217 | sample rate of audio_data
218 | filter_class : str, optional
219 | Class of weighting filter used.
220 | K-weighting' (default), 'Fenton/Lee 1'
221 | 'Fenton/Lee 2', 'Dash et al.'
222 | by default "K-weighting"
223 | block_size : float, optional
224 | Gating block size in seconds, by default 0.400
225 | min_loudness : float, optional
226 | Minimum loudness in decibels
227 | zeros : int, optional
228 | The length of the FIR filter. You should pick a power of 2 between 512 and 4096.
229 |
230 | Returns
231 | -------
232 | jnp.ndarray
233 | Audio normalized to `target_db` loudness
234 | jnp.ndarray
235 | Loudness of original audio data.
236 |
237 | Reference: https://github.com/descriptinc/audiotools/blob/master/audiotools/core/loudness.py
238 | """
239 |
240 | padded_audio = audio_data
241 |
242 | original_length = padded_audio.shape[-1]
243 | signal_duration = original_length / sample_rate
244 |
245 | if signal_duration < block_size:
246 | padded_audio = jnp.pad(
247 | padded_audio,
248 | pad_width=(
249 | (0, 0),
250 | (0, 0),
251 | (0, int(block_size * sample_rate) - original_length),
252 | ),
253 | )
254 |
255 | # create BS.1770 meter
256 | meter = jln.Meter(
257 | sample_rate,
258 | filter_class=filter_class,
259 | block_size=block_size,
260 | use_fir=True,
261 | zeros=zeros,
262 | )
263 |
264 | # measure loudness
265 | loudness = jax.vmap(meter.integrated_loudness)(
266 | rearrange(padded_audio, "b c t -> b t c")
267 | )
268 |
269 | loudness = jnp.maximum(loudness, jnp.full_like(loudness, min_loudness))
270 |
271 | audio_data = audio_data * db2linear(target_db - loudness)[:, None, None]
272 |
273 | return audio_data, loudness
274 |
--------------------------------------------------------------------------------
/src/dac_jax/compare/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DBraun/DAC-JAX/919ce4a2a9ec4c5c3fa7d10dcb2944259da00865/src/dac_jax/compare/__init__.py
--------------------------------------------------------------------------------
/src/dac_jax/compare/encodec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from audiotools import AudioSignal
3 | from audiotools.ml import BaseModel
4 | from encodec import EncodecModel
5 |
6 |
7 | class Encodec(BaseModel):
8 | def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
9 | super().__init__()
10 |
11 | if sample_rate == 24000:
12 | self.model = EncodecModel.encodec_model_24khz()
13 | else:
14 | self.model = EncodecModel.encodec_model_48khz()
15 | self.model.set_target_bandwidth(bandwidth)
16 | self.sample_rate = 44100
17 |
18 | def forward(
19 | self,
20 | audio_data: torch.Tensor,
21 | sample_rate: int = 44100,
22 | n_quantizers: int = None,
23 | ):
24 | signal = AudioSignal(audio_data, sample_rate)
25 | signal.resample(self.model.sample_rate)
26 | recons = self.model(signal.audio_data)
27 | recons = AudioSignal(recons, self.model.sample_rate)
28 | recons.resample(sample_rate)
29 | return {"audio": recons.audio_data}
30 |
31 |
32 | if __name__ == "__main__":
33 | import numpy as np
34 | from functools import partial
35 |
36 | model = Encodec()
37 |
38 | for n, m in model.named_modules():
39 | o = m.extra_repr()
40 | p = sum([np.prod(p.size()) for p in m.parameters()])
41 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
42 | setattr(m, "extra_repr", partial(fn, o=o, p=p))
43 | print(model)
44 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
45 |
46 | length = 88200 * 2
47 | x = torch.randn(1, 1, length).to(model.device)
48 | x.requires_grad_(True)
49 | x.retain_grad()
50 |
51 | # Make a forward pass
52 | out = model(x)["audio"]
53 |
54 | print(x.shape, out.shape)
55 |
--------------------------------------------------------------------------------
/src/dac_jax/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .dac import DAC, DACFile
2 | from .discriminator import Discriminator
3 | from .encodec import SEANetEncoder, SEANetDecoder, EncodecModel
4 |
--------------------------------------------------------------------------------
/src/dac_jax/model/core.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from flax import linen as nn
3 | from jax import numpy as jnp
4 | import typing as tp
5 |
6 | from dac_jax.nn.encodec_quantize import QuantizedResult
7 |
8 |
9 | class CompressionModel(ABC, nn.Module):
10 | """Base API for all compression models that aim at being used as audio tokenizers
11 | with a language model.
12 | """
13 |
14 | @abstractmethod
15 | def __call__(self, x: jnp.ndarray) -> QuantizedResult: ...
16 |
17 | @abstractmethod
18 | def encode(self, x: jnp.ndarray) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]:
19 | """See `EncodecModel.encode`."""
20 | ...
21 |
22 | @abstractmethod
23 | def decode(
24 | self,
25 | codes: jnp.ndarray,
26 | scale: tp.Optional[jnp.ndarray] = None,
27 | length: int = None,
28 | ):
29 | """See `EncodecModel.decode`."""
30 | ...
31 |
32 | @abstractmethod
33 | def decode_latent(self, codes: jnp.ndarray):
34 | """Decode from the discrete codes to continuous latent space."""
35 | ...
36 |
37 | @property
38 | @abstractmethod
39 | def channels(self) -> int: ...
40 |
41 | @property
42 | @abstractmethod
43 | def frame_rate(self) -> float: ...
44 |
45 | @property
46 | @abstractmethod
47 | def sample_rate(self) -> int: ...
48 |
49 | @property
50 | @abstractmethod
51 | def cardinality(self) -> int: ...
52 |
53 | @property
54 | @abstractmethod
55 | def num_codebooks(self) -> int: ...
56 |
57 | @property
58 | @abstractmethod
59 | def total_codebooks(self) -> int: ...
60 |
--------------------------------------------------------------------------------
/src/dac_jax/model/discriminator.py:
--------------------------------------------------------------------------------
1 | from dataclasses import field
2 |
3 | from audiotree.resample import resample
4 | from einops import rearrange
5 | import flax.linen as nn
6 | import jax
7 | from jax import numpy as jnp
8 |
9 | from dac_jax.audio_utils import stft
10 | from dac_jax.nn.layers import make_initializer
11 |
12 |
13 | class LeakyReLU(nn.Module):
14 |
15 | negative_slope: float = 0.01
16 |
17 | @nn.compact
18 | def __call__(self, x):
19 | return nn.leaky_relu(x, negative_slope=self.negative_slope)
20 |
21 |
22 | class WNConv(nn.Conv):
23 |
24 | act: bool = True
25 |
26 | @nn.compact
27 | def __call__(self, x):
28 |
29 | kernel_init = make_initializer(
30 | x.shape[-1],
31 | self.features,
32 | self.kernel_size,
33 | self.feature_group_count,
34 | mode="fan_in",
35 | )
36 |
37 | if self.use_bias:
38 | # note: we just ignore whatever self.bias_init is
39 | bias_init = make_initializer(
40 | x.shape[-1],
41 | self.features,
42 | self.kernel_size,
43 | self.feature_group_count,
44 | mode="fan_in",
45 | )
46 | else:
47 | bias_init = None
48 |
49 | conv = nn.Conv(
50 | features=self.features,
51 | kernel_size=self.kernel_size,
52 | strides=self.strides,
53 | padding=self.padding,
54 | input_dilation=self.input_dilation,
55 | kernel_dilation=self.kernel_dilation,
56 | feature_group_count=self.feature_group_count,
57 | use_bias=self.use_bias,
58 | mask=self.mask,
59 | dtype=self.dtype,
60 | param_dtype=self.param_dtype,
61 | precision=self.precision,
62 | kernel_init=kernel_init,
63 | bias_init=bias_init,
64 | )
65 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3))
66 | block = nn.WeightNorm(conv, scale_init=scale_init)
67 | x = block(x)
68 |
69 | if self.act:
70 | x = LeakyReLU(0.1)(x)
71 |
72 | return x
73 |
74 |
75 | class MPD(nn.Module):
76 |
77 | period: int
78 |
79 | def pad_to_period(self, x):
80 | t = x.shape[-1]
81 | x = jnp.pad(
82 | x,
83 | pad_width=((0, 0), (0, 0), (0, self.period - t % self.period)),
84 | mode="reflect",
85 | )
86 | return x
87 |
88 | @nn.compact
89 | def __call__(self, x):
90 | convs = [
91 | WNConv(
92 | features=32,
93 | kernel_size=(5, 1),
94 | strides=(3, 1),
95 | padding=((2, 2), (0, 0)),
96 | ),
97 | WNConv(
98 | features=128,
99 | kernel_size=(5, 1),
100 | strides=(3, 1),
101 | padding=((2, 2), (0, 0)),
102 | ),
103 | WNConv(
104 | features=512,
105 | kernel_size=(5, 1),
106 | strides=(3, 1),
107 | padding=((2, 2), (0, 0)),
108 | ),
109 | WNConv(
110 | features=1024,
111 | kernel_size=(5, 1),
112 | strides=(3, 1),
113 | padding=((2, 2), (0, 0)),
114 | ),
115 | WNConv(
116 | features=1024,
117 | kernel_size=(5, 1),
118 | strides=(1, 1),
119 | padding=((2, 2), (0, 0)),
120 | ),
121 | WNConv(features=1, kernel_size=(3, 1), padding=((1, 1), (0, 0)), act=False),
122 | ]
123 |
124 | fmap = []
125 |
126 | x = self.pad_to_period(x)
127 | x = rearrange(x, "b c (l p) -> b l p c", p=self.period)
128 |
129 | for layer in convs:
130 | x = layer(x)
131 | fmap.append(x)
132 |
133 | return fmap
134 |
135 |
136 | class MSD(nn.Module):
137 |
138 | rate: int = 1
139 | sample_rate: int = 44100
140 |
141 | @nn.compact
142 | def __call__(self, x):
143 | convs = [
144 | WNConv(features=16, kernel_size=15, strides=1, padding=7),
145 | WNConv(
146 | features=64,
147 | kernel_size=41,
148 | strides=4,
149 | feature_group_count=4,
150 | padding=20,
151 | ),
152 | WNConv(
153 | features=256,
154 | kernel_size=41,
155 | strides=4,
156 | feature_group_count=16,
157 | padding=20,
158 | ),
159 | WNConv(
160 | features=1024,
161 | kernel_size=41,
162 | strides=4,
163 | feature_group_count=64,
164 | padding=20,
165 | ),
166 | WNConv(
167 | features=1024,
168 | kernel_size=41,
169 | strides=4,
170 | feature_group_count=256,
171 | padding=20,
172 | ),
173 | WNConv(features=1024, kernel_size=5, strides=1, padding=2),
174 | WNConv(features=1, kernel_size=3, strides=1, padding=1, act=False),
175 | ]
176 |
177 | x = resample(x, old_sr=self.sample_rate, new_sr=self.sample_rate // self.rate)
178 |
179 | x = rearrange(x, "b c l -> b l c")
180 |
181 | fmap = []
182 |
183 | for layer in convs:
184 | x = layer(x)
185 | fmap.append(x)
186 |
187 | return fmap
188 |
189 |
190 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
191 |
192 |
193 | class MRD(nn.Module):
194 |
195 | window_length: int
196 | hop_factor: float = 0.25
197 | sample_rate: int = 44100
198 | bands: list = field(default_factory=lambda: BANDS)
199 |
200 | def __post_init__(self) -> None:
201 | n_fft = self.window_length // 2 + 1
202 | self.bands = [(int(low * n_fft), int(high * n_fft)) for low, high in self.bands]
203 | super().__post_init__()
204 |
205 | @nn.compact
206 | def __call__(self, x):
207 | """Complex multi-band spectrogram discriminator.
208 | Parameters
209 | ----------
210 | window_length : int
211 | Window length of STFT.
212 | hop_factor : float, optional
213 | Hop factor of the STFT, defaults to ``0.25 * window_length``.
214 | sample_rate : int, optional
215 | Sampling rate of audio in Hz, by default 44100
216 | bands : list, optional
217 | Bands to run discriminator over.
218 | """
219 |
220 | ch = 32
221 | convs = lambda: [
222 | WNConv(
223 | features=ch,
224 | kernel_size=(3, 9),
225 | strides=(1, 1),
226 | padding=((1, 1), (4, 4)),
227 | ),
228 | WNConv(
229 | features=ch,
230 | kernel_size=(3, 9),
231 | strides=(1, 2),
232 | padding=((1, 1), (4, 4)),
233 | ),
234 | WNConv(
235 | features=ch,
236 | kernel_size=(3, 9),
237 | strides=(1, 2),
238 | padding=((1, 1), (4, 4)),
239 | ),
240 | WNConv(
241 | features=ch,
242 | kernel_size=(3, 9),
243 | strides=(1, 2),
244 | padding=((1, 1), (4, 4)),
245 | ),
246 | WNConv(
247 | features=ch,
248 | kernel_size=(3, 3),
249 | strides=(1, 1),
250 | padding=((1, 1), (1, 1)),
251 | ),
252 | ]
253 | band_convs = [convs() for _ in range(len(self.bands))]
254 | conv_post = WNConv(
255 | features=1,
256 | kernel_size=(3, 3),
257 | strides=(1, 1),
258 | padding=((1, 1), (1, 1)),
259 | act=False,
260 | )
261 |
262 | x_bands = self.get_bands(x)
263 | fmap = []
264 |
265 | x = []
266 | for band, stack in zip(x_bands, band_convs):
267 | band = rearrange(band, "b c t f -> b t f c")
268 | for layer in stack:
269 | band = layer(band)
270 | fmap.append(band)
271 | x.append(band)
272 |
273 | x = jnp.concatenate(x, axis=-2) # concatenate along frequency axis
274 | x = conv_post(x)
275 | fmap.append(x)
276 |
277 | return fmap
278 |
279 | def get_bands(self, x):
280 | stft_data = stft(
281 | x,
282 | frame_length=self.window_length,
283 | hop_factor=self.hop_factor,
284 | match_stride=True,
285 | )
286 | x = self.as_real(stft_data)
287 | x = rearrange(
288 | x, "b c f t ri -> (b c) ri t f", c=1, ri=2
289 | ) # ri is 2 for real and imaginary
290 | # Split into bands
291 | x_bands = [x[..., low:high] for low, high in self.bands]
292 | return x_bands
293 |
294 | @staticmethod
295 | def as_real(x: jnp.ndarray) -> jnp.ndarray:
296 | # https://github.com/google/jax/issues/9496#issuecomment-1033961377
297 | if not jnp.issubdtype(x.dtype, jnp.complexfloating):
298 | return x
299 |
300 | return jnp.stack([x.real, x.imag], axis=-1)
301 |
302 |
303 | class Discriminator(nn.Module):
304 |
305 | rates: list = field(default_factory=lambda: [])
306 | periods: list = field(default_factory=lambda: [2, 3, 5, 7, 11])
307 | fft_sizes: list = field(default_factory=lambda: [2048, 1024, 512])
308 | sample_rate: int = 44100
309 | bands: list = field(default_factory=lambda: BANDS)
310 |
311 | @staticmethod
312 | def preprocess(y: jnp.ndarray):
313 | # Remove DC offset
314 | y = y - y.mean(axis=-1, keepdims=True)
315 | # Peak normalize the volume of input audio
316 | y = 0.8 * y / (jnp.abs(y).max(axis=-1, keepdims=True) + 1e-9)
317 | return y
318 |
319 | @nn.compact
320 | def __call__(self, x):
321 | """Discriminator that combines multiple discriminators.
322 |
323 | Parameters
324 | ----------
325 | rates : list, optional
326 | sampling rates (in Hz) to run MSD at, by default []
327 | If empty, MSD is not used.
328 | periods : list, optional
329 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
330 | fft_sizes : list, optional
331 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
332 | sample_rate : int, optional
333 | Sampling rate of audio in Hz, by default 44100
334 | bands : list, optional
335 | Bands to run MRD at, by default `BANDS`
336 | """
337 | discriminators = []
338 | discriminators += [MPD(p) for p in self.periods]
339 | discriminators += [MSD(r, sample_rate=self.sample_rate) for r in self.rates]
340 | discriminators += [
341 | MRD(f, sample_rate=self.sample_rate, bands=self.bands)
342 | for f in self.fft_sizes
343 | ]
344 | x = self.preprocess(x)
345 | fmaps = [d(x) for d in discriminators]
346 | return fmaps
347 |
348 |
349 | if __name__ == "__main__":
350 | import numpy as np
351 |
352 | disc = Discriminator()
353 | x = jnp.zeros(shape=(1, 1, 44100))
354 |
355 | print(
356 | disc.tabulate(
357 | jax.random.key(1),
358 | x,
359 | # compute_flops=True,
360 | # compute_vjp_flops=True,
361 | depth=3,
362 | # column_kwargs={"width": 400},
363 | console_kwargs={"width": 400},
364 | )
365 | )
366 |
367 | results, variables = disc.init_with_output(jax.random.key(3), x)
368 |
369 | for i, result in enumerate(results):
370 | print(f"disc{i}")
371 | for i, _r in enumerate(result):
372 | r = np.array(_r)
373 | print(
374 | r.shape,
375 | f"{r.mean().item():,.5f}, {r.min().item():,.5f} {r.max().item():,.5f}",
376 | )
377 | print("All Done!")
378 |
--------------------------------------------------------------------------------
/src/dac_jax/model/encodec.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import field
8 | import typing as tp
9 |
10 | from einops import rearrange
11 | from flax import linen as nn
12 | from jax import numpy as jnp
13 | import numpy as np
14 |
15 | from dac_jax.model.core import CompressionModel
16 | from dac_jax.nn.quantize import QuantizedResult
17 |
18 | from dac_jax.nn.encodec_layers import (
19 | StreamableConv1d,
20 | StreamableConvTranspose1d,
21 | StreamableLSTM,
22 | )
23 |
24 |
25 | class SEANetResnetBlock(nn.Module):
26 | """Residual block from SEANet model.
27 |
28 | Args:
29 | dim (int): Dimension of the input/output.
30 | kernel_sizes (list): List of kernel sizes for the convolutions.
31 | dilations (list): List of dilations for the convolutions.
32 | activation (str): Activation function.
33 | activation_params (dict): Parameters to provide to the activation function.
34 | norm (str): Normalization method.
35 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
36 | causal (bool): Whether to use fully causal convolution.
37 | pad_mode (str): Padding mode for the convolutions.
38 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
39 | true_skip (bool): Whether to use true skip connection or a simple
40 | (streamable) convolution as the skip connection.
41 | """
42 |
43 | dim: int
44 | kernel_sizes: tp.List[int] = field(default_factory=lambda: [3, 1])
45 | dilations: tp.List[int] = field(default_factory=lambda: [1, 1])
46 | activation: str = "elu"
47 | activation_params: dict = field(default_factory=lambda: {"alpha": 1.0})
48 | norm: str = "none"
49 | norm_params: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
50 | causal: int = 0 # bool
51 | pad_mode: str = "reflect"
52 | compress: int = 2
53 | true_skip: int = 1 # bool
54 |
55 | @nn.compact
56 | def __call__(self, x):
57 | assert len(self.kernel_sizes) == len(
58 | self.dilations
59 | ), "Number of kernel sizes should match number of dilations"
60 | act = lambda y: getattr(nn.activation, self.activation)(
61 | y, **self.activation_params
62 | )
63 | hidden = self.dim // self.compress
64 | block = []
65 | for i, (kernel_size, dilation) in enumerate(
66 | zip(self.kernel_sizes, self.dilations)
67 | ):
68 | out_chs = self.dim if i == len(self.kernel_sizes) - 1 else hidden
69 | block += [
70 | act,
71 | StreamableConv1d(
72 | out_chs,
73 | kernel_size=kernel_size,
74 | dilation=dilation,
75 | norm=self.norm,
76 | norm_kwargs=self.norm_params,
77 | causal=self.causal,
78 | pad_mode=self.pad_mode,
79 | ),
80 | ]
81 | block = nn.Sequential(block)
82 | if self.true_skip:
83 | return x + block(x)
84 | else:
85 | shortcut = StreamableConv1d(
86 | self.dim,
87 | kernel_size=1,
88 | norm=self.norm,
89 | norm_kwargs=self.norm_params,
90 | causal=self.causal,
91 | pad_mode=self.pad_mode,
92 | )
93 |
94 | return shortcut(x) + block(x)
95 |
96 |
97 | class SEANetEncoder(nn.Module):
98 | """SEANet encoder.
99 |
100 | Args:
101 | channels (int): Audio channels.
102 | dimension (int): Intermediate representation dimension.
103 | n_filters (int): Base width for the model.
104 | n_residual_layers (int): nb of residual layers.
105 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
106 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
107 | that must match the decoder order. We use the decoder order as some models may only employ the decoder.
108 | activation (str): Activation function.
109 | activation_params (dict): Parameters to provide to the activation function.
110 | norm (str): Normalization method.
111 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
112 | kernel_size (int): Kernel size for the initial convolution.
113 | last_kernel_size (int): Kernel size for the initial convolution.
114 | residual_kernel_size (int): Kernel size for the residual layers.
115 | dilation_base (int): How much to increase the dilation with each layer.
116 | causal (bool): Whether to use fully causal convolution.
117 | pad_mode (str): Padding mode for the convolutions.
118 | true_skip (bool): Whether to use true skip connection or a simple
119 | (streamable) convolution as the skip connection in the residual network blocks.
120 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
121 | lstm (int): Number of LSTM layers at the end of the encoder.
122 | disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
123 | For the encoder, it corresponds to the N first blocks.
124 | """
125 |
126 | channels: int = 1
127 | dimension: int = 128
128 | n_filters: int = 32
129 | n_residual_layers: int = 3
130 | ratios: tp.List[int] = field(default_factory=lambda: [8, 5, 4, 2])
131 | activation: str = "elu"
132 | activation_params: dict = field(default_factory=lambda: {"alpha": 1.0})
133 | norm: str = "none"
134 | norm_params: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
135 | kernel_size: int = 7
136 | last_kernel_size: int = 7
137 | residual_kernel_size: int = 3
138 | dilation_base: int = 2
139 | causal: bool = False
140 | pad_mode: str = "reflect"
141 | true_skip: bool = True
142 | compress: int = 2
143 | lstm: int = 0
144 | disable_norm_outer_blocks: int = 0
145 |
146 | def __post_init__(self) -> None:
147 | self.hop_length = np.prod(self.ratios)
148 | self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
149 | assert (
150 | self.disable_norm_outer_blocks >= 0
151 | and self.disable_norm_outer_blocks <= self.n_blocks
152 | ), (
153 | "Number of blocks for which to disable norm is invalid."
154 | "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
155 | )
156 | super().__post_init__()
157 |
158 | @nn.compact
159 | def __call__(self, x):
160 | act = lambda y: getattr(nn.activation, self.activation)(
161 | y, **self.activation_params
162 | )
163 | mult = 1
164 | layers = [
165 | StreamableConv1d(
166 | mult * self.n_filters,
167 | kernel_size=self.kernel_size,
168 | norm="none" if self.disable_norm_outer_blocks >= 1 else self.norm,
169 | norm_kwargs=self.norm_params,
170 | causal=self.causal,
171 | pad_mode=self.pad_mode,
172 | )
173 | ]
174 | # Downsample to raw audio scale
175 | for i, ratio in enumerate(reversed(self.ratios)):
176 | block_norm = (
177 | "none" if self.disable_norm_outer_blocks >= i + 2 else self.norm
178 | )
179 | # Add residual layers
180 | for j in range(self.n_residual_layers):
181 | layers += [
182 | SEANetResnetBlock(
183 | mult * self.n_filters,
184 | kernel_sizes=[self.residual_kernel_size, 1],
185 | dilations=[self.dilation_base**j, 1],
186 | norm=block_norm,
187 | norm_params=self.norm_params,
188 | activation=self.activation,
189 | activation_params=self.activation_params,
190 | causal=self.causal,
191 | pad_mode=self.pad_mode,
192 | compress=self.compress,
193 | true_skip=self.true_skip,
194 | )
195 | ]
196 |
197 | # Add downsampling layers
198 | layers += [
199 | act,
200 | StreamableConv1d(
201 | mult * self.n_filters * 2,
202 | kernel_size=ratio * 2,
203 | stride=ratio,
204 | norm=block_norm,
205 | norm_kwargs=self.norm_params,
206 | causal=self.causal,
207 | pad_mode=self.pad_mode,
208 | ),
209 | ]
210 | mult *= 2
211 |
212 | if self.lstm:
213 | layers += [StreamableLSTM(mult * self.n_filters, num_layers=self.lstm)]
214 |
215 | layers += [
216 | act,
217 | StreamableConv1d(
218 | self.dimension,
219 | kernel_size=self.last_kernel_size,
220 | norm=(
221 | "none"
222 | if self.disable_norm_outer_blocks == self.n_blocks
223 | else self.norm
224 | ),
225 | norm_kwargs=self.norm_params,
226 | causal=self.causal,
227 | pad_mode=self.pad_mode,
228 | ),
229 | ]
230 |
231 | model = nn.Sequential(layers)
232 | x = rearrange(x, "B C T -> B T C")
233 | return model(x)
234 |
235 |
236 | class SEANetDecoder(nn.Module):
237 | """SEANet decoder.
238 |
239 | Args:
240 | channels (int): Audio channels.
241 | dimension (int): Intermediate representation dimension.
242 | n_filters (int): Base width for the model.
243 | n_residual_layers (int): nb of residual layers.
244 | ratios (Sequence[int]): kernel size and stride ratios.
245 | activation (str): Activation function.
246 | activation_params (dict): Parameters to provide to the activation function.
247 | final_activation (str): Final activation function after all convolutions.
248 | final_activation_params (dict): Parameters to provide to the activation function.
249 | norm (str): Normalization method.
250 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
251 | kernel_size (int): Kernel size for the initial convolution.
252 | last_kernel_size (int): Kernel size for the initial convolution.
253 | residual_kernel_size (int): Kernel size for the residual layers.
254 | dilation_base (int): How much to increase the dilation with each layer.
255 | causal (bool): Whether to use fully causal convolution.
256 | pad_mode (str): Padding mode for the convolutions.
257 | true_skip (bool): Whether to use true skip connection or a simple.
258 | (streamable) convolution as the skip connection in the residual network blocks.
259 | compress (int): Reduced dimensionality in residual branches (from Demucs v3).
260 | lstm (int): Number of LSTM layers at the end of the encoder.
261 | disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
262 | For the decoder, it corresponds to the N last blocks.
263 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
264 | If equal to 1.0, it means that all the trimming is done at the right.
265 | """
266 |
267 | channels: int = 1
268 | dimension: int = 128
269 | n_filters: int = 32
270 | n_residual_layers: int = 3
271 | ratios: tp.List[int] = field(default_factory=lambda: [8, 5, 4, 2])
272 | activation: str = "elu"
273 | activation_params: dict = field(default_factory=lambda: {"alpha": 1.0})
274 | final_activation: tp.Optional[str] = None
275 | final_activation_params: tp.Optional[dict] = None
276 | norm: str = "none"
277 | norm_params: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
278 | kernel_size: int = 7
279 | last_kernel_size: int = 7
280 | residual_kernel_size: int = 3
281 | dilation_base: int = 2
282 | causal: bool = False
283 | pad_mode: str = "reflect"
284 | true_skip: bool = True
285 | compress: int = 2
286 | lstm: int = 0
287 | disable_norm_outer_blocks: int = 0
288 | trim_right_ratio: float = 1.0
289 |
290 | def __post_init__(self) -> None:
291 | self.hop_length = np.prod(self.ratios)
292 | self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
293 | assert (
294 | self.disable_norm_outer_blocks >= 0
295 | and self.disable_norm_outer_blocks <= self.n_blocks
296 | ), (
297 | "Number of blocks for which to disable norm is invalid."
298 | "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
299 | )
300 | super().__post_init__()
301 |
302 | @nn.compact
303 | def __call__(self, z):
304 | z = z.transpose(0, 2, 1)
305 | act = lambda y: getattr(nn.activation, self.activation)(
306 | y, **self.activation_params
307 | )
308 | mult = int(2 ** len(self.ratios))
309 | layers = [
310 | StreamableConv1d(
311 | mult * self.n_filters,
312 | kernel_size=self.kernel_size,
313 | norm=(
314 | "none"
315 | if self.disable_norm_outer_blocks == self.n_blocks
316 | else self.norm
317 | ),
318 | norm_kwargs=self.norm_params,
319 | causal=self.causal,
320 | pad_mode=self.pad_mode,
321 | )
322 | ]
323 |
324 | if self.lstm:
325 | layers += [StreamableLSTM(mult * self.n_filters, num_layers=self.lstm)]
326 |
327 | # Upsample to raw audio scale
328 | for i, ratio in enumerate(self.ratios):
329 | block_norm = (
330 | "none"
331 | if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
332 | else self.norm
333 | )
334 | # Add upsampling layers
335 | layers += [
336 | act,
337 | StreamableConvTranspose1d(
338 | mult * self.n_filters // 2,
339 | kernel_size=ratio * 2,
340 | stride=ratio,
341 | norm=block_norm,
342 | norm_kwargs=self.norm_params,
343 | causal=self.causal,
344 | trim_right_ratio=self.trim_right_ratio,
345 | ),
346 | ]
347 | # Add residual layers
348 | for j in range(self.n_residual_layers):
349 | layers += [
350 | SEANetResnetBlock(
351 | mult * self.n_filters // 2,
352 | kernel_sizes=[self.residual_kernel_size, 1],
353 | dilations=[self.dilation_base**j, 1],
354 | activation=self.activation,
355 | activation_params=self.activation_params,
356 | norm=block_norm,
357 | norm_params=self.norm_params,
358 | causal=self.causal,
359 | pad_mode=self.pad_mode,
360 | compress=self.compress,
361 | true_skip=self.true_skip,
362 | )
363 | ]
364 |
365 | mult //= 2
366 |
367 | # Add final layers
368 | layers += [
369 | act,
370 | StreamableConv1d(
371 | self.channels,
372 | kernel_size=self.last_kernel_size,
373 | norm="none" if self.disable_norm_outer_blocks >= 1 else self.norm,
374 | norm_kwargs=self.norm_params,
375 | causal=self.causal,
376 | pad_mode=self.pad_mode,
377 | ),
378 | ]
379 | # Add optional final activation to decoder (eg. tanh)
380 | if self.final_activation is not None:
381 | final_act = getattr(nn, self.final_activation)
382 | final_activation_params = self.final_activation_params or {}
383 | layers += [final_act(**final_activation_params)]
384 | model = nn.Sequential(layers)
385 | y = model(z)
386 | y = rearrange(y, "B T C -> B C T")
387 | return y
388 |
389 |
390 | class EncodecModel(CompressionModel):
391 | """Encodec model operating on the raw waveform.
392 |
393 | Args:
394 | encoder (nn.Module): Encoder network.
395 | decoder (nn.Module): Decoder network.
396 | quantizer (qt.BaseQuantizer): Quantizer network.
397 | frame_rate (int): Frame rate for the latent representation.
398 | sample_rate (int): Audio sample rate.
399 | channels (int): Number of audio channels.
400 | causal (bool): Whether to use a causal version of the model.
401 | renormalize (bool): Whether to renormalize the audio before running the model.
402 | """
403 |
404 | encoder: nn.Module
405 | decoder: nn.Module
406 | quantizer: nn.Module # todo: qt.BaseQuantizer,
407 | causal: int = 0 # bool
408 | renormalize: int = 0 # bool
409 |
410 | # todo: must declare these?
411 | frame_rate: float = 0 # todo: or int?
412 | sample_rate: int = 0
413 | channels: int = 0
414 |
415 | def __post_init__(self) -> None:
416 | if self.causal:
417 | # we force disabling here to avoid handling linear overlap of segments
418 | # as supported in original EnCodec codebase.
419 | assert not self.renormalize, "Causal model does not support renormalize"
420 | super().__post_init__()
421 |
422 | @property
423 | def total_codebooks(self):
424 | """Total number of quantizer codebooks available."""
425 | return self.quantizer.total_codebooks
426 |
427 | @property
428 | def num_codebooks(self):
429 | """Active number of codebooks used by the quantizer."""
430 | return self.quantizer.num_codebooks
431 |
432 | def set_num_codebooks(self, n: int):
433 | """Set the active number of codebooks used by the quantizer."""
434 | self.quantizer.set_num_codebooks(n)
435 |
436 | @property
437 | def cardinality(self):
438 | """Cardinality of each codebook."""
439 | return self.quantizer.bins
440 |
441 | def preprocess(
442 | self, x: jnp.ndarray
443 | ) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]:
444 | scale: tp.Optional[jnp.ndarray]
445 | if self.renormalize:
446 | mono = x.mean(axis=1, keepdims=True)
447 | volume = jnp.sqrt(jnp.square(mono).mean(axis=2, keepdims=True))
448 | scale = 1e-8 + volume
449 | x = x / scale
450 | scale = scale.reshape(-1, 1)
451 | else:
452 | scale = None
453 | return x, scale
454 |
455 | def postprocess(
456 | self, x: jnp.ndarray, scale: tp.Optional[jnp.ndarray] = None
457 | ) -> jnp.ndarray:
458 | if scale is not None:
459 | assert self.renormalize
460 | x = x * scale.reshape(-1, 1, 1)
461 | return x
462 |
463 | def __call__(self, x: jnp.ndarray, train=False) -> QuantizedResult:
464 | assert x.ndim == 3
465 | length = x.shape[-1]
466 | x, scale = self.preprocess(x)
467 |
468 | emb = self.encoder(x)
469 | q_res: QuantizedResult = self.quantizer(emb, self.frame_rate, train=train)
470 | out = self.decoder(q_res.z)
471 |
472 | # remove extra padding added by the encoder and decoder
473 | assert out.shape[-1] >= length, (out.shape[-1], length)
474 | out = out[..., :length]
475 |
476 | q_res.recons = self.postprocess(out, scale)
477 |
478 | return q_res
479 |
480 | def encode(
481 | self, x: jnp.ndarray, n_quantizers: int = None
482 | ) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]:
483 | """Encode the given input tensor to quantized representation along with scale parameter.
484 |
485 | Args:
486 | x (jnp.ndarray): Float tensor of shape [B, C, T]
487 |
488 | Returns:
489 | codes, scale (tuple of jnp.ndarray, jnp.ndarray): Tuple composed of:
490 | codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
491 | scale: a float tensor containing the scale for audio renormalization.
492 | """
493 | assert x.ndim == 3
494 | x, scale = self.preprocess(x)
495 | emb = self.encoder(x)
496 | emb = emb.transpose(0, 2, 1)
497 | codes = self.quantizer.encode(emb, n_quantizers)
498 | return codes, scale
499 |
500 | def decode(
501 | self,
502 | codes: jnp.ndarray,
503 | scale: tp.Optional[jnp.ndarray] = None,
504 | length: int = None,
505 | ):
506 | """Decode the given codes to a reconstructed representation, using the scale to perform
507 | audio denormalization if needed.
508 |
509 | Args:
510 | codes (jnp.ndarray): Int tensor of shape [B, K, T]
511 | scale (jnp.ndarray, optional): Float tensor containing the scale value.
512 |
513 | Returns:
514 | out (jnp.ndarray): Float tensor of shape [B, C, T], the reconstructed audio.
515 | """
516 | emb = self.decode_latent(codes)
517 | out = self.decoder(emb)
518 | out = self.postprocess(out, scale)
519 |
520 | # remove extra padding added by the encoder and decoder
521 | if length is not None:
522 | out = out[..., :length]
523 | return out
524 |
525 | def decode_latent(self, codes: jnp.ndarray):
526 | """Decode from the discrete codes to continuous latent space."""
527 | return self.quantizer.decode(codes)
528 |
--------------------------------------------------------------------------------
/src/dac_jax/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from . import layers
2 | from . import loss
3 | from . import quantize
4 |
--------------------------------------------------------------------------------
/src/dac_jax/nn/encodec_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import field
8 | import math
9 | import typing as tp
10 | import warnings
11 |
12 | from flax import linen as nn
13 | from jax import numpy as jnp
14 |
15 | from dac_jax.nn.layers import make_initializer
16 |
17 |
18 | CONV_NORMALIZATIONS = frozenset(
19 | ["none", "weight_norm", "spectral_norm", "time_group_norm"]
20 | )
21 |
22 |
23 | def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
24 | assert norm in CONV_NORMALIZATIONS
25 | if norm == "weight_norm":
26 | # why we use scale_init: https://github.com/google/flax/issues/4138
27 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3))
28 | return nn.WeightNorm(module, scale_init=scale_init)
29 | elif norm == "spectral_norm":
30 | return nn.SpectralNorm(module)
31 | else:
32 | # We already check was in CONV_NORMALIZATION, so any other choice
33 | # doesn't need reparametrization.
34 | return module
35 |
36 |
37 | def get_norm_module(
38 | module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
39 | ):
40 | """Return the proper normalization module. If causal is True, this will ensure the returned
41 | module is causal, or return an error if the normalization doesn't support causal evaluation.
42 | """
43 | assert norm in CONV_NORMALIZATIONS
44 | if norm == "time_group_norm":
45 | if causal:
46 | raise ValueError("GroupNorm doesn't support causal evaluation.")
47 | assert isinstance(module, nn.Conv)
48 | return nn.GroupNorm(num_groups=1, **norm_kwargs)
49 | else:
50 | return lambda x: x
51 |
52 |
53 | def get_extra_padding_for_conv1d(
54 | x: jnp.ndarray, kernel_size: int, stride: int, padding_total: int = 0
55 | ) -> int:
56 | """See `pad_for_conv1d`."""
57 | length = x.shape[-2]
58 | n_frames = (length - kernel_size + padding_total) / stride + 1
59 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
60 | return ideal_length - length
61 |
62 |
63 | def pad_for_conv1d(
64 | x: jnp.ndarray, kernel_size: int, stride: int, padding_total: int = 0
65 | ):
66 | """Pad for a convolution to make sure that the last window is full.
67 | Extra padding is added at the end. This is required to ensure that we can rebuild
68 | an output of the same length, as otherwise, even with padding, some time steps
69 | might get removed.
70 | For instance, with total padding = 4, kernel size = 4, stride = 2:
71 | 0 0 1 2 3 4 5 0 0 # (0s are padding)
72 | 1 2 3 # (output frames of a convolution, last 0 is never used)
73 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
74 | 1 2 3 4 # once you removed padding, we are missing one time step !
75 | """
76 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
77 | return jnp.pad(x, ((0, 0), (0, extra_padding), (0, 0)))
78 |
79 |
80 | def pad1d(
81 | x: jnp.ndarray,
82 | paddings: tp.Tuple[int, int],
83 | mode: str = "constant",
84 | value: float = 0.0,
85 | ):
86 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
87 | If this is the case, we insert extra 0 padding to the right before the reflection happen.
88 | """
89 | length = x.shape[-2]
90 | padding_left, padding_right = paddings
91 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
92 | if mode == "constant":
93 | pad_kwargs = {"constant_values": value}
94 | else:
95 | pad_kwargs = {}
96 | if mode == "reflect":
97 | max_pad = max(padding_left, padding_right)
98 | extra_pad = 0
99 | if length <= max_pad:
100 | extra_pad = max_pad - length + 1
101 | x = jnp.pad(x, ((0, 0), (0, extra_pad), (0, 0)))
102 | padded = jnp.pad(
103 | x, pad_width=((0, 0), paddings, (0, 0)), mode=mode, **pad_kwargs
104 | )
105 | end = padded.shape[-2] - extra_pad
106 | return padded[:, :end, :]
107 | else:
108 | return jnp.pad(x, pad_width=((0, 0), paddings, (0, 0)), mode=mode, **pad_kwargs)
109 |
110 |
111 | def unpad1d(x: jnp.ndarray, paddings: tp.Tuple[int, int]):
112 | """Remove padding from x, handling properly zero padding. Only for 1d!"""
113 | padding_left, padding_right = paddings
114 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
115 | assert (padding_left + padding_right) <= x.shape[-2]
116 | end = x.shape[-2] - padding_right
117 | return x[:, padding_left:end, :]
118 |
119 |
120 | class NormConv1d(nn.Conv):
121 | """Wrapper around Conv and normalization applied to this conv
122 | to provide a uniform interface across normalization approaches.
123 | """
124 |
125 | causal: bool = False
126 | norm: str = "none"
127 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
128 |
129 | @nn.compact
130 | def __call__(self, x):
131 |
132 | # note: we just ignore whatever self.kernel_init is
133 | kernel_init = make_initializer(
134 | x.shape[-1],
135 | self.features,
136 | self.kernel_size,
137 | self.feature_group_count,
138 | mode="fan_in",
139 | )
140 |
141 | if self.use_bias:
142 | # note: we just ignore whatever self.bias_init is
143 | bias_init = make_initializer(
144 | x.shape[-1],
145 | self.features,
146 | self.kernel_size,
147 | self.feature_group_count,
148 | mode="fan_in",
149 | )
150 | else:
151 | bias_init = None
152 |
153 | conv = nn.Conv(
154 | features=self.features,
155 | kernel_size=(self.kernel_size,),
156 | strides=(self.strides,),
157 | padding="VALID",
158 | input_dilation=self.input_dilation,
159 | kernel_dilation=self.kernel_dilation,
160 | feature_group_count=self.feature_group_count,
161 | use_bias=self.use_bias,
162 | mask=self.mask,
163 | dtype=self.dtype,
164 | param_dtype=self.param_dtype,
165 | precision=self.precision,
166 | kernel_init=kernel_init,
167 | bias_init=bias_init,
168 | )
169 | conv = apply_parametrization_norm(conv, self.norm)
170 | norm = get_norm_module(conv, self.causal, self.norm, **self.norm_kwargs)
171 | x = conv(x)
172 | x = norm(x)
173 | return x
174 |
175 |
176 | class NormConv2d(nn.Conv):
177 | """Wrapper around Conv and normalization applied to this conv
178 | to provide a uniform interface across normalization approaches.
179 | """
180 |
181 | norm: str = "none"
182 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
183 |
184 | @nn.compact
185 | def __call__(self, x):
186 |
187 | # note: we just ignore whatever self.kernel_init is
188 | kernel_init = make_initializer(
189 | x.shape[-1],
190 | self.features,
191 | self.kernel_size,
192 | self.feature_group_count,
193 | mode="fan_in",
194 | )
195 |
196 | if self.use_bias:
197 | # note: we just ignore whatever self.bias_init is
198 | bias_init = make_initializer(
199 | x.shape[-1],
200 | self.features,
201 | self.kernel_size,
202 | self.feature_group_count,
203 | mode="fan_in",
204 | )
205 | else:
206 | bias_init = None
207 |
208 | conv = nn.Conv(
209 | features=self.features,
210 | kernel_size=self.kernel_size,
211 | strides=self.strides,
212 | padding="VALID",
213 | input_dilation=self.input_dilation,
214 | kernel_dilation=self.kernel_dilation,
215 | feature_group_count=self.feature_group_count,
216 | use_bias=self.use_bias,
217 | mask=self.mask,
218 | dtype=self.dtype,
219 | param_dtype=self.param_dtype,
220 | precision=self.precision,
221 | kernel_init=kernel_init,
222 | bias_init=bias_init,
223 | )
224 | conv = apply_parametrization_norm(conv, self.norm)
225 | norm = get_norm_module(conv, causal=False, norm=self.norm, **self.norm_kwargs)
226 | x = conv(x)
227 | x = norm(x)
228 | return x
229 |
230 |
231 | class NormConvTranspose1d(nn.ConvTranspose):
232 | """Wrapper around ConvTranspose1d and normalization applied to this conv
233 | to provide a uniform interface across normalization approaches.
234 | """
235 |
236 | causal: bool = False
237 | norm: str = "none"
238 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
239 |
240 | @nn.compact
241 | def __call__(self, x):
242 | groups = 1
243 | # note: we just ignore whatever self.kernel_init is
244 | kernel_init = make_initializer(
245 | x.shape[-1],
246 | self.features,
247 | self.kernel_size,
248 | groups,
249 | mode="fan_out",
250 | )
251 |
252 | if self.use_bias:
253 | # note: we just ignore whatever self.bias_init is
254 | bias_init = make_initializer(
255 | x.shape[-1],
256 | self.features,
257 | self.kernel_size,
258 | groups,
259 | mode="fan_out",
260 | )
261 | else:
262 | bias_init = None
263 |
264 | convtr = nn.ConvTranspose(
265 | features=self.features,
266 | kernel_size=self.kernel_size,
267 | strides=self.strides,
268 | padding="VALID",
269 | kernel_dilation=self.kernel_dilation,
270 | use_bias=self.use_bias,
271 | mask=self.mask,
272 | dtype=self.dtype,
273 | param_dtype=self.param_dtype,
274 | precision=self.precision,
275 | kernel_init=kernel_init,
276 | bias_init=bias_init,
277 | transpose_kernel=True, # note: this helps us load weights from PyTorch
278 | )
279 | convtr = apply_parametrization_norm(convtr, self.norm)
280 | norm = get_norm_module(convtr, self.causal, self.norm, **self.norm_kwargs)
281 | x = convtr(x)
282 | x = norm(x)
283 | return x
284 |
285 |
286 | class NormConvTranspose2d(nn.ConvTranspose):
287 | """Wrapper around ConvTranspose2d and normalization applied to this conv
288 | to provide a uniform interface across normalization approaches.
289 | """
290 |
291 | norm: str = "none"
292 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
293 |
294 | @nn.compact
295 | def __call__(self, x):
296 | groups = 1
297 | # note: we just ignore whatever self.kernel_init is
298 | kernel_init = make_initializer(
299 | x.shape[-1],
300 | self.features,
301 | self.kernel_size,
302 | groups,
303 | mode="fan_out",
304 | )
305 |
306 | if self.use_bias:
307 | # note: we just ignore whatever self.bias_init is
308 | bias_init = make_initializer(
309 | x.shape[-1],
310 | self.features,
311 | self.kernel_size,
312 | groups,
313 | mode="fan_out",
314 | )
315 | else:
316 | bias_init = None
317 |
318 | convtr = nn.ConvTranspose(
319 | features=self.features,
320 | kernel_size=self.kernel_size,
321 | strides=self.strides,
322 | padding="VALID",
323 | kernel_dilation=self.kernel_dilation,
324 | use_bias=self.use_bias,
325 | mask=self.mask,
326 | dtype=self.dtype,
327 | param_dtype=self.param_dtype,
328 | precision=self.precision,
329 | kernel_init=kernel_init,
330 | bias_init=bias_init,
331 | transpose_kernel=True, # note: this helps us load weights from PyTorch
332 | )
333 | convtr = apply_parametrization_norm(convtr, self.norm)
334 | norm = get_norm_module(convtr, causal=False, norm=self.norm, **self.norm_kwargs)
335 | x = convtr(x)
336 | x = norm(x)
337 | return x
338 |
339 |
340 | class StreamableConv1d(nn.Module):
341 | """Conv1d with some builtin handling of asymmetric or causal padding
342 | and normalization.
343 | """
344 |
345 | out_channels: int
346 | kernel_size: int
347 | stride: int = 1
348 | dilation: int = 1
349 | groups: int = 1
350 | bias: bool = True
351 | causal: bool = False
352 | norm: str = "none"
353 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
354 | pad_mode: str = "reflect"
355 |
356 | def __post_init__(self) -> None:
357 | # warn user on unusual setup between dilation and stride
358 | if self.stride > 1 and self.dilation > 1:
359 | warnings.warn(
360 | "StreamableConv1d has been initialized with stride > 1 and dilation > 1"
361 | f" (kernel_size={self.kernel_size} stride={self.stride}, dilation={self.dilation})."
362 | )
363 | super().__post_init__()
364 |
365 | @nn.compact
366 | def __call__(self, x):
367 | conv = NormConv1d(
368 | self.out_channels,
369 | kernel_size=self.kernel_size,
370 | strides=self.stride,
371 | kernel_dilation=self.dilation,
372 | feature_group_count=self.groups,
373 | use_bias=self.bias,
374 | causal=self.causal,
375 | norm=self.norm,
376 | norm_kwargs=self.norm_kwargs,
377 | )
378 | B, T, C = x.shape
379 | kernel_size = conv.kernel_size
380 | stride = conv.strides
381 | dilation = conv.kernel_dilation
382 | kernel_size = (
383 | kernel_size - 1
384 | ) * dilation + 1 # effective kernel size with dilations
385 | padding_total = kernel_size - stride
386 | extra_padding = get_extra_padding_for_conv1d(
387 | x, kernel_size, stride, padding_total
388 | )
389 | if self.causal:
390 | # Left padding for causal
391 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
392 | else:
393 | # Asymmetric padding required for odd strides
394 | padding_right = padding_total // 2
395 | padding_left = padding_total - padding_right
396 | x = pad1d(
397 | x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
398 | )
399 | y = conv(x)
400 | return y
401 |
402 |
403 | class StreamableConvTranspose1d(nn.Module):
404 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding
405 | and normalization.
406 | """
407 |
408 | out_channels: int
409 | kernel_size: int
410 | stride: int = 1
411 | causal: bool = False
412 | norm: str = "none"
413 | trim_right_ratio: float = 1.0
414 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {})
415 |
416 | def __post_init__(self):
417 | assert (
418 | self.causal or self.trim_right_ratio == 1.0
419 | ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
420 | assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
421 | super().__post_init__()
422 |
423 | @nn.compact
424 | def __call__(self, x):
425 | convtr = NormConvTranspose1d(
426 | self.out_channels,
427 | kernel_size=self.kernel_size,
428 | strides=self.stride,
429 | causal=self.causal,
430 | norm=self.norm,
431 | norm_kwargs=self.norm_kwargs,
432 | )
433 | kernel_size = convtr.kernel_size
434 | stride = convtr.strides
435 | padding_total = kernel_size - stride
436 |
437 | y = convtr(x)
438 |
439 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
440 | # removed at the very end, when keeping only the right length for the output,
441 | # as removing it here would require also passing the length at the matching layer
442 | # in the encoder.
443 | if self.causal:
444 | # Trim the padding on the right according to the specified ratio
445 | # if trim_right_ratio = 1.0, trim everything from right
446 | padding_right = math.ceil(padding_total * self.trim_right_ratio)
447 | padding_left = padding_total - padding_right
448 | y = unpad1d(y, (padding_left, padding_right))
449 | else:
450 | # Asymmetric padding required for odd strides
451 | padding_right = padding_total // 2
452 | padding_left = padding_total - padding_right
453 | y = unpad1d(y, (padding_left, padding_right))
454 | return y
455 |
456 |
457 | class StreamableLSTM(nn.Module):
458 | """LSTM without worrying about the hidden state, nor the layout of the data.
459 | Expects input as convolutional layout.
460 | """
461 |
462 | dimension: int
463 | num_layers: int = 2
464 | skip: int = 1 # bool
465 |
466 | @nn.compact
467 | def __call__(self, x):
468 | y = x
469 | for _ in range(self.num_layers):
470 | y = nn.RNN(nn.LSTMCell(self.dimension))(y)
471 |
472 | if self.skip:
473 | y = y + x
474 |
475 | return y
476 |
--------------------------------------------------------------------------------
/src/dac_jax/nn/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import flax.linen as nn
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 |
8 | def default_stride(strides):
9 | if strides is None:
10 | return 1
11 | if isinstance(strides, int):
12 | return strides
13 | return strides[0]
14 |
15 |
16 | def default_kernel_dilation(kernel_dilation):
17 | if kernel_dilation is None:
18 | return 1
19 | if isinstance(kernel_dilation, int):
20 | return kernel_dilation
21 | return kernel_dilation[0]
22 |
23 |
24 | def default_kernel_size(kernel_size):
25 | if kernel_size is None:
26 | return 1
27 | if isinstance(kernel_size, int):
28 | return kernel_size
29 | return kernel_size[0]
30 |
31 |
32 | def conv_to_delay(s, d, k, L):
33 | L = (L - 1) * s + d * (k - 1) + 1
34 | L = math.ceil(L)
35 | return L
36 |
37 |
38 | def convtranspose_to_delay(s, d, k, L):
39 | L = ((L - d * (k - 1) - 1) / s) + 1
40 | L = math.ceil(L)
41 | return L
42 |
43 |
44 | def conv_to_output_length(s, d, k, L):
45 | L = ((L - d * (k - 1) - 1) / s) + 1
46 | L = math.floor(L)
47 | return L
48 |
49 |
50 | def convtranspose_to_output_length(s, d, k, L):
51 | L = (L - 1) * s + d * (k - 1) + 1
52 | L = math.floor(L)
53 | return L
54 |
55 |
56 | def make_initializer(in_channels, out_channels, kernel_size, groups, mode="fan_in"):
57 | # https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
58 | if mode == "fan_in":
59 | c = in_channels
60 | elif mode == "fan_out":
61 | c = out_channels
62 | else:
63 | raise ValueError(f"Unexpected mode: {mode}")
64 | k = groups / (c * jnp.prod(jnp.array(kernel_size)))
65 | scale = jnp.sqrt(k)
66 | return lambda key, shape, dtype: jax.random.uniform(
67 | key, shape, minval=-scale, maxval=scale, dtype=dtype
68 | )
69 |
70 |
71 | class WNConv1d(nn.Conv):
72 |
73 | @nn.compact
74 | def __call__(self, x):
75 | # https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L18-L21
76 | # https://github.com/google/flax/issues/4091
77 | # Note: we are just ignoring whatever self.kernel_init and self.bias_init are.
78 | kernel_init = jax.nn.initializers.truncated_normal(
79 | 0.02, lower=-2 / 0.02, upper=2 / 0.02
80 | )
81 | bias_init = nn.initializers.zeros
82 |
83 | conv = nn.Conv(
84 | features=self.features,
85 | kernel_size=self.kernel_size,
86 | strides=self.strides,
87 | padding=self.padding,
88 | input_dilation=self.input_dilation,
89 | kernel_dilation=self.kernel_dilation,
90 | feature_group_count=self.feature_group_count,
91 | use_bias=self.use_bias,
92 | mask=self.mask,
93 | dtype=self.dtype,
94 | param_dtype=self.param_dtype,
95 | precision=self.precision,
96 | kernel_init=kernel_init,
97 | bias_init=bias_init,
98 | )
99 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3))
100 | block = nn.WeightNorm(conv, scale_init=scale_init)
101 | x = block(x)
102 | return x
103 |
104 | @staticmethod
105 | def delay(s, d, k, L):
106 | s = default_stride(s)
107 | d = default_kernel_dilation(d)
108 | k = default_kernel_size(k)
109 | return conv_to_delay(s, d, k, L)
110 |
111 | @staticmethod
112 | def output_length(s, d, k, L):
113 | s = default_stride(s)
114 | d = default_kernel_dilation(d)
115 | k = default_kernel_size(k)
116 | return conv_to_output_length(s, d, k, L)
117 |
118 |
119 | class WNConvTranspose1d(nn.ConvTranspose):
120 |
121 | @nn.compact
122 | def __call__(self, x):
123 |
124 | groups = 1
125 | # note: we just ignore whatever self.kernel_init is
126 | kernel_init = make_initializer(
127 | x.shape[-1],
128 | self.features,
129 | self.kernel_size,
130 | groups,
131 | mode="fan_out",
132 | )
133 |
134 | if self.use_bias:
135 | # note: we just ignore whatever self.bias_init is
136 | bias_init = make_initializer(
137 | x.shape[-1],
138 | self.features,
139 | self.kernel_size,
140 | groups,
141 | mode="fan_out",
142 | )
143 | else:
144 | bias_init = None
145 |
146 | conv = nn.ConvTranspose(
147 | features=self.features,
148 | kernel_size=self.kernel_size,
149 | strides=self.strides,
150 | padding=self.padding,
151 | kernel_dilation=self.kernel_dilation,
152 | use_bias=self.use_bias,
153 | mask=self.mask,
154 | dtype=self.dtype,
155 | param_dtype=self.param_dtype,
156 | precision=self.precision,
157 | kernel_init=kernel_init,
158 | bias_init=bias_init,
159 | transpose_kernel=True, # note: this helps us load weights from PyTorch
160 | )
161 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3))
162 | block = nn.WeightNorm(conv, scale_init=scale_init)
163 | x = block(x)
164 | return x
165 |
166 | @staticmethod
167 | def delay(s, d, k, L):
168 | s = default_stride(s)
169 | d = default_kernel_dilation(d)
170 | k = default_kernel_size(k)
171 | return convtranspose_to_delay(s, d, k, L)
172 |
173 | @staticmethod
174 | def output_length(s, d, k, L):
175 | s = default_stride(s)
176 | d = default_kernel_dilation(d)
177 | k = default_kernel_size(k)
178 | return convtranspose_to_output_length(s, d, k, L)
179 |
180 |
181 | class Snake1d(nn.Module):
182 |
183 | channels: int
184 |
185 | @nn.compact
186 | def __call__(self, x):
187 | alpha = self.param("alpha", nn.initializers.ones, (1, 1, self.channels))
188 | x = x + jnp.reciprocal(alpha + 1e-9) * jnp.square(jnp.sin(alpha * x))
189 | return x
190 |
--------------------------------------------------------------------------------
/src/dac_jax/nn/loss.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import os
3 | from typing import Callable, Optional
4 |
5 | from einops import rearrange
6 | import jax
7 | import jax.numpy as jnp
8 | import numpy as np
9 |
10 | from dac_jax.audio_utils import stft, decibel_loudness, mel_spectrogram
11 |
12 |
13 | def l1_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray, reduction="mean") -> jnp.ndarray:
14 |
15 | errors = jnp.abs(y_pred - y_true)
16 | if reduction == "none":
17 | return errors
18 | elif reduction == "mean":
19 | return jnp.mean(errors)
20 | elif reduction == "sum":
21 | return jnp.sum(errors)
22 | else:
23 | raise ValueError(f"Invalid reduction method: {reduction}")
24 |
25 |
26 | def sisdr_loss(
27 | y_true: jnp.ndarray,
28 | y_pred: jnp.ndarray,
29 | scaling: int = True,
30 | reduction: str = "mean",
31 | zero_mean: int = True,
32 | clip_min: int = None,
33 | ):
34 | """
35 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
36 | of estimated and reference audio signals or aligned features.
37 |
38 | Parameters
39 | ----------
40 | y_true : jnp.ndarray
41 | Estimate jnp.ndarray
42 | y_pred : jnp.ndarray
43 | Reference jnp.ndarray
44 | scaling : int, optional
45 | Whether to use scale-invariant (True) or
46 | signal-to-noise ratio (False), by default True
47 | reduction : str, optional
48 | How to reduce across the batch (either 'mean',
49 | 'sum', or none).], by default ' mean'
50 | zero_mean : int, optional
51 | Zero mean the references and estimates before
52 | computing the loss, by default True
53 | clip_min : int, optional
54 | The minimum possible loss value. Helps network
55 | to not focus on making already good examples better, by default None
56 |
57 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
58 | """
59 |
60 | eps = 1e-8
61 | # nb, nc, nt
62 | references = y_true
63 | estimates = y_pred
64 |
65 | nb = references.shape[0]
66 | references = references.reshape(nb, 1, -1).transpose(0, 2, 1)
67 | estimates = estimates.reshape(nb, 1, -1).transpose(0, 2, 1)
68 |
69 | # samples now on axis 1
70 | if zero_mean:
71 | mean_reference = references.mean(axis=1, keepdims=True)
72 | mean_estimate = estimates.mean(axis=1, keepdims=True)
73 | else:
74 | mean_reference = 0
75 | mean_estimate = 0
76 |
77 | _references = references - mean_reference
78 | _estimates = estimates - mean_estimate
79 |
80 | references_projection = jnp.square(_references).sum(axis=-2) + eps
81 | references_on_estimates = (_estimates * _references).sum(axis=-2) + eps
82 |
83 | scale = (
84 | jnp.expand_dims(references_on_estimates / references_projection, 1)
85 | if scaling
86 | else 1
87 | )
88 |
89 | e_true = scale * _references
90 | e_res = _estimates - e_true
91 |
92 | signal = jnp.square(e_true).sum(axis=1)
93 | noise = jnp.square(e_res).sum(axis=1)
94 | sdr = -10 * jnp.log10(signal / noise + eps)
95 |
96 | if clip_min is not None:
97 | sdr = jnp.maximum(sdr, clip_min)
98 |
99 | if reduction == "mean":
100 | sdr = sdr.mean()
101 | elif reduction == "sum":
102 | sdr = sdr.sum()
103 | return sdr
104 |
105 |
106 | def discriminator_loss(fake, real):
107 | """
108 | Computes a discriminator loss, given the outputs of the discriminator
109 | used on a fake input and a real input.
110 | """
111 | d_fake, d_real = fake, real
112 |
113 | loss_d = 0
114 | for x_fake, x_real in zip(d_fake, d_real):
115 | loss_d = loss_d + jnp.square(x_fake[-1]).mean()
116 | loss_d = loss_d + jnp.square(1 - x_real[-1]).mean()
117 | # We normalize based on the number of feature maps, but the original DAC doesn't do this.
118 | # loss_d = loss_d / len(d_fake)
119 | return loss_d
120 |
121 |
122 | def generator_loss(fake, real):
123 | """
124 | Computes a generator loss, given the outputs of the discriminator
125 | used on a fake input and a real input.
126 | """
127 | d_fake, d_real = fake, jax.lax.stop_gradient(real)
128 |
129 | loss_g = 0
130 | for x_fake in d_fake:
131 | loss_g = loss_g + jnp.square(1 - x_fake[-1]).mean()
132 |
133 | # We normalize based on the number of feature maps, but the original DAC doesn't do this.
134 | # loss_g = loss_g / len(d_fake)
135 |
136 | loss_feature = 0
137 |
138 | for i in range(len(d_fake)):
139 | for j in range(len(d_fake[i]) - 1):
140 | loss_feature = loss_feature + l1_loss(d_fake[i][j], d_real[i][j])
141 |
142 | # We normalize based on the number of feature maps, but the original DAC doesn't do this.
143 | # loss_feature = loss_feature / sum([len(d_fake[i])-1 for i in range(len(d_fake))])
144 |
145 | return loss_g, loss_feature
146 |
147 |
148 | def multiscale_stft_loss(
149 | y_true: jnp.ndarray,
150 | y_pred: jnp.ndarray,
151 | window_lengths=None,
152 | loss_fn: Callable = l1_loss,
153 | clamp_eps: float = 1e-5,
154 | mag_weight: float = 1.0,
155 | log_weight: float = 1.0,
156 | pow: float = 2.0,
157 | match_stride: Optional[bool] = False,
158 | window: str = "hann",
159 | ):
160 | """Computes the multiscale STFT loss from [1].
161 |
162 | Parameters
163 | ----------
164 | y_true : AudioSignal
165 | Estimate signal
166 | y_pred : AudioSignal
167 | Reference signal
168 | window_lengths : List[int], optional
169 | Length of each window of each STFT, by default [2048, 512]
170 | loss_fn : typing.Callable, optional
171 | How to compare each loss, by default l1_loss
172 | clamp_eps : float, optional
173 | Clamp on the log magnitude, below, by default 1e-5
174 | mag_weight : float, optional
175 | Weight of raw magnitude portion of loss, by default 1.0
176 | log_weight : float, optional
177 | Weight of log magnitude portion of loss, by default 1.0
178 | pow : float, optional
179 | Power to raise magnitude to before taking log, by default 2.0
180 | match_stride : bool, optional
181 | Whether to match the stride of convolutional layers, by default False
182 | window : str or tuple or array_like, optional
183 | Desired window to use. If `window` is a string or tuple, it is
184 | passed to `get_window` to generate the window values, which are
185 | DFT-even by default. See `get_window` for a list of windows and
186 | required parameters. If `window` is array_like it will be used
187 | directly as the window and its length must be nperseg. Defaults
188 | to a Hann window.
189 |
190 | Returns
191 | -------
192 | jnp.ndarray
193 | Multi-scale STFT loss.
194 |
195 | References
196 | ----------
197 |
198 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
199 | "DDSP: Differentiable Digital Signal Processing."
200 | International Conference on Learning Representations. 2019.
201 |
202 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
203 | """
204 |
205 | x = y_pred
206 | y = y_true
207 |
208 | loss = jnp.zeros(())
209 |
210 | if window_lengths is None:
211 | window_lengths = [2048, 512]
212 |
213 | for frame_length in window_lengths:
214 | stft_fun = partial(
215 | stft,
216 | frame_length=frame_length,
217 | hop_factor=0.25,
218 | window=window,
219 | match_stride=match_stride,
220 | )
221 | x_stft = stft_fun(x)
222 | y_stft = stft_fun(y)
223 |
224 | loss = loss + log_weight * loss_fn(
225 | decibel_loudness(x_stft, clamp_eps=clamp_eps, pow=pow),
226 | decibel_loudness(y_stft, clamp_eps=clamp_eps, pow=pow),
227 | )
228 | loss = loss + mag_weight * loss_fn(jnp.abs(x_stft), jnp.abs(y_stft))
229 |
230 | return loss
231 |
232 |
233 | def mel_spectrogram_loss(
234 | y_true: jnp.ndarray,
235 | y_pred: jnp.ndarray,
236 | sample_rate: int,
237 | n_mels=None,
238 | window_lengths=None,
239 | loss_fn: Callable = l1_loss,
240 | clamp_eps: float = 1e-5,
241 | mag_weight: float = 1.0,
242 | log_weight: float = 1.0,
243 | pow: float = 2.0,
244 | match_stride: Optional[bool] = False,
245 | lower_edge_hz=None,
246 | upper_edge_hz=None,
247 | window: str = "hann",
248 | ):
249 | """Compute distance between mel spectrograms. Can be used in a multiscale way.
250 |
251 | Parameters
252 | ----------
253 | y_true : jnp.ndarray
254 | Estimate signal
255 | y_pred : jnp.ndarray
256 | Reference signal
257 | sample_rate : int
258 | Sample rate
259 | n_mels : List[int]
260 | Number of mel bins per STFT, by default [150, 80],
261 | window_lengths : List[int], optional
262 | Length of each window of each STFT, by default [2048, 512]
263 | loss_fn : typing.Callable, optional
264 | How to compare each loss, by default L1Loss()
265 | clamp_eps : float, optional
266 | Clamp on the log magnitude, below, by default 1e-5
267 | mag_weight : float, optional
268 | Weight of raw magnitude portion of loss, by default 1.0
269 | log_weight : float, optional
270 | Weight of log magnitude portion of loss, by default 1.0
271 | pow : float, optional
272 | Power to raise magnitude to before taking log, by default 2.0
273 | match_stride : bool, optional
274 | Whether to match the stride of convolutional layers, by default False
275 | lower_edge_hz: List[float], optional
276 | Lowest frequency to consider to general mel filterbanks.
277 | upper_edge_hz: List[float], optional
278 | Highest frequency to consider to general mel filterbanks.
279 | window : str or tuple or array_like, optional
280 | Desired window to use. If `window` is a string or tuple, it is
281 | passed to `get_window` to generate the window values, which are
282 | DFT-even by default. See `get_window` for a list of windows and
283 | required parameters. If `window` is array_like it will be used
284 | directly as the window and its length must be nperseg. Defaults
285 | to a Hann window.
286 |
287 | Returns
288 | -------
289 | jnp.ndarray
290 | Mel loss.
291 |
292 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
293 | """
294 |
295 | x = y_pred
296 | y = y_true
297 |
298 | if n_mels is None:
299 | n_mels = [150, 80]
300 |
301 | if window_lengths is None:
302 | window_lengths = [2048, 512]
303 |
304 | if lower_edge_hz is None:
305 | lower_edge_hz = [0.0, 0.0]
306 |
307 | if upper_edge_hz is None:
308 | upper_edge_hz = [None, None] # librosa converts None to sample_rate/2
309 |
310 | def decibel_fn(mels: jnp.ndarray) -> jnp.ndarray:
311 | return jnp.log10(jnp.pow(jnp.maximum(mels, clamp_eps), pow))
312 |
313 | loss = jnp.zeros(())
314 | for features, fmin, fmax, frame_length in zip(
315 | n_mels, lower_edge_hz, upper_edge_hz, window_lengths
316 | ):
317 |
318 | def spectrogram_fn(signal):
319 | stft_data = stft(
320 | signal,
321 | frame_length=frame_length,
322 | hop_factor=0.25,
323 | window=window,
324 | match_stride=match_stride,
325 | )
326 | stft_data = rearrange(stft_data, "b c nf nt -> (b c) nt nf")
327 |
328 | spectrogram = jnp.abs(stft_data)
329 | return spectrogram
330 |
331 | x_spectrogram = spectrogram_fn(x)
332 | y_spectrogram = spectrogram_fn(y)
333 |
334 | nf = x_spectrogram.shape[-1]
335 |
336 | mel_fun = partial(
337 | mel_spectrogram,
338 | log_scale=False,
339 | sample_rate=sample_rate,
340 | frame_length=2 * (nf - 1),
341 | num_features=features,
342 | lower_edge_hertz=fmin,
343 | upper_edge_hertz=fmax,
344 | )
345 |
346 | x_mels = mel_fun(x_spectrogram)
347 | y_mels = mel_fun(y_spectrogram)
348 |
349 | loss = loss + log_weight * loss_fn(decibel_fn(x_mels), decibel_fn(y_mels))
350 | loss = loss + mag_weight * loss_fn(x_mels, y_mels)
351 |
352 | return loss
353 |
354 |
355 | def phase_loss(
356 | y_true: jnp.ndarray,
357 | y_pred: jnp.ndarray,
358 | window_length: int = 2048,
359 | hop_factor: float = 0.25,
360 | ):
361 | """Computes phase loss between an estimate and a reference signal.
362 |
363 | Parameters
364 | ----------
365 | y_true : AudioSignal
366 | Reference signal
367 | y_pred : AudioSignal
368 | Estimate signal
369 | window_length : int, optional
370 | Length of STFT window, by default 2048
371 | hop_factor : float, optional
372 | Hop factor between 0 and 1, which is multiplied by the length of STFT
373 | window length to determine the hop size.
374 |
375 | Returns
376 | -------
377 | jnp.ndarray
378 | Phase loss.
379 |
380 | Implementation adapted from https://github.com/descriptinc/audiotools/blob/7776c296c711db90176a63ff808c26e0ee087263/audiotools/metrics/spectral.py#L195
381 | """
382 |
383 | x = y_pred
384 | y = y_true
385 |
386 | stft_fun = partial(
387 | stft, frame_length=window_length, hop_factor=hop_factor, window="hann"
388 | )
389 |
390 | x_stft = stft_fun(x)
391 | y_stft = stft_fun(y)
392 |
393 | def phase(spec):
394 | return jnp.angle(spec)
395 |
396 | # Take circular difference
397 | diff = phase(x_stft) - phase(y_stft)
398 | diff = diff.at[diff < -jnp.pi].set(diff[diff < -jnp.pi] + 2 * jnp.pi)
399 | diff = diff.at[diff > jnp.pi].set(diff[diff > jnp.pi - 2 * jnp.pi])
400 |
401 | # Scale true magnitude to weights in [0, 1]
402 | x_mag = jnp.abs(x_stft)
403 | x_min, x_max = x_mag.min(), x_mag.max()
404 | weights = (x_mag - x_min) / (x_max - x_min)
405 |
406 | # Take weighted mean of all phase errors
407 | loss = jnp.square(weights * diff).mean()
408 | return loss
409 |
410 |
411 | def stoi(
412 | estimates: jnp.ndarray,
413 | references: jnp.ndarray,
414 | sample_rate: int,
415 | extended: int = False,
416 | ):
417 | """Short term objective intelligibility
418 | Computes the STOI (See [1][2]) of a de-noised signal compared to a clean
419 | signal, The output is expected to have a monotonic relation with the
420 | subjective speech-intelligibility, where a higher score denotes better
421 | speech intelligibility. Uses pystoi under the hood.
422 |
423 | Parameters
424 | ----------
425 | estimates : jnp.ndarray
426 | De-noised speech
427 | references : jnp.ndarray
428 | Clean original speech
429 | sample_rate: int
430 | Sample rate of the references
431 | extended : int, optional
432 | Boolean, whether to use the extended STOI described in [3], by default False
433 |
434 | Returns
435 | -------
436 | Tensor[float]
437 | Short time objective intelligibility measure between clean and
438 | de-noised speech
439 |
440 | References
441 | ----------
442 | 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
443 | Objective Intelligibility Measure for Time-Frequency Weighted Noisy
444 | Speech', ICASSP 2010, Texas, Dallas.
445 | 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
446 | Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
447 | IEEE Transactions on Audio, Speech, and Language Processing, 2011.
448 | 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
449 | Intelligibility of Speech Masked by Modulated Noise Maskers',
450 | IEEE Transactions on Audio, Speech and Language Processing, 2016.
451 | """
452 | import pystoi
453 |
454 | if estimates.ndim == 3:
455 | estimates = jnp.average(estimates, axis=-2) # to mono
456 | if references.ndim == 3:
457 | references = jnp.average(references, axis=-2) # to mono
458 |
459 | stois = []
460 | for reference, estimate in zip(references, estimates):
461 | _stoi = pystoi.stoi(
462 | np.array(reference),
463 | np.array(estimates),
464 | sample_rate,
465 | extended=extended,
466 | )
467 | stois.append(_stoi)
468 | return jnp.array(np.array(stois))
469 |
470 |
471 | def pesq(
472 | estimates: jnp.ndarray,
473 | estimates_sample_rate: int,
474 | references: jnp.ndarray,
475 | references_sample_rate: int,
476 | mode: str = "wb",
477 | target_sr: int = 16000,
478 | ):
479 | """_summary_
480 |
481 | Parameters
482 | ----------
483 | estimates : jnp.ndarray
484 | Degraded audio signal
485 | estimates_sample_rate: int
486 | Sample rate of the estimates
487 | references : jnp.ndarray
488 | Reference audio signal
489 | references_sample_rate: int
490 | Sample rate of the references
491 | mode : str, optional
492 | 'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
493 | target_sr : int, optional
494 | Target sample rate, by default 16000
495 |
496 | Returns
497 | -------
498 | Tensor[float]
499 | PESQ score: P.862.2 Prediction (MOS-LQO)
500 | """
501 | from pesq import pesq as pesq_fn
502 | from audiotree.resample import resample
503 |
504 | if estimates.ndim == 3:
505 | estimates = jnp.average(estimates, axis=-2, keepdims=True) # to mono
506 | if references.ndim == 3:
507 | references = jnp.average(references, axis=-2, keepdims=True) # to mono
508 |
509 | estimates = resample(estimates, old_sr=estimates_sample_rate, new_sr=target_sr)
510 | references = resample(references, old_sr=references_sample_rate, new_sr=target_sr)
511 |
512 | pesqs = []
513 | for reference, estimate in zip(references, estimates):
514 | _pesq = pesq_fn(
515 | estimates_sample_rate,
516 | np.array(reference[0]),
517 | np.array(estimate[0]),
518 | mode,
519 | )
520 | pesqs.append(_pesq)
521 | return jnp.array(np.array(pesqs))
522 |
523 |
524 | def visqol(
525 | estimates: jnp.ndarray,
526 | estimates_sample_rate: int,
527 | references: jnp.ndarray,
528 | references_sample_rate: int,
529 | mode: str = "audio",
530 | ): # pragma: no cover
531 | """ViSQOL score.
532 |
533 | Parameters
534 | ----------
535 | estimates : jnp.ndarray
536 | Degraded audio
537 | references : jnp.ndarray
538 | Reference audio
539 | mode : str, optional
540 | 'audio' or 'speech', by default 'audio'
541 |
542 | Returns
543 | -------
544 | Tensor[float]
545 | ViSQOL score (MOS-LQO)
546 | """
547 | from visqol import visqol_lib_py
548 | from visqol.pb2 import visqol_config_pb2
549 | from visqol.pb2 import similarity_result_pb2
550 | from audiotree.resample import resample
551 |
552 | config = visqol_config_pb2.VisqolConfig()
553 | if mode == "audio":
554 | target_sr = 48000
555 | config.options.use_speech_scoring = False
556 | svr_model_path = "libsvm_nu_svr_model.txt"
557 | elif mode == "speech":
558 | target_sr = 16000
559 | config.options.use_speech_scoring = True
560 | svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
561 | else:
562 | raise ValueError(f"Unrecognized mode: {mode}")
563 | config.audio.sample_rate = target_sr
564 | config.options.svr_model_path = os.path.join(
565 | os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
566 | )
567 |
568 | api = visqol_lib_py.VisqolApi()
569 | api.Create(config)
570 |
571 | if estimates.ndim == 3:
572 | estimates = jnp.average(estimates, axis=-2, keepdims=True) # to mono
573 | if references.ndim == 3:
574 | references = jnp.average(references, axis=-2, keepdims=True) # to mono
575 |
576 | estimates = resample(estimates, old_sr=estimates_sample_rate, new_sr=target_sr)
577 | references = resample(references, old_sr=references_sample_rate, new_sr=target_sr)
578 |
579 | visqols = []
580 | for reference, estimate in zip(references, estimates):
581 | _visqol = api.Measure(
582 | np.array(reference[0], dtype=np.float32),
583 | np.array(estimate[0], dtype=np.float32),
584 | )
585 | visqols.append(_visqol.moslqo)
586 | return jnp.array(np.array(visqols))
587 |
--------------------------------------------------------------------------------
/src/dac_jax/nn/quantize.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple
2 |
3 | from einops import rearrange
4 | import flax.linen as nn
5 | import jax
6 | import jax.numpy as jnp
7 | import jax.random
8 |
9 | from dac_jax.nn.encodec_quantize import QuantizedResult
10 | from dac_jax.nn.layers import WNConv1d
11 |
12 |
13 | def mse_loss(
14 | predictions: jnp.ndarray, targets: jnp.ndarray, reduction="mean"
15 | ) -> jnp.ndarray:
16 | errors = (predictions - targets) ** 2
17 | if reduction == "none":
18 | return errors
19 | elif reduction == "mean":
20 | return jnp.mean(errors)
21 | elif reduction == "sum":
22 | return jnp.sum(errors)
23 | else:
24 | raise ValueError(f"Invalid reduction method: {reduction}")
25 |
26 |
27 | def normalize(x, ord=2, axis=1, eps=1e-12):
28 | """Normalizes an array along a specified dimension.
29 |
30 | Args:
31 | x: A JAX array to normalize.
32 | ord: The order of the norm (default is 2, corresponding to L2-norm).
33 | axis: The dimension along which to normalize.
34 | eps: A small constant to avoid division by zero.
35 |
36 | Returns:
37 | A JAX array with normalized vectors.
38 |
39 | Reference:
40 | https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
41 | """
42 | denom = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True)
43 | denom = jnp.maximum(eps, denom)
44 | return x / denom
45 |
46 |
47 | class VectorQuantize(nn.Module):
48 | """
49 | Implementation of VQ similar to Karpathy's repo:
50 | https://github.com/karpathy/deep-vector-quantization
51 | Additionally uses following tricks from Improved VQGAN
52 | (https://arxiv.org/pdf/2110.04627.pdf):
53 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
54 | for improved codebook usage
55 | 2. l2-normalized codes: Converts Euclidean distance to cosine similarity which
56 | improves training stability
57 | """
58 |
59 | input_dim: int
60 | codebook_size: int
61 | codebook_dim: int
62 |
63 | def setup(self):
64 | self.in_proj = WNConv1d(features=self.codebook_dim, kernel_size=(1,))
65 | self.out_proj = WNConv1d(features=self.input_dim, kernel_size=(1,))
66 | # PyTorch uses a normal distribution for weight initialization of Embeddings.
67 | self.codebook = nn.Embed(
68 | num_embeddings=self.codebook_size,
69 | features=self.codebook_dim,
70 | embedding_init=nn.initializers.normal(stddev=1),
71 | )
72 |
73 | def __call__(
74 | self, z
75 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
76 | """Quantized the input tensor using a fixed codebook and returns the corresponding codebook vectors
77 |
78 | Parameters
79 | ----------
80 | z : Tensor[B x T x D]
81 |
82 | Returns
83 | -------
84 | Tensor[B x T x D]
85 | Quantized continuous representation of input
86 | Tensor[1]
87 | Commitment loss to train encoder to predict vectors closer to codebook
88 | entries
89 | Tensor[1]
90 | Codebook loss to update the codebook
91 | Tensor[B x T]
92 | Codebook indices (quantized discrete representation of input)
93 | Tensor[B x T x D]
94 | Projected latents (continuous representation of input before quantization)
95 | """
96 |
97 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
98 | z_e = self.in_proj(z) # z_e : (B x T x D)
99 | z_q, indices = self.decode_latents(z_e)
100 |
101 | commitment_loss = mse_loss(
102 | z_e, jax.lax.stop_gradient(z_q), reduction="none"
103 | ).mean([1, 2])
104 | codebook_loss = mse_loss(
105 | z_q, jax.lax.stop_gradient(z_e), reduction="none"
106 | ).mean([1, 2])
107 |
108 | z_q = z_e + jax.lax.stop_gradient(
109 | z_q - z_e
110 | ) # noop in forward pass, straight-through gradient estimator in backward pass
111 |
112 | z_q = self.out_proj(z_q)
113 |
114 | return z_q, commitment_loss, codebook_loss, indices, z_e
115 |
116 | def embed_code(self, embed_id):
117 | return self.codebook(embed_id)
118 |
119 | def decode_code(self, embed_id):
120 | return self.embed_code(embed_id)
121 |
122 | def decode_latents(self, latents: jnp.ndarray):
123 | encodings = rearrange(latents, "b t d -> (b t) d", d=self.codebook_dim)
124 | codebook = self.codebook.embedding # codebook: (N x D)
125 | # L2 normalize encodings and codebook (ViT-VQGAN)
126 | encodings = normalize(encodings)
127 | codebook = normalize(codebook)
128 |
129 | # Compute Euclidean distance with codebook
130 | dist = (
131 | jnp.square(encodings).sum(1, keepdims=True)
132 | - 2 * encodings @ codebook.transpose()
133 | + jnp.square(codebook).sum(1, keepdims=True).transpose()
134 | )
135 | indices = rearrange(
136 | jnp.argmax(-dist, axis=1), "(b t) -> b t", b=latents.shape[0]
137 | )
138 | z_q = self.decode_code(indices)
139 | return z_q, indices
140 |
141 |
142 | class ResidualVectorQuantize(nn.Module):
143 | """
144 | Introduced in SoundStream: An End-to-End Neural Audio Codec
145 | https://arxiv.org/abs/2107.03312
146 | """
147 |
148 | input_dim: int = 512
149 | num_codebooks: int = 9
150 | codebook_size: int = 1024
151 | codebook_dim: Union[int, list] = 8
152 | quantizer_dropout: float = 0.0
153 |
154 | def __post_init__(self) -> None:
155 | if isinstance(self.codebook_dim, int):
156 | self.codebook_dim = [self.codebook_dim for _ in range(self.num_codebooks)]
157 | super().__post_init__()
158 |
159 | def setup(self) -> None:
160 |
161 | self.quantizers = [
162 | VectorQuantize(self.input_dim, self.codebook_size, self.codebook_dim[i])
163 | for i in range(self.num_codebooks)
164 | ]
165 |
166 | def __call__(self, z, n_quantizers: int = None, train=True) -> QuantizedResult:
167 | z_q = 0
168 | residual = z
169 | commitment_loss = jnp.zeros(())
170 | codebook_loss = jnp.zeros(())
171 |
172 | codebook_indices = []
173 | latents = []
174 |
175 | if n_quantizers is None:
176 | n_quantizers = self.num_codebooks
177 | if train:
178 | n_quantizers = jnp.ones((z.shape[0],)) * self.num_codebooks + 1
179 | dropout = jax.random.randint(
180 | self.make_rng("rng_stream"),
181 | shape=(z.shape[0],),
182 | minval=1,
183 | maxval=self.num_codebooks + 1,
184 | )
185 | n_dropout = int(z.shape[0] * self.quantizer_dropout)
186 | n_quantizers = n_quantizers.at[:n_dropout].set(dropout[:n_dropout])
187 |
188 | # todo: this loop would possibly compile faster if jax.lax.scan were used
189 | for i, quantizer in enumerate(self.quantizers):
190 | if not train and i >= n_quantizers:
191 | break
192 |
193 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
194 | residual
195 | )
196 |
197 | # Create mask to apply quantizer dropout
198 | mask = jnp.full((z.shape[0],), fill_value=i) < n_quantizers
199 | z_q = z_q + z_q_i * mask[:, None, None]
200 | residual = residual - z_q_i
201 |
202 | # Sum losses
203 | commitment_loss = commitment_loss + (commitment_loss_i * mask).mean()
204 | codebook_loss = codebook_loss + (codebook_loss_i * mask).mean()
205 |
206 | codebook_indices.append(indices_i)
207 | latents.append(z_e_i)
208 |
209 | codes = jnp.stack(codebook_indices, axis=1)
210 | latents = jnp.concatenate(latents, axis=2).transpose(0, 2, 1)
211 |
212 | # normalize based on number of codebooks
213 | # commitment_loss = commitment_loss / self.num_codebooks
214 | # codebook_loss = codebook_loss / self.num_codebooks
215 |
216 | return QuantizedResult(
217 | z_q,
218 | codes=codes,
219 | bandwidth=None,
220 | penalty=None,
221 | metrics=None,
222 | latents=latents,
223 | commitment_loss=commitment_loss,
224 | codebook_loss=codebook_loss,
225 | )
226 |
227 | def from_codes(self, codes: jnp.ndarray):
228 | """Given the quantized codes, reconstruct the continuous representation
229 | Parameters
230 | ----------
231 | codes : Tensor[B x T x N]
232 | Quantized discrete representation of input
233 | Returns
234 | -------
235 | Tensor[B x T x D]
236 | Quantized continuous representation of input
237 | """
238 | z_q = 0.0
239 | z_p = []
240 | num_codebooks = codes.shape[-2]
241 | assert num_codebooks <= self.num_codebooks
242 |
243 | # todo: use jax.lax.scan for this loop
244 | for i in range(num_codebooks):
245 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
246 | z_p.append(z_p_i)
247 |
248 | z_q_i = self.quantizers[i].out_proj(z_p_i)
249 | z_q = z_q + z_q_i
250 |
251 | return z_q, jnp.concatenate(z_p, axis=1), codes
252 |
253 | def from_latents(self, latents: jnp.ndarray):
254 | # todo: this function hasn't been tested/used yet.
255 |
256 | """Given the unquantized latents, reconstruct the
257 | continuous representation after quantization.
258 |
259 | Parameters
260 | ----------
261 | latents : Tensor[B x T x N]
262 | Continuous representation of input after projection
263 |
264 | Returns # todo: make this return info correct
265 | -------
266 | Tensor[B x T x D]
267 | Quantized representation of full-projected space
268 | Tensor[B x T x D]
269 | Quantized representation of latent space
270 | """
271 | z_q = 0
272 | z_p = []
273 | codes = []
274 | dims = jnp.cumsum([0] + [q.codebook_dim for q in self.quantizers])
275 |
276 | num_codebooks = jnp.where(dims <= latents.shape[2])[0].max(
277 | axis=0, keepdims=True
278 | ) # todo: check
279 |
280 | # todo: use jax.lax.scan for this loop
281 | for i in range(num_codebooks):
282 | j, k = dims[i], dims[i + 1]
283 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, :, j:k])
284 | z_p.append(z_p_i)
285 | codes.append(codes_i)
286 |
287 | z_q_i = self.quantizers[i].out_proj(z_p_i)
288 | z_q = z_q + z_q_i
289 |
290 | return z_q, jnp.concatenate(z_p, axis=1), jnp.stack(codes, axis=1)
291 |
292 |
293 | if __name__ == "__main__":
294 | rvq = ResidualVectorQuantize(quantizer_dropout=True)
295 | key = jax.random.PRNGKey(0)
296 | key, subkey = jax.random.split(key)
297 | x = jax.random.normal(key=subkey, shape=(16, 80, 512))
298 |
299 | key, subkey = jax.random.split(key)
300 | params = rvq.init({"params": subkey, "rng_stream": jax.random.key(4)}, x)["params"]
301 | z_q, codes, latents, commitment_loss, codebook_loss = rvq.apply(
302 | {"params": params}, x, rngs={"rng_stream": jax.random.key(4)}
303 | )
304 | print(latents.shape)
305 |
--------------------------------------------------------------------------------
/src/dac_jax/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import environ
3 | from pathlib import Path
4 | import json
5 | import typing as tp
6 |
7 | import argbind
8 | from huggingface_hub import hf_hub_download
9 | import numpy as np
10 | from omegaconf import OmegaConf
11 | import torch
12 |
13 | from dac_jax.utils import load_torch_weights_encodec
14 | from dac_jax.utils import load_torch_weights
15 | from dac_jax.model import DAC
16 | from dac_jax.model import EncodecModel, SEANetEncoder, SEANetDecoder
17 | from dac_jax.nn.encodec_quantize import ResidualVectorQuantizer
18 |
19 |
20 | def get_audiocraft_cache_dir() -> tp.Optional[str]:
21 | return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
22 |
23 |
24 | def _get_state_dict(
25 | file_or_url_or_id: tp.Union[Path, str],
26 | filename: tp.Optional[str] = None,
27 | device='cpu',
28 | cache_dir: tp.Optional[str] = None,
29 | ):
30 | if cache_dir is None:
31 | cache_dir = get_audiocraft_cache_dir()
32 | # Return the state dict either from a file or url
33 | file_or_url_or_id = str(file_or_url_or_id)
34 | assert isinstance(file_or_url_or_id, str)
35 |
36 | if os.path.isfile(file_or_url_or_id):
37 | return torch.load(file_or_url_or_id, map_location=device)
38 |
39 | if os.path.isdir(file_or_url_or_id):
40 | file = f"{file_or_url_or_id}/{filename}"
41 | return torch.load(file, map_location=device)
42 |
43 | elif file_or_url_or_id.startswith('https://'):
44 | return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
45 |
46 | else:
47 | assert filename is not None, "filename needs to be defined if using HF checkpoints"
48 |
49 | file = hf_hub_download(
50 | repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
51 | library_name="audiocraft", library_version="1.3.0")
52 | return torch.load(file, map_location=device)
53 |
54 |
55 | try:
56 | from audiocraft.models.loaders import load_compression_model_ckpt
57 | except Exception as e:
58 | def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
59 | return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
60 |
61 |
62 | __MODEL_LATEST_TAGS__ = {
63 | ("44khz", "8kbps"): "0.0.1",
64 | ("24khz", "8kbps"): "0.0.4",
65 | ("16khz", "8kbps"): "0.0.5",
66 | ("44khz", "16kbps"): "1.0.0",
67 | }
68 |
69 | __MODEL_URLS__ = {
70 | (
71 | "44khz",
72 | "0.0.1",
73 | "8kbps",
74 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
75 | (
76 | "24khz",
77 | "0.0.4",
78 | "8kbps",
79 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
80 | (
81 | "16khz",
82 | "0.0.5",
83 | "8kbps",
84 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
85 | (
86 | "44khz",
87 | "1.0.0",
88 | "16kbps",
89 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
90 | }
91 |
92 |
93 | def convert_torch_weights_to_numpy(
94 | torch_weights_path: Path, write_path: Path, metadata_path: Path
95 | ):
96 |
97 | if write_path.exists() and metadata_path.exists():
98 | return
99 |
100 | if not write_path.exists():
101 | write_path.parent.mkdir(parents=True, exist_ok=True)
102 |
103 | weights = torch.load(str(torch_weights_path), map_location=torch.device("cpu"))
104 |
105 | kwargs = weights["metadata"]["kwargs"]
106 | with open(metadata_path, "w") as f:
107 | f.write(json.dumps(kwargs))
108 |
109 | weights = weights["state_dict"]
110 | weights = {key: value.numpy() for key, value in weights.items()}
111 |
112 | allow_pickle = (
113 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53
114 | )
115 |
116 | np.save(write_path, weights, allow_pickle=allow_pickle)
117 |
118 |
119 | @argbind.bind(group="download_encodec", positional=True, without_prefix=True)
120 | def download_encodec(
121 | name: str = "facebook/musicgen-small",
122 | ):
123 | if (
124 | "DAC_JAX_CACHE" in environ
125 | and environ["DAC_JAX_CACHE"].strip()
126 | and os.path.isabs(environ["DAC_JAX_CACHE"])
127 | ):
128 | cache_home = environ["DAC_JAX_CACHE"]
129 | cache_home = Path(cache_home)
130 | else:
131 | cache_home = Path.home() / ".cache" / "dac_jax"
132 |
133 | safename = name.replace("/", "_")
134 |
135 | metadata_path = cache_home / f"encodec_weights_{safename}.json"
136 | jax_write_path = cache_home / f"encodec_jax_weights_{safename}.npy"
137 |
138 | if jax_write_path.exists() and metadata_path.exists():
139 | return jax_write_path, metadata_path
140 |
141 | torch_model_path = cache_home / f"encodec_weights_{safename}.pth"
142 |
143 | if not torch_model_path.exists():
144 | torch_model_path.parent.mkdir(parents=True, exist_ok=True)
145 |
146 | file_or_url_or_id = name
147 | pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=str(cache_home))
148 | cfg = OmegaConf.create(pkg["xp.cfg"])
149 |
150 | weights = pkg["best_state"]
151 | weights = {key: value.numpy() for key, value in weights.items()}
152 |
153 | jax_write_path.parent.mkdir(parents=True, exist_ok=True)
154 |
155 | allow_pickle = (
156 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53
157 | )
158 |
159 | np.save(jax_write_path, weights, allow_pickle=allow_pickle)
160 |
161 | OmegaConf.save(config=cfg, f=metadata_path)
162 |
163 | return jax_write_path, metadata_path
164 |
165 |
166 | # todo: we don't call this function `download` because that would conflict with the PyTorch implementation's `download`.
167 | # and we need to be able to run both in our tests.
168 | # Reference issue: https://github.com/pseeth/argbind/?tab=readme-ov-file#bound-function-names-should-be-unique
169 | @argbind.bind(group="download_model", positional=True, without_prefix=True)
170 | def download_model(
171 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
172 | ):
173 | """
174 | Function that downloads the weights file from URL if a local cache is not found.
175 |
176 | Parameters
177 | ----------
178 | model_type : str
179 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
180 | model_bitrate: str
181 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
182 | Only 44khz model supports 16kbps.
183 | tag : str
184 | The tag of the model to download. Defaults to "latest".
185 |
186 | Returns
187 | -------
188 | Path
189 | Directory path required to load model via audiotools.
190 | """
191 | model_type = model_type.lower()
192 | tag = tag.lower()
193 |
194 | if (
195 | "DAC_JAX_CACHE" in environ
196 | and environ["DAC_JAX_CACHE"].strip()
197 | and os.path.isabs(environ["DAC_JAX_CACHE"])
198 | ):
199 | cache_home = environ["DAC_JAX_CACHE"]
200 | cache_home = Path(cache_home)
201 | else:
202 | cache_home = Path.home() / ".cache" / "dac_jax"
203 |
204 | metadata_path = cache_home / f"weights_{model_type}_{model_bitrate}_{tag}.json"
205 | jax_write_path = cache_home / f"jax_weights_{model_type}_{model_bitrate}_{tag}.npy"
206 |
207 | if jax_write_path.exists() and metadata_path.exists():
208 | return jax_write_path, metadata_path
209 |
210 | assert model_type in [
211 | "44khz",
212 | "24khz",
213 | "16khz",
214 | ], "model_type must be one of '44khz', '24khz', or '16khz'"
215 |
216 | assert model_bitrate in [
217 | "8kbps",
218 | "16kbps",
219 | ], "model_bitrate must be one of '8kbps', or '16kbps'"
220 |
221 | if tag == "latest":
222 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
223 |
224 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
225 |
226 | if download_link is None:
227 | raise ValueError(
228 | f"Could not find model with tag {tag} and model type {model_type}"
229 | )
230 |
231 | torch_model_path = cache_home / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
232 |
233 | if not torch_model_path.exists():
234 | torch_model_path.parent.mkdir(parents=True, exist_ok=True)
235 |
236 | # Download the model
237 | import requests
238 |
239 | response = requests.get(download_link)
240 |
241 | if response.status_code != 200:
242 | raise ValueError(
243 | f"Could not download model. Received response code {response.status_code}"
244 | )
245 | torch_model_path.write_bytes(response.content)
246 |
247 | convert_torch_weights_to_numpy(torch_model_path, jax_write_path, metadata_path)
248 |
249 | # remove torch model because it's not needed anymore.
250 | if torch_model_path.exists():
251 | os.remove(torch_model_path)
252 |
253 | return jax_write_path, metadata_path
254 |
255 |
256 | def load_encodec_model(
257 | name: str = "facebook/musicgen-small",
258 | load_path: str = None,
259 | metadata_path: str = None,
260 | ):
261 | if not load_path or not metadata_path:
262 | load_path, metadata_path = download_encodec(name)
263 |
264 | kwargs = OmegaConf.load(metadata_path)
265 |
266 | seanet_kwargs = kwargs["seanet"]
267 |
268 | common_kwargs = {
269 | "channels": kwargs["channels"],
270 | "dimension": seanet_kwargs["dimension"],
271 | "n_filters": seanet_kwargs["n_filters"],
272 | "n_residual_layers": seanet_kwargs["n_residual_layers"],
273 | "ratios": seanet_kwargs["ratios"],
274 | "activation": seanet_kwargs["activation"].lower(),
275 | "activation_params": OmegaConf.to_object(seanet_kwargs["activation_params"]),
276 | "norm": seanet_kwargs["norm"],
277 | "norm_params": OmegaConf.to_object(seanet_kwargs["norm_params"]),
278 | "kernel_size": seanet_kwargs["kernel_size"],
279 | "last_kernel_size": seanet_kwargs["last_kernel_size"],
280 | "residual_kernel_size": seanet_kwargs["residual_kernel_size"],
281 | "dilation_base": seanet_kwargs["dilation_base"],
282 | "causal": kwargs["encodec"]["causal"],
283 | "pad_mode": seanet_kwargs["pad_mode"],
284 | "true_skip": seanet_kwargs["true_skip"],
285 | "compress": seanet_kwargs["compress"],
286 | "lstm": seanet_kwargs["compress"],
287 | "disable_norm_outer_blocks": seanet_kwargs["disable_norm_outer_blocks"],
288 | }
289 | encoder_override_kwargs = {}
290 | decoder_override_kwargs = {
291 | "trim_right_ratio": seanet_kwargs["decoder"]["trim_right_ratio"],
292 | "final_activation": seanet_kwargs["decoder"]["final_activation"],
293 | "final_activation_params": seanet_kwargs["decoder"]["final_activation_params"],
294 | }
295 | encoder_kwargs = {**common_kwargs, **encoder_override_kwargs}
296 | decoder_kwargs = {**common_kwargs, **decoder_override_kwargs}
297 |
298 | rvq_kwargs = kwargs["rvq"]
299 | quantizer_kwargs = {
300 | "dimension": seanet_kwargs["dimension"],
301 | "n_q": rvq_kwargs["n_q"],
302 | "q_dropout": rvq_kwargs["q_dropout"],
303 | "bins": rvq_kwargs["bins"],
304 | "decay": rvq_kwargs["decay"],
305 | "kmeans_init": rvq_kwargs["kmeans_init"],
306 | "kmeans_iters": rvq_kwargs["kmeans_iters"],
307 | "threshold_ema_dead_code": rvq_kwargs["threshold_ema_dead_code"],
308 | "orthogonal_reg_weight": rvq_kwargs["orthogonal_reg_weight"],
309 | "orthogonal_reg_active_codes_only": rvq_kwargs[
310 | "orthogonal_reg_active_codes_only"
311 | ],
312 | "orthogonal_reg_max_codes": None, # todo:
313 | }
314 |
315 | encoder = SEANetEncoder(**encoder_kwargs)
316 | decoder = SEANetDecoder(**decoder_kwargs)
317 | quantizer = ResidualVectorQuantizer(**quantizer_kwargs)
318 |
319 | sample_rate = kwargs["sample_rate"]
320 |
321 | encodec_model = EncodecModel(
322 | encoder=encoder,
323 | decoder=decoder,
324 | quantizer=quantizer,
325 | causal=kwargs["encodec"]["causal"],
326 | renormalize=kwargs["encodec"]["renormalize"],
327 | frame_rate=sample_rate // encoder.hop_length,
328 | sample_rate=sample_rate,
329 | channels=kwargs["channels"],
330 | )
331 |
332 | allow_pickle = (
333 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53
334 | )
335 |
336 | torch_params = np.load(load_path, allow_pickle=allow_pickle)
337 | torch_params = torch_params.item() # todo
338 |
339 | variables = load_torch_weights_encodec.torch_to_linen(
340 | torch_params,
341 | encodec_model.encoder.ratios,
342 | encodec_model.decoder.ratios,
343 | encodec_model.num_codebooks,
344 | )
345 |
346 | return encodec_model, variables
347 |
348 |
349 | def load_model(
350 | model_type: str = "44khz",
351 | model_bitrate: str = "8kbps",
352 | tag: str = "latest",
353 | load_path: str = None,
354 | metadata_path: str = None,
355 | padding=True,
356 | ):
357 | # reference:
358 | # https://flax.readthedocs.io/en/latest/guides/training_techniques/transfer_learning.html#create-a-function-for-model-loading
359 |
360 | if not load_path or not metadata_path:
361 | load_path, metadata_path = download_model(
362 | model_type=model_type, model_bitrate=model_bitrate, tag=tag
363 | )
364 |
365 | with open(str(metadata_path), "r") as f:
366 | kwargs = json.loads(f.read())
367 |
368 | kwargs["padding"] = padding # todo: seems like bad design
369 | kwargs["num_codebooks"] = kwargs.pop("n_codebooks")
370 |
371 | model = DAC(**kwargs)
372 |
373 | allow_pickle = (
374 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53
375 | )
376 |
377 | torch_params = np.load(load_path, allow_pickle=allow_pickle)
378 | torch_params = torch_params.item()
379 |
380 | variables = load_torch_weights.torch_to_linen(
381 | torch_params, model.encoder_rates, model.decoder_rates, model.num_codebooks
382 | )
383 |
384 | return model, variables
385 |
386 |
387 | if __name__ == "__main__":
388 | load_encodec_model()
389 |
--------------------------------------------------------------------------------
/src/dac_jax/utils/decode.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 |
4 | import argbind
5 | from tqdm import tqdm
6 |
7 | import jax
8 |
9 | from dac_jax import DACFile
10 | from dac_jax.utils import load_model
11 |
12 |
13 | warnings.filterwarnings(
14 | "ignore", category=UserWarning
15 | ) # ignore librosa warnings related to mel bins
16 |
17 |
18 | @jax.jit
19 | @argbind.bind(group="decode", positional=True, without_prefix=True)
20 | def decode(
21 | input: str,
22 | output: str = "",
23 | weights_path: str = "",
24 | model_tag: str = "latest",
25 | model_bitrate: str = "8kbps",
26 | model_type: str = "44khz",
27 | verbose: bool = False,
28 | ):
29 | """Decode audio from codes.
30 |
31 | Parameters
32 | ----------
33 | input : str
34 | Path to input directory or file
35 | output : str, optional
36 | Path to output directory, by default "".
37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38 | weights_path : str, optional
39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet
40 | using the model_tag and model_type.
41 | model_tag : str, optional
42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43 | model_bitrate: str
44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45 | model_type : str, optional
46 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if
47 | `weights_path` is specified.
48 | """
49 | model, variables = load_model(
50 | model_type=model_type,
51 | model_bitrate=model_bitrate,
52 | tag=model_tag,
53 | load_path=weights_path,
54 | )
55 |
56 | # Find all .dac files in input directory
57 | _input = Path(input)
58 | input_files = list(_input.glob("**/*.dac"))
59 |
60 | # If input is a .dac file, add it to the list
61 | if _input.suffix == ".dac":
62 | input_files.append(_input)
63 |
64 | # Create output directory
65 | output = Path(output)
66 | output.mkdir(parents=True, exist_ok=True)
67 |
68 | @jax.jit
69 | def decompress_chunk(c):
70 | return model.apply(variables, c, method="decompress_chunk")
71 |
72 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
73 | # Load file
74 | dac_file = DACFile.load(input_files[i])
75 |
76 | # Reconstruct audio from codes
77 | recons = model.decompress(decompress_chunk, dac_file, verbose=verbose)
78 |
79 | # Compute output path
80 | relative_path = input_files[i].relative_to(input)
81 | output_dir = output / relative_path.parent
82 | if not relative_path.name:
83 | output_dir = output
84 | relative_path = input_files[i]
85 | output_name = relative_path.with_suffix(".wav").name
86 | output_path = output_dir / output_name
87 | output_path.parent.mkdir(parents=True, exist_ok=True)
88 |
89 | # Write to file
90 | recons.write(output_path)
91 |
92 |
93 | if __name__ == "__main__":
94 | args = argbind.parse_args()
95 | with argbind.scope(args):
96 | decode()
97 |
--------------------------------------------------------------------------------
/src/dac_jax/utils/encode.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 |
4 | import argbind
5 | import librosa
6 | from tqdm import tqdm
7 |
8 | import jax
9 | import jax.numpy as jnp
10 |
11 | from dac_jax import load_model
12 | from dac_jax.audio_utils import find_audio
13 |
14 | warnings.filterwarnings(
15 | "ignore", category=UserWarning
16 | ) # ignore librosa warnings related to mel bins
17 |
18 |
19 | @jax.jit
20 | @argbind.bind(group="encode", positional=True, without_prefix=True)
21 | def encode(
22 | input: str,
23 | output: str = "",
24 | weights_path: str = "",
25 | model_tag: str = "latest",
26 | model_bitrate: str = "8kbps",
27 | n_quantizers: int = None,
28 | model_type: str = "44khz",
29 | win_duration: float = 5.0,
30 | verbose: bool = False,
31 | ):
32 | """Encode audio files in input path to .dac format.
33 |
34 | Parameters
35 | ----------
36 | input : str
37 | Path to input audio file or directory
38 | output : str, optional
39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input`
40 | is re-created in `output`.
41 | weights_path : str, optional
42 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet
43 | using the model_tag and model_type.
44 | model_tag : str, optional
45 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
46 | model_bitrate: str
47 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
48 | n_quantizers : int, optional
49 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model
50 | will compress at maximum bitrate.
51 | model_type : str, optional
52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if
53 | `weights_path` is specified.
54 | """
55 | model, variables = load_model(
56 | model_type=model_type,
57 | model_bitrate=model_bitrate,
58 | tag=model_tag,
59 | load_path=weights_path,
60 | )
61 |
62 | # Find all audio files in input path
63 | input = Path(input)
64 | audio_files = find_audio(input)
65 |
66 | output = Path(output)
67 | output.mkdir(parents=True, exist_ok=True)
68 |
69 | @jax.jit
70 | def compress_chunk(x):
71 | return model.apply(variables, x, method="compress_chunk")
72 |
73 | for audio_file in tqdm(audio_files, desc="Encoding files"):
74 | # Load file with original sample rate
75 | signal, sample_rate = librosa.load(audio_file, sr=None, mono=False)
76 | while signal.ndim < 3:
77 | signal = jnp.expand_dims(signal, axis=0)
78 |
79 | # Encode audio to .dac format
80 | dac_file = model.compress(
81 | compress_chunk,
82 | signal,
83 | sample_rate,
84 | win_duration=win_duration,
85 | verbose=verbose,
86 | n_quantizers=n_quantizers,
87 | )
88 |
89 | # Compute output path
90 | relative_path = audio_file.relative_to(input)
91 | output_dir = output / relative_path.parent
92 | if not relative_path.name:
93 | output_dir = output
94 | relative_path = audio_file
95 | output_name = relative_path.with_suffix(".dac").name
96 | output_path = output_dir / output_name
97 | output_path.parent.mkdir(parents=True, exist_ok=True)
98 |
99 | dac_file.save(output_path)
100 |
101 |
102 | if __name__ == "__main__":
103 | args = argbind.parse_args()
104 | with argbind.scope(args):
105 | encode()
106 |
--------------------------------------------------------------------------------
/src/dac_jax/utils/load_torch_weights.py:
--------------------------------------------------------------------------------
1 | def torch_to_linen(
2 | torch_params: dict,
3 | encoder_rates: tuple[int] = None,
4 | decoder_rates: tuple[int] = None,
5 | num_codebooks: int = 9,
6 | ) -> dict:
7 | """Convert PyTorch parameters to Linen nested dictionaries"""
8 |
9 | if encoder_rates is None:
10 | encoder_rates = [2, 4, 8, 8]
11 | if decoder_rates is None:
12 | decoder_rates = [8, 8, 4, 2]
13 |
14 | def parse_wn_conv(flax_params, from_prefix, to_i: int):
15 | d = {}
16 | d[f"Conv_0"] = {
17 | "bias": torch_params[f"{from_prefix}.bias"],
18 | "kernel": torch_params[f"{from_prefix}.weight_v"].T,
19 | }
20 | d[f"WeightNorm_0"] = {
21 | f"Conv_0/kernel/scale": torch_params[f"{from_prefix}.weight_g"].squeeze(
22 | (1, 2)
23 | )
24 | }
25 | flax_params[f"WNConv1d_{to_i}"] = d
26 |
27 | def parse_wn_convtranspose(flax_params, from_prefix, to_i: int):
28 | d = {}
29 | d[f"ConvTranspose_0"] = {
30 | "bias": torch_params[f"{from_prefix}.bias"],
31 | "kernel": torch_params[f"{from_prefix}.weight_v"].transpose(),
32 | }
33 | d[f"WeightNorm_0"] = {
34 | f"ConvTranspose_0/kernel/scale": torch_params[
35 | f"{from_prefix}.weight_g"
36 | ].squeeze((1, 2))
37 | }
38 | flax_params[f"WNConvTranspose1d_{to_i}"] = d
39 |
40 | def parse_residual_unit(flax_params, from_prefix, to_i):
41 | d = {}
42 | d["Snake1d_0"] = {
43 | "alpha": torch_params[f"{from_prefix}.block.0.alpha"].transpose(0, 2, 1)
44 | }
45 | parse_wn_conv(d, f"{from_prefix}.block.1", 0)
46 | d["Snake1d_1"] = {
47 | "alpha": torch_params[f"{from_prefix}.block.2.alpha"].transpose(0, 2, 1)
48 | }
49 | parse_wn_conv(d, f"{from_prefix}.block.3", 1)
50 | flax_params[f"ResidualUnit_{to_i}"] = d
51 |
52 | def parse_encoder_block(flax_params, from_prefix, to_i):
53 | d = {}
54 | for i in range(3):
55 | parse_residual_unit(d, f"{from_prefix}.block.{i}", i)
56 |
57 | d["Snake1d_0"] = {
58 | "alpha": torch_params[f"{from_prefix}.block.3.alpha"].transpose(0, 2, 1)
59 | }
60 |
61 | parse_wn_conv(d, f"{from_prefix}.block.4", 0)
62 | flax_params[f"EncoderBlock_{to_i}"] = d
63 |
64 | def parse_decoder_block(flax_params, from_prefix, to_i):
65 | d = {}
66 | d["Snake1d_0"] = {
67 | "alpha": torch_params[f"{from_prefix}.block.0.alpha"].transpose(0, 2, 1)
68 | }
69 |
70 | parse_wn_convtranspose(d, f"{from_prefix}.block.1", 0)
71 |
72 | for i in range(3):
73 | parse_residual_unit(d, f"{from_prefix}.block.{i+2}", i)
74 |
75 | flax_params[f"DecoderBlock_{to_i}"] = d
76 |
77 | flax_params = {"encoder": {}, "decoder": {}, "quantizer": {}}
78 |
79 | i = 0
80 | # add Encoder
81 | parse_wn_conv(flax_params["encoder"], f"encoder.block.{i}", 0)
82 |
83 | # add EncoderBlocks
84 | for _ in encoder_rates:
85 | parse_encoder_block(flax_params["encoder"], f"encoder.block.{i+1}", i)
86 | i += 1
87 |
88 | i += 1
89 | flax_params["encoder"]["Snake1d_0"] = {
90 | "alpha": torch_params[f"encoder.block.{i}.alpha"].transpose(0, 2, 1)
91 | }
92 |
93 | i += 1
94 | parse_wn_conv(flax_params["encoder"], f"encoder.block.{i}", 1)
95 |
96 | # Add Quantizer
97 | for i in range(num_codebooks):
98 | quantizer = {}
99 | quantizer["in_proj"] = {
100 | "WeightNorm_0": {
101 | "Conv_0/kernel/scale": torch_params[
102 | f"quantizer.quantizers.{i}.in_proj.weight_g"
103 | ].squeeze((1, 2))
104 | },
105 | "Conv_0": {
106 | "bias": torch_params[f"quantizer.quantizers.{i}.in_proj.bias"],
107 | "kernel": torch_params[f"quantizer.quantizers.{i}.in_proj.weight_v"].T,
108 | },
109 | }
110 | quantizer["codebook"] = {
111 | "embedding": torch_params[f"quantizer.quantizers.{i}.codebook.weight"]
112 | }
113 | quantizer["out_proj"] = {
114 | "WeightNorm_0": {
115 | "Conv_0/kernel/scale": torch_params[
116 | f"quantizer.quantizers.{i}.out_proj.weight_g"
117 | ].squeeze((1, 2))
118 | },
119 | "Conv_0": {
120 | "bias": torch_params[f"quantizer.quantizers.{i}.out_proj.bias"],
121 | "kernel": torch_params[f"quantizer.quantizers.{i}.out_proj.weight_v"].T,
122 | },
123 | }
124 | flax_params["quantizer"][f"quantizers_{i}"] = quantizer
125 |
126 | i = 0
127 | # Add Decoder
128 | parse_wn_conv(flax_params["decoder"], f"decoder.model.{i}", 0)
129 |
130 | # Add DecoderBlocks
131 | for _ in decoder_rates:
132 | parse_decoder_block(flax_params["decoder"], f"decoder.model.{i+1}", i)
133 | i += 1
134 |
135 | i += 1
136 | flax_params["decoder"]["Snake1d_0"] = {
137 | "alpha": torch_params[f"decoder.model.{i}.alpha"].transpose(0, 2, 1)
138 | }
139 |
140 | i += 1
141 | parse_wn_conv(flax_params["decoder"], f"decoder.model.{i}", 1)
142 |
143 | return {"params": flax_params}
144 |
--------------------------------------------------------------------------------
/src/dac_jax/utils/load_torch_weights_encodec.py:
--------------------------------------------------------------------------------
1 | from jax import numpy as jnp
2 |
3 |
4 | def streamable(torch_params, prefix: str):
5 | return {
6 | "NormConv1d_0": {
7 | "WeightNorm_0": {
8 | "Conv_0/kernel/scale": torch_params[
9 | f"{prefix}.conv.conv.weight_g"
10 | ].squeeze((1, 2)),
11 | },
12 | "Conv_0": {
13 | "bias": torch_params[f"{prefix}.conv.conv.bias"],
14 | "kernel": torch_params[f"{prefix}.conv.conv.weight_v"].T,
15 | },
16 | }
17 | }
18 |
19 |
20 | def streamable_transpose(torch_params, prefix: str):
21 | return {
22 | "NormConvTranspose1d_0": {
23 | "WeightNorm_0": {
24 | "ConvTranspose_0/kernel/scale": torch_params[
25 | f"{prefix}.convtr.convtr.weight_g"
26 | ].squeeze((1, 2)),
27 | },
28 | "ConvTranspose_0": {
29 | "bias": torch_params[f"{prefix}.convtr.convtr.bias"],
30 | "kernel": torch_params[f"{prefix}.convtr.convtr.weight_v"].T,
31 | },
32 | }
33 | }
34 |
35 |
36 | def lstm(torch_params, prefix: str, i: int):
37 | weight_ih_l0 = torch_params[f"{prefix}.lstm.weight_ih_l{i}"]
38 | weight_hh_l0 = torch_params[f"{prefix}.lstm.weight_hh_l{i}"]
39 | bias_ih_l0 = torch_params[f"{prefix}.lstm.bias_ih_l{i}"]
40 | bias_hh_l0 = torch_params[f"{prefix}.lstm.bias_hh_l{i}"]
41 |
42 | weight_hh_l0 = weight_hh_l0.transpose(1, 0)
43 | weight_ih_l0 = weight_ih_l0.transpose(1, 0)
44 |
45 | # https://github.com/pytorch/pytorch/blob/40de63be097ce6d499aac15fc58ed27ca33e5227/aten/src/ATen/native/RNN.cpp#L1560-L1564
46 | kernel_hi, kernel_hf, kernel_hg, kernel_ho = jnp.split(weight_hh_l0, 4, axis=1)
47 | kernel_ii, kernel_if, kernel_ig, kernel_io = jnp.split(weight_ih_l0, 4, axis=1)
48 |
49 | bias = bias_ih_l0 + bias_hh_l0
50 |
51 | bias_i, bias_f, bias_g, bias_o = jnp.split(bias, 4)
52 |
53 | return {
54 | "hi": {
55 | "bias": bias_i,
56 | "kernel": kernel_hi,
57 | },
58 | "hf": {
59 | "bias": bias_f,
60 | "kernel": kernel_hf,
61 | },
62 | "hg": {
63 | "bias": bias_g,
64 | "kernel": kernel_hg,
65 | },
66 | "ho": {
67 | "bias": bias_o,
68 | "kernel": kernel_ho,
69 | },
70 | "ii": {
71 | "kernel": kernel_ii,
72 | },
73 | "if": {
74 | "kernel": kernel_if,
75 | },
76 | "ig": {
77 | "kernel": kernel_ig,
78 | },
79 | "io": {
80 | "kernel": kernel_io,
81 | },
82 | }
83 |
84 |
85 | def torch_to_encoder(torch_params: dict, encoder_rates: tuple[int] = None):
86 | d = {}
87 |
88 | i = 0
89 | j = 0
90 | for _ in range(len(encoder_rates)):
91 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"encoder.model.{j}")
92 | j += 1
93 | d[f"SEANetResnetBlock_{i}"] = {
94 | f"StreamableConv1d_0": streamable(
95 | torch_params, f"encoder.model.{j}.block.1"
96 | ),
97 | f"StreamableConv1d_1": streamable(
98 | torch_params, f"encoder.model.{j}.block.3"
99 | ),
100 | }
101 | i += 1
102 | j += 2
103 |
104 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"encoder.model.{j}")
105 |
106 | j += 1
107 | lstm_layers = 2 # todo:
108 | d[f"StreamableLSTM_0"] = {
109 | f"LSTMCell_{k}": lstm(torch_params, f"encoder.model.{j}", k)
110 | for k in range(lstm_layers)
111 | }
112 | j += lstm_layers
113 |
114 | i += 1
115 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"encoder.model.{j}")
116 |
117 | return d
118 |
119 |
120 | def torch_to_decoder(torch_params: dict, decoder_rates: tuple[int] = None):
121 | d = {}
122 |
123 | i = 0
124 | j = 0
125 |
126 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"decoder.model.{j}")
127 | j += 1
128 | lstm_layers = 2 # todo:
129 | d[f"StreamableLSTM_0"] = {
130 | f"LSTMCell_{k}": lstm(torch_params, f"decoder.model.{j}", k)
131 | for k in range(lstm_layers)
132 | }
133 | j += lstm_layers
134 | for k in range(len(decoder_rates)):
135 | d[f"StreamableConvTranspose1d_{i}"] = streamable_transpose(
136 | torch_params, f"decoder.model.{j}"
137 | )
138 | j += 1
139 | d[f"SEANetResnetBlock_{i}"] = {
140 | f"StreamableConv1d_0": streamable(
141 | torch_params, f"decoder.model.{j}.block.1"
142 | ),
143 | f"StreamableConv1d_1": streamable(
144 | torch_params, f"decoder.model.{j}.block.3"
145 | ),
146 | }
147 | i += 1
148 | j += 2
149 |
150 | d[f"StreamableConv1d_1"] = streamable(torch_params, f"decoder.model.{j}")
151 |
152 | return d
153 |
154 |
155 | def torch_to_quantizer(torch_params: dict, n_quantizers):
156 | d = {
157 | f"layers_{i}": {
158 | "_codebook": {
159 | "embed": torch_params[f"quantizer.vq.layers.{i}._codebook.embed"],
160 | "embed_avg": torch_params[
161 | f"quantizer.vq.layers.{i}._codebook.embed_avg"
162 | ],
163 | }
164 | }
165 | for i in range(n_quantizers)
166 | }
167 |
168 | return {"vq": d}
169 |
170 |
171 | def torch_to_linen(
172 | torch_params: dict,
173 | encoder_rates: tuple[int] = None,
174 | decoder_rates: tuple[int] = None,
175 | num_codebooks: int = 9,
176 | ) -> dict:
177 | """Convert PyTorch parameters to Linen nested dictionaries"""
178 |
179 | if encoder_rates is None:
180 | encoder_rates = [2, 4, 8, 8]
181 | if decoder_rates is None:
182 | decoder_rates = [8, 8, 4, 2]
183 |
184 | return {
185 | "params": {
186 | "encoder": torch_to_encoder(torch_params, encoder_rates=encoder_rates),
187 | "decoder": torch_to_decoder(torch_params, decoder_rates=decoder_rates),
188 | "quantizer": torch_to_quantizer(torch_params, num_codebooks),
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | `60013__qubodup__whoosh.flac`:
2 | https://freesound.org/people/qubodup/sounds/60013/
3 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DBraun/DAC-JAX/919ce4a2a9ec4c5c3fa7d10dcb2944259da00865/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_audio_utils.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | from einops import rearrange
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import pytest
7 |
8 | from dac_jax.audio_utils import stft, mel_spectrogram
9 | from dac_jax.nn.loss import mel_spectrogram_loss, multiscale_stft_loss
10 |
11 | from dac.nn.loss import MelSpectrogramLoss, MultiScaleSTFTLoss
12 | from audiotools import AudioSignal
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "match_stride,hop_factor,length",
17 | product(
18 | [False, True],
19 | [0.25, 0.5],
20 | [44100, 44101],
21 | ),
22 | )
23 | def test_mel_same_as_audiotools(match_stride: bool, hop_factor: float, length: int):
24 |
25 | if hop_factor == 0.5 and match_stride:
26 | return # for some reason DAC torch disallows this
27 |
28 | sample_rate = 44100
29 |
30 | B = 1
31 | x = np.random.uniform(low=-1, high=1, size=(B, 1, length))
32 |
33 | signal1 = AudioSignal(x, sample_rate=sample_rate)
34 |
35 | window_length = 2048
36 | hop_length = int(window_length * hop_factor)
37 |
38 | stft_kwargs = {
39 | "window_length": window_length,
40 | "hop_length": hop_length,
41 | "window_type": "hann",
42 | "match_stride": match_stride,
43 | "padding_type": "reflect",
44 | }
45 |
46 | n_mels = 80
47 |
48 | mel1 = signal1.mel_spectrogram(n_mels=n_mels, **stft_kwargs)
49 |
50 | stft1 = signal1.stft_data
51 |
52 | stft_data = stft(
53 | jnp.array(x),
54 | frame_length=stft_kwargs["window_length"],
55 | hop_factor=hop_factor,
56 | window=stft_kwargs["window_type"],
57 | match_stride=stft_kwargs["match_stride"],
58 | padding_type=stft_kwargs["padding_type"],
59 | )
60 |
61 | assert np.allclose(np.abs(stft1), np.abs(stft_data), atol=1e-4)
62 |
63 | stft_data = rearrange(stft_data, "b c nf nt -> (b c) nt nf")
64 |
65 | spectrogram = jnp.abs(stft_data)
66 |
67 | mel2 = mel_spectrogram(
68 | spectrogram,
69 | log_scale=False,
70 | sample_rate=sample_rate,
71 | num_features=n_mels,
72 | frame_length=stft_kwargs["window_length"],
73 | )
74 |
75 | mel2 = rearrange(mel2, "(b c) t bins -> b c bins t", b=B)
76 |
77 | assert np.allclose(mel1, np.array(mel2), atol=1e-4)
78 |
79 |
80 | @pytest.mark.parametrize(
81 | "length",
82 | (44100, 44101),
83 | )
84 | def test_mel_loss_same_as_dac_torch(length: int):
85 |
86 | sample_rate = 44100
87 |
88 | x1 = np.random.uniform(low=-1, high=1, size=(1, 1, length))
89 | x2 = x1 * 0.5
90 |
91 | signal1 = AudioSignal(x1, sample_rate=sample_rate)
92 | signal2 = AudioSignal(x2, sample_rate=sample_rate)
93 |
94 | loss1 = mel_spectrogram_loss(jnp.array(x1), jnp.array(x2), sample_rate=sample_rate)
95 | loss2 = MelSpectrogramLoss()(signal1, signal2)
96 |
97 | assert np.isclose(np.array(loss1), loss2)
98 |
99 |
100 | @pytest.mark.parametrize(
101 | "length",
102 | (44100, 44101),
103 | )
104 | def test_multiscale_stft_loss_same_as_dac_torch(length: int):
105 | sample_rate = 44100
106 |
107 | x1 = np.random.uniform(low=-1, high=1, size=(1, 1, length))
108 | x2 = x1 * 0.5
109 |
110 | signal1 = AudioSignal(x1, sample_rate=sample_rate)
111 | signal2 = AudioSignal(x2, sample_rate=sample_rate)
112 |
113 | loss1 = multiscale_stft_loss(jnp.array(x1), jnp.array(x2))
114 | loss2 = MultiScaleSTFTLoss()(signal1, signal2)
115 |
116 | assert np.isclose(np.array(loss1), loss2)
117 |
118 |
119 | if __name__ == "__main__":
120 | # test_mel_same_as_audiotools()
121 | # test_mel_loss_same_as_dac_torch()
122 | # test_multiscale_stft_loss_same_as_dac_torch()
123 | # test_stft_equivalence(True)
124 | # test_stft_equivalence(False)
125 | # test_stft_equivalence2(0.5)
126 | test_mel_same_as_audiotools(False, 0.25, 44100)
127 |
--------------------------------------------------------------------------------
/tests/test_binding.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from pathlib import Path
3 | import tempfile
4 |
5 | import jax.numpy as jnp
6 | import librosa
7 |
8 | import dac_jax
9 |
10 |
11 | def test_binding():
12 |
13 | # Download a model and bind variables to it.
14 | model, variables = dac_jax.load_model(model_type="44khz")
15 | model = model.bind(variables)
16 |
17 | # Load audio file
18 | filepath = Path(__file__).parent / "assets" / "60013__qubodup__whoosh.flac"
19 | signal, sample_rate = librosa.load(filepath, sr=44100, mono=True, duration=0.5)
20 |
21 | signal = jnp.array(signal, dtype=jnp.float32)
22 | while signal.ndim < 3:
23 | signal = jnp.expand_dims(signal, axis=0)
24 |
25 | # Encode audio signal as one long file (may run out of GPU memory on long files)
26 | dac_file = model.encode_to_dac(signal, sample_rate)
27 |
28 | with tempfile.TemporaryDirectory() as tmpdirname:
29 | filepath = os.path.join(tmpdirname, "dac_file_001.dac")
30 |
31 | # Save to a file
32 | dac_file.save(filepath)
33 |
34 | # Load a file
35 | dac_file = dac_jax.DACFile.load(filepath)
36 |
37 | # Decode audio signal
38 | y = model.decode(dac_file)
39 |
40 | # reconstruction mean-square error
41 | mse = jnp.square(y - signal).mean()
42 |
43 | # Informal expected maximum MSE
44 | assert mse.item() < 0.005
45 |
46 |
47 | if __name__ == "__main__":
48 | test_binding()
49 |
--------------------------------------------------------------------------------
/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for CLI.
3 | """
4 |
5 | import subprocess
6 | from pathlib import Path
7 |
8 | import argbind
9 | import numpy as np
10 | import pytest
11 | import soundfile
12 |
13 | from dac_jax.__main__ import run
14 |
15 |
16 | def setup_module(module):
17 | data_dir = Path(__file__).parent / "tmp_assets"
18 | data_dir.mkdir(exist_ok=True, parents=True)
19 | input_dir = data_dir / "input"
20 | input_dir.mkdir(exist_ok=True, parents=True)
21 |
22 | for i in range(5):
23 | sample_rate = 44_100
24 | signal = np.random.randn(1000, sample_rate)
25 | soundfile.write(input_dir / f"sample_{i}.wav", signal, samplerate=sample_rate)
26 | return input_dir
27 |
28 |
29 | def teardown_module(module):
30 | repo_root = Path(__file__).parent.parent
31 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/tmp_assets"])
32 |
33 |
34 | @pytest.mark.parametrize("model_type", ["44khz", "24khz", "16khz"])
35 | def test_reconstruction(model_type):
36 | # Test encoding
37 | input_dir = Path(__file__).parent / "tmp_assets" / "input"
38 | output_dir = input_dir.parent / model_type / "encoded_output"
39 | args = {
40 | "input": str(input_dir),
41 | "output": str(output_dir),
42 | "model_type": model_type,
43 | }
44 | with argbind.scope(args):
45 | run("encode")
46 |
47 | # Test decoding
48 | input_dir = output_dir
49 | output_dir = input_dir.parent / model_type / "decoded_output"
50 | args = {
51 | "input": str(input_dir),
52 | "output": str(output_dir),
53 | "model_type": model_type,
54 | }
55 | with argbind.scope(args):
56 | run("decode")
57 |
58 |
59 | def test_compression():
60 | # Test encoding
61 | input_dir = Path(__file__).parent / "tmp_assets" / "input"
62 | output_dir = input_dir.parent / "encoded_output_quantizers"
63 | args = {
64 | "input": str(input_dir),
65 | "output": str(output_dir),
66 | "n_quantizers": 3,
67 | }
68 | with argbind.scope(args):
69 | run("encode")
70 |
71 | # Open .dac file
72 | dac_file = output_dir / "sample_0.dac"
73 | allow_pickle = True # todo:
74 | artifacts = np.load(dac_file, allow_pickle=allow_pickle)[()]
75 | codes = artifacts["codes"]
76 |
77 | # Ensure that the number of quantizers is correct
78 | assert codes.shape[2] == 3
79 |
80 | # Ensure that dtype of compression is uint16
81 | assert codes.dtype == np.uint16
82 |
83 |
84 | # CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s
85 |
--------------------------------------------------------------------------------
/tests/test_dac_equivalence.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["XLA_FLAGS"] = (
4 | " --xla_gpu_deterministic_ops=true" # todo: https://github.com/google/flax/discussions/3382
5 | )
6 | os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
7 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
8 |
9 | from functools import partial
10 | from pathlib import Path
11 |
12 | import torch
13 |
14 | torch.use_deterministic_algorithms(True)
15 |
16 | import jax
17 | from jax import numpy as jnp
18 | from jax import random
19 |
20 | import dac as dac_torch
21 | from audiotools import AudioSignal
22 |
23 | import librosa
24 | import numpy as np
25 |
26 | import dac_jax
27 | from dac_jax import QuantizedResult
28 |
29 |
30 | def _torch_padding(np_data) -> dict[np.array]:
31 |
32 | model_path = dac_torch.utils.download(model_type="44khz")
33 | model = dac_torch.DAC.load(model_path)
34 |
35 | x = torch.from_numpy(np_data)
36 | sample_rate = model.sample_rate # note: not always true outside of this test
37 | x = model.preprocess(x, sample_rate)
38 | z, codes, latents, commitment_loss, codebook_loss = model.encode(x)
39 |
40 | # Decode audio signal
41 | audio = model.decode(z)
42 |
43 | d = {
44 | "audio": audio,
45 | "z": z,
46 | "codes": codes,
47 | "latents": latents,
48 | "vq/commitment_loss": commitment_loss,
49 | "vq/codebook_loss": codebook_loss,
50 | }
51 |
52 | d = {k: v.detach().cpu().numpy() for k, v in d.items()}
53 |
54 | return d
55 |
56 |
57 | def _torch_compress(np_data, win_duration: float):
58 |
59 | model = dac_torch.utils.load_model(model_type="44khz")
60 |
61 | sample_rate = model.sample_rate # note: not always true outside of this test
62 | x = AudioSignal(np_data, sample_rate=sample_rate)
63 |
64 | dac_file = model.compress(x, win_duration=win_duration)
65 | # get an embedding z for just a single chunk, only for the sake of comparing to jax
66 | c = dac_file.codes[..., : dac_file.chunk_length]
67 | z = model.quantizer.from_codes(c)[0]
68 | z = z.detach().cpu().numpy()
69 |
70 | recons = model.decompress(dac_file).audio_data
71 | recons = recons.cpu().numpy()
72 |
73 | return dac_file.codes, z, recons
74 |
75 |
76 | def _jax_padding(np_data) -> dict[np.array]:
77 |
78 | model, variables = dac_jax.load_model(model_type="44khz")
79 |
80 | q_res: QuantizedResult = model.apply(
81 | variables, jnp.array(np_data), model.sample_rate, train=False
82 | )
83 |
84 | # Multiply by model.num_codebooks since we normalize by num_codebooks and torch doesn't.
85 | # q_res.commitment_loss = q_res.commitment_loss*model.num_codebooks
86 | # q_res.codebook_loss = q_res.codebook_loss * model.num_codebooks
87 |
88 | y = {
89 | "audio": q_res.recons,
90 | "z": q_res.z.transpose(0, 2, 1),
91 | "latents": q_res.latents,
92 | "codes": q_res.codes,
93 | "vq/codebook_loss": q_res.codebook_loss,
94 | "vq/commitment_loss": q_res.commitment_loss,
95 | }
96 |
97 | y = jax.tree.map(lambda x: np.array(x), y)
98 | return y
99 |
100 |
101 | def _jax_padding_jit(np_data):
102 |
103 | model, variables = dac_jax.load_model(model_type="44khz")
104 |
105 | @jax.jit
106 | def encode_to_codes(x: jnp.ndarray):
107 | codes, scale = model.apply(
108 | variables,
109 | x,
110 | method="encode",
111 | )
112 | return codes, scale
113 |
114 | @partial(jax.jit, static_argnums=(1, 2))
115 | def decode_from_codes(codes: jnp.ndarray, scale, length: int = None):
116 | recons = model.apply(
117 | variables,
118 | codes,
119 | scale,
120 | length,
121 | method="decode",
122 | )
123 |
124 | return recons
125 |
126 | x = jnp.array(np_data)
127 |
128 | original_length = x.shape[-1]
129 |
130 | codes, scale = encode_to_codes(x)
131 | assert codes.shape[1] == model.num_codebooks
132 |
133 | recons = decode_from_codes(codes, scale, original_length)
134 |
135 | return np.array(recons), np.array(codes)
136 |
137 |
138 | def _jax_compress(np_data, win_duration: float):
139 |
140 | # set padding to False since we're using the chunk functions
141 | model, variables = dac_jax.load_model(model_type="44khz", padding=False)
142 | sample_rate = 44100
143 |
144 | @jax.jit
145 | def compress_chunk(x):
146 | return model.apply(variables, x, method="compress_chunk")
147 |
148 | @jax.jit
149 | def decompress_chunk(c):
150 | return model.apply(variables, c, method="decompress_chunk")
151 |
152 | @jax.jit
153 | def decode_latent(c):
154 | return model.apply(variables, c, method="decode_latent")
155 |
156 | key = jax.random.key(0)
157 | subkey1, subkey2, subkey3 = jax.random.split(key, 3)
158 | x = jax.random.normal(subkey1, shape=(1, 1, int(sample_rate * 2)))
159 |
160 | _ = model.init({"params": subkey2, "rng_stream": subkey3}, x, sample_rate)
161 |
162 | x = jnp.array(np_data)
163 | dac_file = model.compress(compress_chunk, x, sample_rate, win_duration=win_duration)
164 |
165 | codes = dac_file.codes
166 |
167 | # get an embedding z for just a single chunk, only for the sake of comparing to torch
168 | z = decode_latent(codes[:, :, : dac_file.chunk_length]).transpose(0, 2, 1)
169 |
170 | recons = model.decompress(decompress_chunk, dac_file)
171 | recons = np.array(recons)
172 |
173 | return codes, z, recons
174 |
175 |
176 | def test_equivalence_padding():
177 |
178 | np.random.seed(0)
179 | np_data = np.random.normal(loc=0, scale=1, size=(1, 1, 4096)).astype(np.float32)
180 |
181 | jax_result = _jax_padding(np_data)
182 | torch_result = _torch_padding(np_data)
183 | assert set(jax_result.keys()) == set(torch_result.keys())
184 | assert list(jax_result.keys())
185 | for key in jax_result.keys():
186 | # print(f"key: {key}, torch: {torch_result[key].shape}, jax: {jax_result[key].shape}")
187 | if key == "latents":
188 | # todo: why do we need to accept lower absolute tolerance for this key?
189 | atol = 1e-3
190 | elif key in ["vq/commitment_loss", "vq/codebook_loss"]:
191 | # todo: why do we need to accept lower absolute tolerance for these keys?
192 | atol = 1e-3
193 | elif key == "codes":
194 | atol = 1e-8
195 | elif key == "audio":
196 | atol = 1e-5
197 | elif key == "z":
198 | atol = 1e-5
199 | else:
200 | raise ValueError(f"Unexpected key '{key}'.")
201 | assert (
202 | jax_result[key].shape == torch_result[key].shape
203 | ), f"key: {key}, torch: {torch_result[key].shape}, jax: {jax_result[key].shape}"
204 | assert np.allclose(
205 | jax_result[key], torch_result[key], atol=atol
206 | ), f"Failed to match outputs for key: {key} and atol: {atol}"
207 |
208 | jax_recons, jax_codes = _jax_padding_jit(np_data)
209 |
210 | assert np.allclose(torch_result["codes"], jax_codes)
211 | assert np.allclose(
212 | torch_result["audio"], jax_recons, atol=1e-4
213 | ) # todo: reduce atol to 1e-5
214 |
215 |
216 | def test_equivalence_compress(verbose=False):
217 |
218 | def compress_helper(np_data, atol, win_duration=0.38):
219 |
220 | jax_codes, jax_z, jax_recons = _jax_compress(np_data, win_duration)
221 | torch_codes, torch_z, torch_recons = _torch_compress(np_data, win_duration)
222 | assert np.allclose(jax_codes, torch_codes)
223 | np.testing.assert_almost_equal(
224 | torch_z, jax_z, decimal=5
225 | ) # todo: raise this to decimal=6
226 | if verbose:
227 | print("max diff: ", jnp.abs(jax_recons - torch_recons).max())
228 | assert np.allclose(jax_recons, torch_recons, atol=atol)
229 |
230 | np_data, sr = librosa.load(
231 | Path(__file__).parent / "assets/60013__qubodup__whoosh.flac", sr=None, mono=True
232 | )
233 | np_data = np.expand_dims(np.array(np_data), 0)
234 | np_data = np.expand_dims(np.array(np_data), 0)
235 | np_data = np.concatenate([np_data, np_data, np_data, np_data], axis=-1)
236 | compress_helper(np_data, atol=1e-5)
237 |
238 | np.random.seed(0)
239 | num_samples = int(44100 * 10)
240 | np_data = 0.5 * np.random.uniform(low=-1, high=1, size=(1, 1, num_samples)).astype(
241 | np.float32
242 | )
243 | # todo: for compressing/decompressing noise, why must we use a higher absolute tolerance?
244 | compress_helper(np_data, atol=0.003)
245 |
246 |
247 | if __name__ == "__main__":
248 | test_equivalence_padding()
249 | test_equivalence_compress()
250 | print("All Done!")
251 |
--------------------------------------------------------------------------------
/tests/test_encodec_equivalence.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["XLA_FLAGS"] = (
4 | " --xla_gpu_deterministic_ops=true" # todo: https://github.com/google/flax/discussions/3382
5 | )
6 | os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
7 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
8 |
9 | from functools import partial
10 | from pathlib import Path
11 |
12 | from audiocraft.models import MusicGen
13 | import jax
14 | from jax import numpy as jnp
15 | from jax import random
16 | import librosa
17 | import numpy as np
18 | import torch
19 |
20 | from dac_jax import load_encodec_model, QuantizedResult
21 |
22 |
23 | def run_jax_model1(np_data):
24 |
25 | x = jnp.array(np_data)
26 |
27 | encodec_model, variables = load_encodec_model("facebook/musicgen-small")
28 |
29 | result: QuantizedResult = encodec_model.apply(
30 | variables, x, train=False, rngs={"rng_stream": random.key(0)}
31 | )
32 | recons = result.recons
33 | codes = result.codes
34 | assert codes.shape[1] == encodec_model.num_codebooks
35 |
36 | return np.array(recons), np.array(codes)
37 |
38 |
39 | def run_jax_model2(np_data):
40 | """jax.jit version of run_jax_model1"""
41 |
42 | model, variables = load_encodec_model()
43 |
44 | @jax.jit
45 | def encode_to_codes(x: jnp.ndarray):
46 | codes, scale = model.apply(
47 | variables,
48 | x,
49 | method="encode",
50 | )
51 | return codes, scale
52 |
53 | @partial(jax.jit, static_argnums=(1, 2))
54 | def decode_from_codes(codes: jnp.ndarray, scale, length: int = None):
55 | recons = model.apply(
56 | variables,
57 | codes,
58 | scale,
59 | length,
60 | method="decode",
61 | )
62 |
63 | return recons
64 |
65 | x = jnp.array(np_data)
66 |
67 | original_length = x.shape[-1]
68 |
69 | codes, scale = encode_to_codes(x)
70 | assert codes.shape[1] == model.num_codebooks
71 |
72 | recons = decode_from_codes(codes, scale, original_length)
73 |
74 | return np.array(recons), np.array(codes)
75 |
76 |
77 | def run_torch_model(np_data):
78 | model = MusicGen.get_pretrained("facebook/musicgen-small")
79 | x = torch.from_numpy(np_data).cuda()
80 | result = model.compression_model(x)
81 |
82 | recons = result.x.detach().cpu().numpy()
83 | codes = result.codes.detach().cpu().numpy()
84 | assert codes.shape[1] == model.compression_model.num_codebooks
85 |
86 | return recons, codes
87 |
88 |
89 | def test_encoded_equivalence():
90 | np_data, sr = librosa.load(
91 | Path(__file__).parent / "assets/60013__qubodup__whoosh.flac", sr=None, mono=True
92 | )
93 | np_data = np.expand_dims(np.array(np_data), 0)
94 | np_data = np.expand_dims(np.array(np_data), 0)
95 | np_data = np.concatenate([np_data, np_data, np_data, np_data], axis=-1)
96 |
97 | np_data *= 0.5
98 |
99 | torch_recons, torch_codes = run_torch_model(np_data)
100 | jax_recons, jax_codes = run_jax_model1(np_data)
101 |
102 | assert np.allclose(torch_codes, jax_codes)
103 | assert np.allclose(torch_recons, jax_recons, atol=1e-4) # todo: reduce atol to 1e-5
104 |
105 | jax_recons, jax_codes = run_jax_model2(np_data)
106 |
107 | assert np.allclose(torch_codes, jax_codes)
108 | assert np.allclose(torch_recons, jax_recons, atol=1e-4) # todo: reduce atol to 1e-5
109 |
110 |
111 | if __name__ == "__main__":
112 | test_encoded_equivalence()
113 |
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for CLI.
3 | """
4 |
5 | import os
6 | import shlex
7 | import subprocess
8 | from pathlib import Path
9 |
10 | import argbind
11 | import numpy as np
12 | from audiotools import AudioSignal
13 |
14 | from dac_jax.__main__ import run
15 |
16 |
17 | def make_fake_data(data_dir=Path(__file__).parent / "tmp_assets"):
18 | data_dir.mkdir(exist_ok=True, parents=True)
19 | input_dir = data_dir / "input"
20 | input_dir.mkdir(exist_ok=True, parents=True)
21 |
22 | for i in range(100):
23 | signal = AudioSignal(np.random.randn(44_100 * 5), 44_100)
24 | signal.write(input_dir / f"sample_{i}.wav")
25 | return input_dir
26 |
27 |
28 | def make_fake_data_tree():
29 | data_dir = Path(__file__).parent / "tmp_assets"
30 |
31 | for relative_dir in [
32 | "train/speech",
33 | "train/music",
34 | "train/env",
35 | "val/speech",
36 | "val/music",
37 | "val/env",
38 | "test/speech",
39 | "test/music",
40 | "test/env",
41 | ]:
42 | leaf_dir = data_dir / relative_dir
43 | leaf_dir.mkdir(exist_ok=True, parents=True)
44 | make_fake_data(leaf_dir)
45 | return {
46 | split: {
47 | key: [str(data_dir / f"{split}/{key}")]
48 | for key in ["speech", "music", "env"]
49 | }
50 | for split in ["train", "val", "test"]
51 | }
52 |
53 |
54 | def setup_module(module):
55 | # Make fake dataset dir
56 | input_datasets = make_fake_data_tree()
57 | repo_root = Path(__file__).parent.parent
58 |
59 | # Load baseline conf and modify it for testing
60 | conf = argbind.load_args(repo_root / "conf" / "ablations" / "baseline.yml")
61 |
62 | for key in ["train", "val", "test"]:
63 | conf[f"{key}/build_dataset.folders"] = input_datasets[key]
64 | conf["num_iters"] = 1
65 | conf["val/AudioDataset.n_examples"] = 1
66 | conf["val_idx"] = [0]
67 | conf["val_batch_size"] = 1
68 |
69 | argbind.dump_args(conf, Path(__file__).parent / "tmp_assets" / "conf.yml")
70 |
71 |
72 | def teardown_module(module):
73 | repo_root = Path(__file__).parent.parent
74 | # Remove fake dataset dir
75 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/tmp_assets"])
76 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/runs"])
77 |
78 |
79 | def test_single_gpu_train():
80 | env = os.environ.copy()
81 | repo_root = Path(__file__).parent.parent
82 | args = shlex.split(
83 | f"python {repo_root}/scripts/train.py --args.load {repo_root}/tests/assets/conf.yml --train.save_path {repo_root}/tests/runs/baseline"
84 | )
85 | subprocess.check_output(args, env=env)
86 |
87 |
88 | def test_multi_gpu_train():
89 | pass # todo:
90 |
--------------------------------------------------------------------------------