├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── conf ├── 1gpu.yml ├── ablations │ ├── baseline.yml │ ├── diff-mb.yml │ ├── equal-mb.yml │ ├── no-adv.yml │ ├── no-data-balance.yml │ ├── no-low-hop.yml │ ├── no-mb.yml │ ├── no-mpd-msd.yml │ ├── no-mpd.yml │ └── only-speech.yml ├── base.yml ├── downsampling │ ├── 1024x.yml │ ├── 128x.yml │ ├── 1536x.yml │ └── 768x.yml ├── final │ ├── 16khz.yml │ ├── 24khz.yml │ ├── 44khz-16kbps.yml │ └── 44khz.yml ├── neuronic.yml ├── quantizer │ ├── 24kbps.yml │ ├── 256d.yml │ ├── 2d.yml │ ├── 32d.yml │ ├── 4d.yml │ ├── 512d.yml │ ├── dropout-0.0.yml │ ├── dropout-0.25.yml │ └── dropout-0.5.yml └── size │ ├── medium.yml │ └── small.yml ├── jobs ├── benchmark.slurm └── simple.slurm ├── pyproject.toml ├── scripts ├── benchmark.py ├── compute_entropy.py ├── evaluate.py ├── get_samples.py ├── input_pipeline.py ├── mushra.py ├── organize_daps.py ├── save_test_set.py └── train.py ├── setup.cfg ├── src └── dac_jax │ ├── __init__.py │ ├── __main__.py │ ├── audio_utils.py │ ├── compare │ ├── __init__.py │ └── encodec.py │ ├── model │ ├── __init__.py │ ├── core.py │ ├── dac.py │ ├── discriminator.py │ └── encodec.py │ ├── nn │ ├── __init__.py │ ├── encodec_layers.py │ ├── encodec_quantize.py │ ├── layers.py │ ├── loss.py │ └── quantize.py │ └── utils │ ├── __init__.py │ ├── decode.py │ ├── encode.py │ ├── load_torch_weights.py │ └── load_torch_weights_encodec.py └── tests ├── README.md ├── __init__.py ├── test_audio_utils.py ├── test_binding.py ├── test_cli.py ├── test_dac_equivalence.py ├── test_encodec_equivalence.py └── test_train.py /.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DBraun/DAC-JAX/919ce4a2a9ec4c5c3fa7d10dcb2944259da00865/.gitattributes -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/env.sh 108 | venv/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # PyCharm 131 | .idea 132 | 133 | # Files created by experiments 134 | output/ 135 | snapshot/ 136 | *.m4a 137 | *.wav 138 | notebooks/scratch.ipynb 139 | notebooks/inspect.ipynb 140 | notebooks/effects.ipynb 141 | notebooks/*.ipynb 142 | notebooks/*.gif 143 | notebooks/*.wav 144 | notebooks/*.mp4 145 | *runs/ 146 | boards/ 147 | samples/ 148 | *.ipynb 149 | tmp/ 150 | 151 | results.json 152 | metrics.csv 153 | mprofile_* 154 | mem.png 155 | 156 | results/ 157 | mprofile* 158 | *.png 159 | # do not ignore the test wav file 160 | !tests/audio/short_test_audio.wav 161 | !tests/audio/output.wav 162 | */.DS_Store 163 | .DS_Store 164 | env.sh 165 | _codebraid/ 166 | **/*.html 167 | **/*.exec.md 168 | flagged/ 169 | log.txt 170 | ckpt/ 171 | .syncthing* 172 | tests/assets/ 173 | archived/ 174 | 175 | *_remote_module_* 176 | *.zip 177 | *.pth 178 | encoded_out/ 179 | recon/ 180 | recons/ 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present, Descript 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAC-JAX and EnCodec-JAX 2 | 3 | This repository holds **unofficial** JAX implementations of Descript's DAC and Meta's EnCodec. 4 | We are not affiliated with Descript or Meta. 5 | 6 | You can read the DAC-JAX paper [here](https://arxiv.org/abs/2405.11554). 7 | 8 | ## Background 9 | 10 | In 2022, Meta published "[High Fidelity Neural Audio Compression](https://arxiv.org/abs/2210.13438)". 11 | They eventually open-sourced the code inside [AudioCraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/ENCODEC.md). 12 | 13 | In 2023, Descript published a related work "[High-Fidelity Audio Compression with Improved RVQGAN](https://arxiv.org/abs/2306.06546)" 14 | and released their code under the name [DAC](https://github.com/descriptinc/descript-audio-codec/) (Descript Audio Codec). 15 | 16 | Both EnCodec and DAC are neural audio codecs which use residual vector quantization inside a fully convolutional 17 | encoder-decoder architecture. 18 | 19 | ## Usage 20 | 21 | ### Installation 22 | 23 | 1. Upgrade `pip` and `setuptools`: 24 | ```bash 25 | pip install --upgrade pip setuptools 26 | ``` 27 | 28 | 2. Install the **CPU** version of [PyTorch](https://pytorch.org/). 29 | We strongly suggest the CPU version because trying to install a GPU version can conflict with JAX's CUDA-related installation. 30 | PyTorch is required because it's used to load pretrained model weights. 31 | 32 | 3. Install [JAX](https://jax.readthedocs.io/en/latest/installation.html) (with GPU support). 33 | 34 | 4. Install DAC-JAX with one of the following: 35 | 36 | 40 | 41 | ``` 42 | pip install git+https://github.com/DBraun/DAC-JAX 43 | ``` 44 | 45 | Or, 46 | 47 | ```bash 48 | python -m pip install . 49 | ``` 50 | 51 | Or, if you intend to contribute, clone and do an [editable install](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs): 52 | ```bash 53 | python -m pip install -e ".[dev]" 54 | ``` 55 | 56 | ### Weights 57 | The original Descript repository releases model weights under the MIT license. These weights are for models that natively support 16 kHz, 24kHz, and 44.1kHz sampling rates. Our scripts download these PyTorch weights and load them into JAX. 58 | Weights are automatically downloaded when you first run an `encode` or `decode` command. You can download them in advance with one of the following commands: 59 | ```bash 60 | python -m dac_jax download_model # downloads the default 44kHz variant 61 | python -m dac_jax download_model --model_type 44khz --model_bitrate 16kbps # downloads the 44kHz 16 kbps variant 62 | python -m dac_jax download_model --model_type 44khz # downloads the 44kHz variant 63 | python -m dac_jax download_model --model_type 24khz # downloads the 24kHz variant 64 | python -m dac_jax download_model --model_type 16khz # downloads the 16kHz variant 65 | ``` 66 | 67 | EnCodec weights can be downloaded similarly. This will download the 32 kHz EnCodec used in [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md). 68 | ```bash 69 | python -m dac_jax download_encodec 70 | ``` 71 | 72 | For both DAC and EnCodec, the default download location is `~/.cache/dac_jax`. You can change the location by setting an **absolute path** value for an environment variable `DAC_JAX_CACHE`. For example, on macOS/Linux: 73 | ```bash 74 | export DAC_JAX_CACHE=/Users/admin/my-project/dac_jax_models 75 | ``` 76 | 77 | If you do this, remember to still have `DAC_JAX_CACHE` set before you use the `load_model` function. 78 | 79 | ### Compress audio 80 | ``` 81 | python -m dac_jax encode /path/to/input --output /path/to/output/codes 82 | ``` 83 | 84 | This command will create `.dac` files with the same name as the input files. 85 | It will also preserve the directory structure relative to input root and 86 | re-create it in the output directory. Please use `python -m dac_jax encode --help` 87 | for more options. 88 | 89 | ### Reconstruct audio from compressed codes 90 | ``` 91 | python -m dac_jax decode /path/to/output/codes --output /path/to/reconstructed_input 92 | ``` 93 | 94 | This command will create `.wav` files with the same name as the input files. 95 | It will also preserve the directory structure relative to input root and 96 | re-create it in the output directory. Please use `python -m dac_jax decode --help` 97 | for more options. 98 | 99 | ### Programmatic usage (DAC and EnCodec) 100 | 101 | Here we use `jax.jit` for optimized encoding and decoding. 102 | This does not do sample-rate conversion or volume normalization in the encoder or decoder. 103 | 104 | ```python 105 | from functools import partial 106 | 107 | import jax 108 | from jax import numpy as jnp 109 | import librosa 110 | 111 | import dac_jax 112 | 113 | model, variables = dac_jax.load_model(model_type="44khz") 114 | 115 | # If you want to use pretrained 32 kHz EnCodec from Meta's MusicGen, use this: 116 | # model, variables = dac_jax.load_encodec_model() 117 | 118 | @jax.jit 119 | def encode_to_codes(x: jnp.ndarray): 120 | codes, scale = model.apply( 121 | variables, 122 | x, 123 | method="encode", 124 | ) 125 | return codes, scale 126 | 127 | @partial(jax.jit, static_argnums=(1, 2)) 128 | def decode_from_codes(codes: jnp.ndarray, scale, length: int = None): 129 | recons = model.apply( 130 | variables, 131 | codes, 132 | scale, 133 | length, 134 | method="decode", 135 | ) 136 | return recons 137 | 138 | # Load a mono audio file with the correct sample rate 139 | signal, sample_rate = librosa.load('input.wav', sr=model.sample_rate, mono=True, duration=.5) 140 | 141 | signal = jnp.array(signal, dtype=jnp.float32) 142 | while signal.ndim < 3: 143 | signal = jnp.expand_dims(signal, axis=0) 144 | 145 | original_length = signal.shape[-1] 146 | 147 | codes, scale = encode_to_codes(signal) 148 | assert codes.shape[1] == model.num_codebooks 149 | 150 | recons = decode_from_codes(codes, scale, original_length) 151 | ``` 152 | 153 | ### DAC with Binding 154 | 155 | Here we use DAC-JAX as a "[bound](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html#bind)" module, freeing us from repeatedly passing variables as an argument and using `.apply`. Note that bound modules are not meant to be used in fine-tuning. 156 | 157 | ```python 158 | import dac_jax 159 | from dac_jax import DACFile 160 | 161 | from jax import numpy as jnp 162 | import librosa 163 | 164 | # Download a model and bind variables to it. 165 | model, variables = dac_jax.load_model(model_type="44khz") 166 | model = model.bind(variables) 167 | 168 | # Load a mono audio file 169 | signal, sample_rate = librosa.load('input.wav', sr=44100, mono=True, duration=.5) 170 | 171 | signal = jnp.array(signal, dtype=jnp.float32) 172 | while signal.ndim < 3: 173 | signal = jnp.expand_dims(signal, axis=0) 174 | 175 | # Encode audio signal as one long file (may run out of GPU memory on long files). 176 | # This performs resampling to the codec's sample rate and volume normalization. 177 | dac_file = model.encode_to_dac(signal, sample_rate) 178 | 179 | # Save to a file 180 | dac_file.save("dac_file_001.dac") 181 | 182 | # Load a file 183 | dac_file = DACFile.load("dac_file_001.dac") 184 | 185 | # Decode audio signal. Since we're passing a dac_file, this undoes the 186 | # previous sample rate conversion and volume normalization. 187 | y = model.decode(dac_file) 188 | 189 | # Calculate mean-square error of reconstruction in time-domain 190 | mse = jnp.square(y-signal).mean() 191 | ``` 192 | 193 | ### DAC compression with constant GPU memory regardless of input length: 194 | 195 | ```python 196 | import dac_jax 197 | 198 | import jax 199 | import jax.numpy as jnp 200 | import librosa 201 | 202 | # Download a model and set padding to False because we will use the chunk functions. 203 | model, variables = dac_jax.load_model(model_type="44khz", padding=False) 204 | 205 | # Load a mono audio file at any sample rate 206 | signal, sample_rate = librosa.load('input.wav', sr=None, mono=True) 207 | 208 | signal = jnp.array(signal, dtype=jnp.float32) 209 | while signal.ndim < 3: 210 | # signal will eventually be shaped [B, C, T] 211 | signal = jnp.expand_dims(signal, axis=0) 212 | 213 | # Jit-compile these functions because they're used inside a loop over chunks. 214 | @jax.jit 215 | def compress_chunk(x): 216 | return model.apply(variables, x, method='compress_chunk') 217 | 218 | @jax.jit 219 | def decompress_chunk(c): 220 | return model.apply(variables, c, method='decompress_chunk') 221 | 222 | win_duration = 0.5 # Adjust based on your GPU's memory size 223 | dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration) 224 | 225 | # Save and load to and from disk 226 | dac_file.save("compressed.dac") 227 | dac_file = dac_jax.DACFile.load("compressed.dac") 228 | 229 | # Decompress it back to audio 230 | y = model.decompress(decompress_chunk, dac_file) 231 | ``` 232 | 233 | ## DAC Training 234 | The baseline model configuration can be trained using the following commands. 235 | 236 | ```bash 237 | python scripts/train.py --args.load conf/final/44khz.yml --train.ckpt_dir="/tmp/dac_jax_runs" 238 | ``` 239 | 240 | In root directory, monitor with Tensorboard (`runs` will appear next to `scripts`): 241 | ```bash 242 | tensorboard --logdir="/tmp/dac_jax_runs" 243 | ``` 244 | 245 | ## Testing 246 | 247 | ``` 248 | python -m pytest tests 249 | ``` 250 | 251 | ## Limitations 252 | 253 | Pull requests—especially ones which address any of the limitations below—are welcome. 254 | 255 | * We implement the "chunked" `compress`/`decompress` methods from the PyTorch repository, although this technique has some problems outlined [here](https://github.com/descriptinc/descript-audio-codec/issues/39). 256 | * We have not run all evaluation scripts in the `scripts` directory. For some of them, it makes sense to just keep using PyTorch instead of JAX. 257 | * The model architecture code (`model/dac.py`) has many static methods to help with finding DAC's `delay` and `output_length`. Please help us refactor this so that code is not so duplicated and at risk of typos. 258 | * In `audio_utils.py` we use [DM_AUX's](https://github.com/google-deepmind/dm_aux) STFT function instead of `jax.scipy.signal.stft`. We believe this is faster but requires more memory. 259 | * The source code of DAC-JAX has some `todo:` markings which indicate (mostly minor) improvements we'd like to have. 260 | * We don't have a Docker image yet like the original [DAC repository](https://github.com/descriptinc/descript-audio-codec) does. 261 | * Please check the limitations of [argbind](https://github.com/pseeth/argbind?tab=readme-ov-file#limitations-and-known-issues). 262 | * We don't provide a training script for EnCodec. 263 | 264 | ## Citation 265 | 266 | If you use this repository in your work, please cite EnCodec: 267 | ``` 268 | @article{defossez2022high, 269 | title={High fidelity neural audio compression}, 270 | author={D{\'e}fossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, 271 | journal={arXiv preprint arXiv:2210.13438}, 272 | year={2022} 273 | } 274 | ``` 275 | 276 | DAC: 277 | 278 | ``` 279 | @article{kumar2024high, 280 | title={High-fidelity audio compression with improved rvqgan}, 281 | author={Kumar, Rithesh and Seetharaman, Prem and Luebs, Alejandro and Kumar, Ishaan and Kumar, Kundan}, 282 | journal={Advances in Neural Information Processing Systems}, 283 | volume={36}, 284 | year={2024} 285 | } 286 | ``` 287 | 288 | 289 | 290 | and DAC-JAX: 291 | 292 | ``` 293 | @misc{braun2024dacjax, 294 | title={{DAC-JAX}: A {JAX} Implementation of the Descript Audio Codec}, 295 | author={David Braun}, 296 | year={2024}, 297 | eprint={2405.11554}, 298 | archivePrefix={arXiv}, 299 | primaryClass={cs.SD} 300 | } 301 | ``` 302 | -------------------------------------------------------------------------------- /conf/1gpu.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | 4 | train.batch_size: 12 5 | train.val_batch_size: 12 6 | -------------------------------------------------------------------------------- /conf/ablations/baseline.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | -------------------------------------------------------------------------------- /conf/ablations/diff-mb.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.fft_sizes: [2048, 1024, 512] 7 | Discriminator.bands: 8 | - [0.0, 0.05] 9 | - [0.05, 0.1] 10 | - [0.1, 0.25] 11 | - [0.25, 0.5] 12 | - [0.5, 1.0] 13 | 14 | 15 | # re-weight lambdas to make up for 16 | # lost discriminators vs baseline 17 | lambdas: 18 | mel/loss: 15.0 19 | adv/feat_loss: 5.0 20 | adv/gen_loss: 1.0 21 | vq/commitment_loss: 0.25 22 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/equal-mb.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.fft_sizes: [2048, 1024, 512] 7 | Discriminator.bands: 8 | - [0.0, 0.2] 9 | - [0.2, 0.4] 10 | - [0.4, 0.6] 11 | - [0.6, 0.8] 12 | - [0.8, 1.0] 13 | 14 | 15 | # re-weight lambdas to make up for 16 | # lost discriminators vs baseline 17 | lambdas: 18 | mel/loss: 15.0 19 | adv/feat_loss: 5.0 20 | adv/gen_loss: 1.0 21 | vq/commitment_loss: 0.25 22 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/no-adv.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | lambdas: 6 | mel/loss: 1.0 7 | waveform/loss: 1.0 8 | vq/commitment_loss: 0.25 9 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/no-data-balance.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | train/build_dataset.folders: 6 | speech: 7 | - /data/daps/train 8 | - /data/vctk 9 | - /data/vocalset 10 | - /data/read_speech 11 | - /data/french_speech 12 | - /data/emotional_speech/ 13 | - /data/common_voice/ 14 | - /data/german_speech/ 15 | - /data/russian_speech/ 16 | - /data/spanish_speech/ 17 | music: 18 | - /data/musdb/train 19 | - /data/jamendo 20 | general: 21 | - /data/audioset/data/unbalanced_train_segments/ 22 | - /data/audioset/data/balanced_train_segments/ 23 | -------------------------------------------------------------------------------- /conf/ablations/no-low-hop.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | mel_spectrogram_loss.n_mels: [80] 6 | mel_spectrogram_loss.window_lengths: [512] 7 | mel_spectrogram_loss.lower_edge_hz: [0] 8 | mel_spectrogram_loss.upper_edge_hz: [null] 9 | mel_spectrogram_loss.pow: 1.0 10 | mel_spectrogram_loss.clamp_eps: 1.0e-5 11 | mel_spectrogram_loss.mag_weight: 0.0 12 | 13 | lambdas: 14 | mel/loss: 100.0 15 | adv/feat_loss: 2.0 16 | adv/gen_loss: 1.0 17 | vq/commitment_loss: 0.25 18 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/no-mb.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.fft_sizes: [2048, 1024, 512] 7 | Discriminator.bands: 8 | - [0.0, 1.0] 9 | 10 | # re-weight lambdas to make up for 11 | # lost discriminators vs baseline 12 | lambdas: 13 | mel/loss: 15.0 14 | adv/feat_loss: 5.0 15 | adv/gen_loss: 1.0 16 | vq/commitment_loss: 0.25 17 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/no-mpd-msd.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.rates: [] 7 | Discriminator.periods: [] 8 | Discriminator.fft_sizes: [2048, 1024, 512] 9 | Discriminator.bands: 10 | - [0.0, 0.1] 11 | - [0.1, 0.25] 12 | - [0.25, 0.5] 13 | - [0.5, 0.75] 14 | - [0.75, 1.0] 15 | 16 | lambdas: 17 | mel/loss: 15.0 18 | adv/feat_loss: 2.66 19 | adv/gen_loss: 1.0 20 | vq/commitment_loss: 0.25 21 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/no-mpd.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | Discriminator.sample_rate: 44100 6 | Discriminator.rates: [1] 7 | Discriminator.periods: [] 8 | Discriminator.fft_sizes: [2048, 1024, 512] 9 | Discriminator.bands: 10 | - [0.0, 0.1] 11 | - [0.1, 0.25] 12 | - [0.25, 0.5] 13 | - [0.5, 0.75] 14 | - [0.75, 1.0] 15 | 16 | lambdas: 17 | mel/loss: 15.0 18 | adv/feat_loss: 2.5 19 | adv/gen_loss: 1.0 20 | vq/commitment_loss: 0.25 21 | vq/codebook_loss: 1.0 -------------------------------------------------------------------------------- /conf/ablations/only-speech.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | train/build_dataset.folders: 6 | speech_fb: 7 | - /data/daps/train 8 | speech_hq: 9 | - /data/vctk 10 | - /data/vocalset 11 | - /data/read_speech 12 | - /data/french_speech 13 | speech_uq: 14 | - /data/emotional_speech/ 15 | - /data/common_voice/ 16 | - /data/german_speech/ 17 | - /data/russian_speech/ 18 | - /data/spanish_speech/ 19 | 20 | val/build_dataset.folders: 21 | speech_hq: 22 | - /data/daps/val 23 | -------------------------------------------------------------------------------- /conf/base.yml: -------------------------------------------------------------------------------- 1 | # Model setup 2 | DAC.sample_rate: 44100 3 | DAC.encoder_dim: 64 4 | DAC.encoder_rates: [2, 4, 8, 8] 5 | DAC.decoder_dim: 1536 6 | DAC.decoder_rates: [8, 8, 4, 2] 7 | 8 | # Quantization 9 | DAC.num_codebooks: 9 10 | DAC.codebook_size: 1024 11 | DAC.codebook_dim: 8 12 | DAC.quantizer_dropout: 1.0 13 | 14 | # Discriminator 15 | Discriminator.sample_rate: 44100 16 | Discriminator.rates: [] 17 | Discriminator.periods: [2, 3, 5, 7, 11] 18 | Discriminator.fft_sizes: [2048, 1024, 512] 19 | Discriminator.bands: 20 | - [0.0, 0.1] 21 | - [0.1, 0.25] 22 | - [0.25, 0.5] 23 | - [0.5, 0.75] 24 | - [0.75, 1.0] 25 | 26 | # Schedules 27 | create_generator_schedule.learning_rate: 1e-4 28 | create_generator_schedule.lr_gamma: 0.999996 29 | 30 | create_discriminator_schedule.learning_rate: 1e-4 31 | create_discriminator_schedule.lr_gamma: 0.999996 32 | 33 | # Optimization 34 | create_generator_optimizer.adam_b1: 0.8 35 | create_generator_optimizer.adam_b2: 0.99 36 | create_generator_optimizer.adam_weight_decay: .01 37 | create_generator_optimizer.grad_clip: 1e3 38 | 39 | create_discriminator_optimizer.adam_b1: 0.8 40 | create_discriminator_optimizer.adam_b2: 0.99 41 | create_discriminator_optimizer.adam_weight_decay: .01 42 | create_discriminator_optimizer.grad_clip: 10 43 | 44 | #lambdas: 45 | # mel/loss: 15.0 46 | # adv/feat_loss: 200 # 2.0 * (5+5+5+5+5+25+25+25) = 2.0 * 100 47 | # # 2.0 comes from the PyTorch DAC base.yml Then we multiply since we normalized the magnitude 48 | # # of our feature loss differently than the PyTorch version. 49 | # # 5 is (6-1) where 6 is number of convs in MPD. Then there are 5 of these because the 50 | # # number of periods is 5. 51 | # # 25 is number of bands (5) times the number of convs (5) in MRD. Then there are 3 of these 52 | # # because of the number of fft sizes is 3. 53 | # adv/gen_loss: 8 # 1.0 * 8 where 8 is number of Discriminator rates+periods+fft sizes = (0+5+3) 54 | # # 1.0 comes from the PyTorch DAC base.yml 55 | # vq/commitment_loss: 2.25 # 0.25 * 9 since we normalize based on the number of codebooks. 56 | # vq/codebook_loss: 9 # 1 * 9 since we normalized based on the number of codebooks. 57 | 58 | lambdas: 59 | mel/loss: 15.0 60 | adv/feat_loss: 2 61 | adv/gen_loss: 1 62 | vq/commitment_loss: 0.25 63 | vq/codebook_loss: 1 64 | 65 | train.batch_size: 72 66 | train.val_batch_size: 100 67 | train.sample_batch_size: 100 68 | train.num_iterations: 250000 69 | train.valid_freq: 1000 70 | train.sample_freq: 10000 71 | train.ckpt_max_keep: 4 72 | train.seed: 0 73 | train.tabulate: 1 74 | 75 | EarlyStopping.min_delta: .001 76 | EarlyStopping.patience: 4 77 | 78 | log_training.log_every_steps: 10 79 | 80 | # Loss setup 81 | multiscale_stft_loss.window_lengths: [2048, 512] 82 | 83 | mel_spectrogram_loss.n_mels: [5, 10, 20, 40, 80, 160, 320] 84 | mel_spectrogram_loss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 85 | mel_spectrogram_loss.lower_edge_hz: [0, 0, 0, 0, 0, 0, 0] 86 | mel_spectrogram_loss.upper_edge_hz: [null, null, null, null, null, null, null] 87 | mel_spectrogram_loss.pow: 1.0 88 | mel_spectrogram_loss.clamp_eps: 1.0e-5 89 | mel_spectrogram_loss.mag_weight: 0.0 90 | 91 | # Data Augmentation 92 | VolumeNorm.config: 93 | min_db: -16 94 | max_db: -16 95 | 96 | train/augment_batch.transforms: 97 | - VolumeNorm 98 | - RescaleAudio 99 | - ShiftPhase 100 | 101 | val/augment_batch.transforms: 102 | - VolumeNorm 103 | - RescaleAudio 104 | 105 | sample/augment_batch.transforms: 106 | - VolumeNorm 107 | - RescaleAudio 108 | 109 | # Data 110 | # This should be equivalent to how DAC used salient_excerpt from AudioTools. 111 | SaliencyParams.enabled: 1 112 | SaliencyParams.num_tries: 8 113 | SaliencyParams.loudness_cutoff: -40 114 | SaliencyParams.search_function: SaliencyParams.search_uniform 115 | 116 | # Data 117 | create_dataset.worker_count: 0 118 | create_dataset.worker_buffer_size: 1 119 | 120 | create_dataset.extensions: 121 | - .wav 122 | - .flac 123 | - .ogg 124 | # - .mp3 125 | 126 | train/create_dataset.duration: 0.38 127 | val/create_dataset.duration: 5.0 128 | sample/create_dataset.duration: 5.0 129 | test/create_dataset.duration: 10.0 130 | 131 | val/create_dataset.num_steps: 4 132 | 133 | train/create_dataset.sources: 134 | speech_fb: 135 | - /data/daps/train 136 | speech_hq: 137 | - /data/vctk 138 | - /data/vocalset 139 | - /data/read_speech 140 | - /data/french_speech 141 | speech_uq: 142 | - /data/emotional_speech/ 143 | - /data/common_voice/ 144 | - /data/german_speech/ 145 | - /data/russian_speech/ 146 | - /data/spanish_speech/ 147 | music_hq: 148 | - /data/musdb/train 149 | music_uq: 150 | - /data/jamendo 151 | general: 152 | - /data/audioset/data/unbalanced_train_segments/ 153 | - /data/audioset/data/balanced_train_segments/ 154 | 155 | val/create_dataset.sources: 156 | speech_hq: 157 | - /data/daps/val 158 | music_hq: 159 | - /data/musdb/test 160 | general: 161 | - /data/audioset/data/eval_segments/ 162 | 163 | sample/create_dataset.sources: 164 | speech_hq: 165 | - /data/daps/val 166 | music_hq: 167 | - /data/musdb/test 168 | general: 169 | - /data/audioset/data/eval_segments/ 170 | 171 | test/create_dataset.sources: 172 | speech_hq: 173 | - /data/daps/test 174 | music_hq: 175 | - /data/musdb/test 176 | general: 177 | - /data/audioset/data/eval_segments/ 178 | -------------------------------------------------------------------------------- /conf/downsampling/1024x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 64 8 | DAC.encoder_rates: [2, 8, 8, 8] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [8, 4, 4, 2, 2, 2] 11 | 12 | # Quantization 13 | DAC.num_codebooks: 19 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /conf/downsampling/128x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 64 8 | DAC.encoder_rates: [2, 4, 4, 4] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [4, 4, 2, 2, 2, 1] 11 | 12 | # Quantization 13 | DAC.num_codebooks: 2 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /conf/downsampling/1536x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 96 8 | DAC.encoder_rates: [2, 8, 8, 12] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [12, 4, 4, 2, 2, 2] 11 | 12 | # Quantization 13 | DAC.num_codebooks: 28 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /conf/downsampling/768x.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | # Model setup 6 | DAC.sample_rate: 44100 7 | DAC.encoder_dim: 64 8 | DAC.encoder_rates: [2, 6, 8, 8] 9 | DAC.decoder_dim: 1536 10 | DAC.decoder_rates: [6, 4, 4, 2, 2, 2] 11 | 12 | # Quantization 13 | DAC.num_codebooks: 14 14 | DAC.codebook_size: 1024 15 | DAC.codebook_dim: 8 16 | DAC.quantizer_dropout: 1.0 17 | -------------------------------------------------------------------------------- /conf/final/16khz.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | 4 | DAC.sample_rate: 16000 5 | 6 | DAC.encoder_rates: [2, 4, 5, 8] 7 | 8 | DAC.decoder_rates: [8, 5, 4, 2] 9 | 10 | DAC.num_codebooks: 12 11 | 12 | DAC.quantizer_dropout: 0.5 13 | 14 | Discriminator.sample_rate: 16000 15 | 16 | train.num_iterations: 400000 -------------------------------------------------------------------------------- /conf/final/24khz.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | 4 | DAC.sample_rate: 24000 5 | 6 | DAC.encoder_rates: [2, 4, 5, 8] 7 | 8 | DAC.decoder_rates: [8, 5, 4, 2] 9 | 10 | DAC.num_codebooks: 32 11 | 12 | DAC.quantizer_dropout: 0.5 13 | 14 | Discriminator.sample_rate: 24000 15 | 16 | train.num_iterations: 400000 -------------------------------------------------------------------------------- /conf/final/44khz-16kbps.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | 4 | DAC.num_codebooks: 18 # Max bitrate of 16kbps 5 | 6 | DAC.quantizer_dropout: 0.5 7 | 8 | train.num_iterations: 400000 -------------------------------------------------------------------------------- /conf/final/44khz.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | 4 | DAC.quantizer_dropout: 0.5 5 | 6 | train.num_iterations: 400000 -------------------------------------------------------------------------------- /conf/neuronic.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/final/44khz.yml 3 | 4 | train.batch_size: 4 5 | train.val_batch_size: 4 6 | train.sample_batch_size: 1 7 | train.valid_freq: 4000 8 | train.sample_freq: 4000 9 | train.ckpt_max_keep: 1 10 | train.tabulate: 0 11 | 12 | EarlyStopping.patience: 10 13 | 14 | # Data 15 | create_dataset.worker_count: 0 16 | 17 | # Data Augmentation 18 | VolumeChange.config: 19 | min_db: -10 20 | max_db: 0 21 | 22 | #train/build_transforms.augment: 23 | # - VolumeChange 24 | # - RescaleAudio 25 | # - ShiftPhase 26 | # 27 | #val/build_transforms.augment: 28 | # - VolumeChange 29 | # - RescaleAudio 30 | # 31 | #sample/build_transforms.augment: 32 | # - VolumeNorm 33 | # - RescaleAudio 34 | 35 | # Data 36 | # This should be equivalent to how DAC used salient_excerpt from AudioTools. 37 | SaliencyParams.enabled: 1 38 | SaliencyParams.num_tries: 8 39 | SaliencyParams.loudness_cutoff: -40 40 | SaliencyParams.search_function: SaliencyParams.search_bias_early 41 | 42 | train/create_dataset.sources: 43 | musdb18hq: 44 | - /scratch/$USER/datasets/musdb18hq/train/*/mixture.wav 45 | # nsynth: 46 | # - /scratch/$USER/datasets/nsynth/nsynth-train/audio 47 | 48 | val/create_dataset.num_steps: 100 49 | val/create_dataset.duration: 2 50 | val/create_dataset.sources: 51 | musdb18hq: 52 | - /scratch/$USER/datasets/musdb18hq/test/*/mixture.wav 53 | # nsynth: 54 | # - /scratch/$USER/datasets/nsynth/nsynth-valid/audio 55 | 56 | sample/create_dataset.duration: 2 57 | sample/create_dataset.sources: 58 | musdb18hq: 59 | - /scratch/$USER/datasets/musdb18hq/test/*/mixture.wav 60 | # nsynth: 61 | # - /scratch/$USER/datasets/nsynth/nsynth-valid/audio 62 | 63 | test/create_dataset.duration: 4 64 | test/create_dataset.sources: 65 | musdb18hq: 66 | - /scratch/$USER/datasets/musdb18hq/test/*/mixture.wav 67 | # nsynth: 68 | # - /scratch/$USER/datasets/nsynth/nsynth-test/audio -------------------------------------------------------------------------------- /conf/quantizer/24kbps.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.num_codebooks: 28 6 | -------------------------------------------------------------------------------- /conf/quantizer/256d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 256 6 | -------------------------------------------------------------------------------- /conf/quantizer/2d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 2 6 | -------------------------------------------------------------------------------- /conf/quantizer/32d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 32 6 | -------------------------------------------------------------------------------- /conf/quantizer/4d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 4 6 | -------------------------------------------------------------------------------- /conf/quantizer/512d.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.codebook_dim: 512 6 | -------------------------------------------------------------------------------- /conf/quantizer/dropout-0.0.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.quantizer_dropout: 0.0 6 | -------------------------------------------------------------------------------- /conf/quantizer/dropout-0.25.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.quantizer_dropout: 0.25 6 | -------------------------------------------------------------------------------- /conf/quantizer/dropout-0.5.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.quantizer_dropout: 0.5 6 | -------------------------------------------------------------------------------- /conf/size/medium.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.decoder_dim: 1024 6 | -------------------------------------------------------------------------------- /conf/size/small.yml: -------------------------------------------------------------------------------- 1 | $include: 2 | - conf/base.yml 3 | - conf/1gpu.yml 4 | 5 | DAC.decoder_dim: 512 6 | -------------------------------------------------------------------------------- /jobs/benchmark.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=jax-gpu # create a short name for your job 3 | #SBATCH --nodes=1 # node count 4 | #SBATCH --ntasks=1 # total number of tasks across all nodes 5 | #SBATCH --cpus-per-task=1 # cpu-cores per task (>1 if multi-threaded tasks) 6 | #SBATCH --mem-per-cpu=16G # RAM usage per cpu-core 7 | #SBATCH --gres=gpu:1 # number of gpus per node 8 | #SBATCH --time=03:00:00 # total run time limit (HH:MM:SS) 9 | #SBATCH --mail-type=END # choice could be 'fail' 10 | #SBATCH --mail-user=db1224@princeton.edu 11 | 12 | module purge 13 | module load anaconda3/2024.02 14 | 15 | eval "$(conda shell.bash hook)" 16 | conda activate jax-env 17 | 18 | python scripts/benchmark.py --model_type=16khz 19 | python scripts/benchmark.py --model_type=24khz 20 | python scripts/benchmark.py --model_type=44khz 21 | -------------------------------------------------------------------------------- /jobs/simple.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=DAC-JAX # create a short name for your job 3 | #SBATCH --nodes=1 # node count 4 | #SBATCH --ntasks=1 # total number of tasks across all nodes 5 | #SBATCH --cpus-per-task=2 # cpu-cores per task (>1 if multi-threaded tasks) 6 | #SBATCH --mem-per-cpu=16G # RAM usage per cpu-core 7 | #SBATCH --gres=gpu:2 # number of gpus per node 8 | #SBATCH --time=60:00:00 # total run time limit (HH:MM:SS) 9 | #SBATCH --signal=B:USR1@120 # 120 sec grace period for cleanup after timeout 10 | #SBATCH --signal=B:SIGTERM@120 # 120 sec grace period for cleanup after scancel is sent 11 | #SBATCH --mail-type=END # choice could be 'fail' 12 | #SBATCH --mail-user=db1224@princeton.edu 13 | 14 | function cleanup() { 15 | echo 'Running cleanup script' 16 | kill $TRAIN_PID 17 | kill $TB_PID 18 | cp -r "/scratch/$USER/runs" "/n/fs/audiovis/$USER/DAC-JAX/runs" 19 | rm -rf "/scratch/$USER" 20 | exit 0 21 | } 22 | 23 | ## Trap the SIGTERM signal (sent by scancel) and call the cleanup function 24 | trap cleanup EXIT SIGINT SIGTERM 25 | 26 | module purge 27 | module load anaconda3/2024.02 28 | 29 | eval "$(conda shell.bash hook)" 30 | 31 | conda activate ../Terrapin/.env/jax-env 32 | export PYTHONPATH=$PWD 33 | 34 | ## prepare data 35 | ##echo "$(date '+%H:%M:%S'): Copying data to /scratch" 36 | ##mkdir -p "/scratch/$USER/datasets" 37 | ##rsync -a --info=progress2 --no-i-r "/n/fs/audiovis/$USER/datasets/nsynth" "/scratch/$USER/datasets" 38 | # 39 | ##cd "/scratch/$USER/datasets/nsynth" || exit 40 | ##echo "$(date '+%H:%M:%S'): Unzipping test" 41 | ##tar -xzf nsynth-test.jsonwav.tar.gz 42 | ##echo "$(date '+%H:%M:%S'): Unzipping valid" 43 | ##tar -xzf nsynth-valid.jsonwav.tar.gz 44 | ##echo "$(date '+%H:%M:%S'): Unzipping train" 45 | ##tar -xzf nsynth-train.jsonwav.tar.gz 46 | ##echo "$(date '+%H:%M:%S'): Copied data to /scratch" 47 | 48 | ## prepare data 49 | echo "$(date '+%H:%M:%S'): Copying data to /scratch" 50 | mkdir -p "/scratch/$USER/datasets" 51 | rsync -a --info=progress2 --no-i-r "/n/fs/audiovis/$USER/datasets/musdb18hq" "/scratch/$USER/datasets" 52 | 53 | cd "/scratch/$USER/datasets/musdb18hq" || exit 54 | echo "$(date '+%H:%M:%S'): Unzipping musdb18hq" 55 | unzip -q musdb18hq.zip 56 | echo "$(date '+%H:%M:%S'): Copied data to /scratch" 57 | 58 | ## Launch TensorBoard and get the process ID of TensorBoard 59 | tensorboard --logdir="/scratch/$USER/runs" --port=10013 --samples_per_plugin audio=20 --bind_all & TB_PID=$! 60 | 61 | cd "/n/fs/audiovis/$USER/DAC-JAX" || exit 62 | python scripts/train.py \ 63 | --args.load conf/neuronic.yml \ 64 | --train.name="slurm_$SLURM_JOB_ID" \ 65 | --train.ckpt_dir="/scratch/$USER/runs" \ 66 | & TRAIN_PID=$! 67 | 68 | wait $TRAIN_PID 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import argbind 4 | import jax 5 | from jax import random 6 | 7 | from dac_jax import load_model 8 | 9 | 10 | @argbind.bind(without_prefix=True) 11 | def benchmark_dac(model_type="44khz", model_bitrate='8kbps', win_durations: List[str] = None): 12 | 13 | if win_durations is None: 14 | win_durations = [0.37, 0.38, 0.42, 0.46, 0.5, 1, 5, 10, 20] 15 | else: 16 | win_durations = [float(x) for x in win_durations] 17 | 18 | # Set padding to False since we're using chunk functions. 19 | model, variables = load_model(model_type=model_type, model_bitrate=model_bitrate, padding=False) 20 | 21 | @jax.jit 22 | def compress_chunk(x): 23 | return model.apply(variables, x, method='compress_chunk') 24 | 25 | @jax.jit 26 | def decompress_chunk(c): 27 | return model.apply(variables, c, method='decompress_chunk') 28 | 29 | audio_sr = model.sample_rate # not always a valid assumption, in case you copy-paste this elsewhere 30 | 31 | print(f'Benchmarking model: {model_type}, {model_bitrate}') 32 | 33 | for win_duration in win_durations: 34 | # Force chunk-encoding by making duration 1 more than win_duration: 35 | # (one day the compress function will default to unchunked if the audio length is <= win_duration) 36 | T = 1 + int(win_duration * model.sample_rate) 37 | x = random.normal(random.key(0), shape=(1, 1, T)) 38 | try: 39 | dac_file = model.compress(compress_chunk, x, audio_sr, win_duration=win_duration, benchmark=True) 40 | recons = model.decompress(decompress_chunk, dac_file, benchmark=True) 41 | except Exception as e: 42 | print(f'Exception for win duration "{win_duration}": {e}') 43 | 44 | 45 | if __name__ == "__main__": 46 | # example usage: 47 | # python3 benchmark.py --model_type=16khz --win_durations="0.5 1 5 10 20" 48 | print(f'devices: {jax.devices()}') 49 | 50 | args = argbind.parse_args() 51 | with argbind.scope(args): 52 | benchmark_dac() 53 | 54 | 55 | # @argbind.bind(without_prefix=True) 56 | # def benchmark_dac_encode(model_type="44khz", model_bitrate='8kbps', batch_size: int = 1, durations: List[str] = None): 57 | # 58 | # if durations is None: 59 | # durations = [1, 2, 4, 8, 16, 32] 60 | # else: 61 | # durations = [float(x) for x in durations] 62 | # 63 | # model, variables = load_model(model_type=model_type, model_bitrate=model_bitrate) 64 | # 65 | # @jax.jit 66 | # def encode(audio): 67 | # audio = model.apply(variables, audio, model.sample_rate, method="preprocess") 68 | # _, codes, _, _, _ = model.apply(variables, audio, train=False, method="encode") 69 | # return codes 70 | # 71 | # for duration in durations: 72 | # print(f'Benchmarking encode for model: {model_type}, {model_bitrate} with duration {duration} sec and batch size {batch_size}.') 73 | # 74 | # T = int(duration * model.sample_rate) 75 | # x = random.normal(random.key(0), shape=(batch_size, 1, T)) 76 | # import tqdm 77 | # for _ in tqdm.trange(100): 78 | # try: 79 | # encode(x) 80 | # except Exception as e: 81 | # print(f'Exception for duration "{duration}": {e}') 82 | 83 | 84 | # if __name__ == "__main__": 85 | # # example usage: 86 | # # python3 benchmark.py --model_type=44khz --durations="5" --batch_size=8 87 | # print(f'devices: {jax.devices()}') 88 | # 89 | # args = argbind.parse_args() 90 | # with argbind.scope(args): 91 | # benchmark_dac_encode() 92 | -------------------------------------------------------------------------------- /scripts/compute_entropy.py: -------------------------------------------------------------------------------- 1 | import argbind 2 | import jax 3 | from audiotools import AudioSignal 4 | import numpy as np 5 | import tqdm 6 | 7 | from dac_jax import load_model 8 | from dac_jax.audio_utils import find_audio 9 | 10 | 11 | @argbind.bind(without_prefix=True, positional=True) 12 | def main( 13 | folder: str, 14 | model_path: str, 15 | metadata_path: str, 16 | n_samples: int = 1024, 17 | ): 18 | files = find_audio(folder)[:n_samples] 19 | key = jax.random.key(0) 20 | key, subkey = jax.random.split(key) 21 | signals = [ 22 | AudioSignal.salient_excerpt(f, subkey, loudness_cutoff=-20, duration=1.0) 23 | for f in files 24 | ] 25 | 26 | assert model_path is not None 27 | assert metadata_path is not None 28 | 29 | model, variables = load_model(load_path=model_path, metadata_path=metadata_path) 30 | model = model.bind(variables) 31 | 32 | codes = [] 33 | for x in tqdm.tqdm(signals): 34 | x = jax.device_put(x, model.device) 35 | o = model.encode(x.audio_data, x.sample_rate) 36 | codes.append(np.array(o["codes"])) 37 | 38 | codes = np.concatenate(codes, axis=-1) 39 | entropy = [] 40 | 41 | for i in range(codes.shape[1]): 42 | codes_ = codes[0, i, :] 43 | counts = np.bincount(codes_) 44 | counts = (counts / counts.sum()) 45 | counts = np.maximum(counts, 1e-10) 46 | entropy.append(-(counts * np.log(counts)).sum().item() * np.log2(np.e)) 47 | 48 | pct = sum(entropy) / (10 * len(entropy)) 49 | print(f"Entropy for each codebook: {entropy}") 50 | print(f"Effective percentage: {pct * 100}%") 51 | 52 | 53 | if __name__ == "__main__": 54 | args = argbind.parse_args() 55 | with argbind.scope(args): 56 | main() 57 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import multiprocessing as mp 3 | from concurrent.futures import ProcessPoolExecutor 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | import argbind 8 | from audiotools import AudioSignal 9 | from audiotools import metrics 10 | from audiotools.core import util 11 | from audiotools.ml.decorators import Tracker 12 | import jax.numpy as jnp 13 | import numpy as np 14 | 15 | from dac_jax.nn.loss import multiscale_stft_loss, mel_spectrogram_loss, sisdr_loss, l1_loss 16 | 17 | 18 | @dataclass 19 | class State: 20 | stft_loss: multiscale_stft_loss 21 | mel_loss: mel_spectrogram_loss 22 | waveform_loss: l1_loss 23 | sisdr_loss: sisdr_loss 24 | 25 | 26 | def get_metrics(signal_path, recons_path, state): 27 | output = {} 28 | signal = AudioSignal(signal_path) 29 | recons = AudioSignal(recons_path) 30 | for sr in [22050, 44100]: 31 | x = signal.clone().resample(sr) 32 | y = recons.clone().resample(sr) 33 | k = "22k" if sr == 22050 else "44k" 34 | output.update( 35 | { 36 | f"mel-{k}": state.mel_loss(x, y), 37 | f"stft-{k}": state.stft_loss(x, y), 38 | f"waveform-{k}": state.waveform_loss(x, y), 39 | f"sisdr-{k}": state.sisdr_loss(x, y), 40 | f"visqol-audio-{k}": metrics.quality.visqol(x, y), 41 | f"visqol-speech-{k}": metrics.quality.visqol(x, y, "speech"), 42 | } 43 | ) 44 | output["path"] = signal.path_to_file 45 | output.update(signal.metadata) 46 | return output 47 | 48 | 49 | @argbind.bind(without_prefix=True) 50 | def evaluate( 51 | input: str = "samples/input", 52 | output: str = "samples/output", 53 | n_proc: int = 50, 54 | ): 55 | tracker = Tracker() 56 | 57 | state = State( 58 | waveform_loss=l1_loss, 59 | stft_loss=multiscale_stft_loss, 60 | mel_loss=mel_spectrogram_loss, 61 | sisdr_loss=sisdr_loss, 62 | ) 63 | 64 | audio_files = util.find_audio(input) 65 | output = Path(output) 66 | output.mkdir(parents=True, exist_ok=True) 67 | 68 | @tracker.track("metrics", len(audio_files)) 69 | def record(future, writer): 70 | o = future.result() 71 | for k, v in o.items(): 72 | if isinstance(v, jnp.ndarray): # todo: 73 | o[k] = np.array(v).item() # todo: 74 | writer.writerow(o) 75 | o.pop("path") 76 | return o 77 | 78 | futures = [] 79 | with tracker.live: 80 | with open(output / "metrics.csv", "w") as csvfile: 81 | with ProcessPoolExecutor(n_proc, mp.get_context("fork")) as pool: 82 | for i in range(len(audio_files)): 83 | future = pool.submit( 84 | get_metrics, audio_files[i], output / audio_files[i].name, state 85 | ) 86 | futures.append(future) 87 | 88 | keys = list(futures[0].result().keys()) 89 | writer = csv.DictWriter(csvfile, fieldnames=keys) 90 | writer.writeheader() 91 | 92 | for future in futures: 93 | record(future, writer) 94 | 95 | tracker.done("test", f"N={len(audio_files)}") 96 | 97 | 98 | if __name__ == "__main__": 99 | args = argbind.parse_args() 100 | with argbind.scope(args): 101 | evaluate() 102 | -------------------------------------------------------------------------------- /scripts/get_samples.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import argbind 4 | from audiotools import AudioSignal 5 | from audiotools.ml.decorators import Tracker 6 | from train import Accelerator 7 | from train import DAC 8 | 9 | from dac_jax.audio_utils import find_audio 10 | from dac_jax.compare.encodec import Encodec 11 | 12 | 13 | Encodec = argbind.bind(Encodec) 14 | 15 | 16 | def load_state( 17 | accel: Accelerator, 18 | tracker: Tracker, 19 | save_path: str, 20 | tag: str = "latest", 21 | load_weights: bool = False, 22 | model_type: str = "dac", 23 | bandwidth: float = 24.0, 24 | ): 25 | kwargs = { 26 | "folder": f"{save_path}/{tag}", 27 | "map_location": "cpu", 28 | "package": not load_weights, 29 | } 30 | tracker.print(f"Resuming from {str(Path('.').absolute())}/{kwargs['folder']}") 31 | 32 | if model_type == "dac": 33 | generator, _ = DAC.load_from_folder(**kwargs) 34 | elif model_type == "encodec": 35 | generator = Encodec(bandwidth=bandwidth) 36 | 37 | generator = accel.prepare_model(generator) 38 | return generator 39 | 40 | 41 | def process(signal, accel, generator, **kwargs): 42 | signal = signal.to(accel.device) 43 | recons = generator(signal.audio_data, signal.sample_rate, **kwargs)["audio"] 44 | recons = AudioSignal(recons, signal.sample_rate) 45 | recons = recons.normalize(signal.loudness()) 46 | return recons.cpu() 47 | 48 | 49 | @argbind.bind(without_prefix=True) 50 | def get_samples( 51 | accel, 52 | path: str = "ckpt", 53 | input: str = "samples/input", 54 | output: str = "samples/output", 55 | model_type: str = "dac", 56 | model_tag: str = "latest", 57 | bandwidth: float = 24.0, 58 | n_quantizers: int = None, 59 | ): 60 | tracker = Tracker(log_file=f"{path}/eval.txt", rank=accel.local_rank) 61 | generator = load_state( 62 | accel, 63 | tracker, 64 | save_path=path, 65 | model_type=model_type, 66 | bandwidth=bandwidth, 67 | tag=model_tag, 68 | ) 69 | kwargs = {"n_quantizers": n_quantizers} if model_type == "dac" else {} 70 | 71 | audio_files = find_audio(input) 72 | 73 | global process 74 | process = tracker.track("process", len(audio_files))(process) 75 | 76 | output = Path(output) 77 | output.mkdir(parents=True, exist_ok=True) 78 | 79 | with tracker.live: 80 | for i in range(len(audio_files)): 81 | signal = AudioSignal(audio_files[i]) 82 | recons = process(signal, accel, generator, **kwargs) 83 | recons.write(output / audio_files[i].name) 84 | 85 | tracker.done("test", f"N={len(audio_files)}") 86 | 87 | 88 | if __name__ == "__main__": 89 | args = argbind.parse_args() 90 | with argbind.scope(args): 91 | with Accelerator() as accel: 92 | get_samples(accel) 93 | -------------------------------------------------------------------------------- /scripts/input_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import List, Mapping 2 | 3 | import argbind 4 | from audiotree import SaliencyParams 5 | from audiotree.datasources import ( 6 | AudioDataSimpleSource, 7 | AudioDataBalancedSource, 8 | ) 9 | from audiotree.transforms import ReduceBatchTransform 10 | from grain import python as grain 11 | 12 | SaliencyParams = argbind.bind(SaliencyParams, "train", "val", "test", "sample") 13 | 14 | 15 | @argbind.bind("train", "val", "test", "sample") 16 | def create_dataset( 17 | batch_size: int, 18 | sample_rate: int, 19 | duration: float = 0.2, 20 | sources: Mapping[str, List[str]] = None, 21 | extensions: List[str] = None, 22 | mono: int = 1, # bool 23 | train: int = 0, # bool 24 | num_steps: int = None, 25 | seed: int = 0, 26 | worker_count: int = 0, 27 | worker_buffer_size: int = 2, 28 | enable_profiling: int = 0, # bool 29 | num_epochs: int = 1, # for train/val use 1, but for sample set it to None so that it loops forever. 30 | ): 31 | 32 | assert sources is not None 33 | 34 | if train: 35 | assert num_steps is not None and num_steps > 0 36 | datasource = AudioDataBalancedSource( 37 | sources=sources, 38 | num_records=num_steps * batch_size, 39 | sample_rate=sample_rate, 40 | mono=mono, 41 | duration=duration, 42 | extensions=extensions, 43 | saliency_params=SaliencyParams(), # rely on argbind, 44 | ) 45 | else: 46 | datasource = AudioDataSimpleSource( 47 | sources=sources, 48 | num_records=num_steps * batch_size if num_steps is not None else None, 49 | sample_rate=sample_rate, 50 | mono=mono, 51 | duration=duration, 52 | extensions=extensions, 53 | ) 54 | 55 | shard_options = grain.NoSharding() # todo: 56 | 57 | index_sampler = grain.IndexSampler( 58 | num_records=len(datasource), 59 | num_epochs=num_epochs, 60 | shard_options=shard_options, 61 | shuffle=bool(train), 62 | seed=seed, 63 | ) 64 | 65 | pygrain_ops = [ 66 | grain.Batch(batch_size=batch_size, drop_remainder=True), 67 | ReduceBatchTransform(), 68 | ] 69 | 70 | dataloader = grain.DataLoader( 71 | data_source=datasource, 72 | sampler=index_sampler, 73 | operations=pygrain_ops, 74 | worker_count=worker_count, 75 | worker_buffer_size=worker_buffer_size, 76 | shard_options=shard_options, 77 | enable_profiling=bool(enable_profiling), 78 | ) 79 | 80 | return dataloader 81 | 82 | 83 | if __name__ == "__main__": 84 | 85 | from tqdm import tqdm 86 | from absl import logging 87 | 88 | logging.set_verbosity(logging.INFO) 89 | 90 | folder1 = "/mnt/d/Datasets/dx7/patches-DX7-AllTheWeb-Bridge-Music-Recording-Studio-Sysex-Set-4-Instruments-Bass-Bass3-bass-10-syx-01-SUPERBASS2-note69" 91 | folder2 = "/mnt/d/Datasets/dx7/patches-DX7-AllTheWeb-Bridge-Music-Recording-Studio-Sysex-Set-4-Instruments-Accordion-ACCORD01-SYX-06-AKKORDEON-note69" 92 | 93 | sources = { 94 | "a": [folder1], 95 | "b": [folder2], 96 | } 97 | 98 | num_steps = 1000 99 | 100 | ds = create_dataset( 101 | batch_size=32, 102 | sample_rate=44_100, 103 | sources=sources, 104 | duration=0.5, 105 | train=True, 106 | mono=True, 107 | seed=0, 108 | num_steps=num_steps, 109 | extensions=None, 110 | worker_count=0, 111 | worker_buffer_size=1, 112 | saliency_params=SaliencyParams(False, 8, -70), 113 | ) 114 | 115 | for x in tqdm(ds, total=num_steps, desc="Grain Dataset"): 116 | pass 117 | -------------------------------------------------------------------------------- /scripts/mushra.py: -------------------------------------------------------------------------------- 1 | import string 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import argbind 7 | import gradio as gr 8 | from audiotools import preference as pr 9 | 10 | 11 | @argbind.bind(without_prefix=True) 12 | @dataclass 13 | class Config: 14 | folder: str = None 15 | save_path: str = "results.csv" 16 | conditions: List[str] = None 17 | reference: str = None 18 | seed: int = 0 19 | share: bool = False 20 | n_samples: int = 10 21 | 22 | 23 | def get_text(wav_file: str): 24 | txt_file = Path(wav_file).with_suffix(".txt") 25 | if Path(txt_file).exists(): 26 | with open(txt_file, "r") as f: 27 | txt = f.read() 28 | else: 29 | txt = "" 30 | return f"""
{txt}
""" 31 | 32 | 33 | def main(config: Config): 34 | with gr.Blocks() as app: 35 | save_path = config.save_path 36 | samples = gr.State(pr.Samples(config.folder, n_samples=config.n_samples)) 37 | 38 | reference = config.reference 39 | conditions = config.conditions 40 | 41 | player = pr.Player(app) 42 | player.create() 43 | if reference is not None: 44 | player.add("Play Reference") 45 | 46 | user = pr.create_tracker(app) 47 | ratings = [] 48 | 49 | with gr.Row(): 50 | txt = gr.HTML("") 51 | 52 | with gr.Row(): 53 | gr.Button("Rate audio quality", interactive=False) 54 | with gr.Column(scale=8): 55 | gr.HTML(pr.slider_mushra) 56 | 57 | for i in range(len(conditions)): 58 | with gr.Row().style(equal_height=True): 59 | x = string.ascii_uppercase[i] 60 | player.add(f"Play {x}") 61 | with gr.Column(scale=9): 62 | ratings.append(gr.Slider(value=50, interactive=True)) 63 | 64 | def build(user, samples, *ratings): 65 | # Filter out samples user has done already, by looking in the CSV. 66 | samples.filter_completed(user, save_path) 67 | 68 | # Write results to CSV 69 | if samples.current > 0: 70 | start_idx = 1 if reference is not None else 0 71 | name = samples.names[samples.current - 1] 72 | result = {"sample": name, "user": user} 73 | for k, r in zip(samples.order[start_idx:], ratings): 74 | result[k] = r 75 | pr.save_result(result, save_path) 76 | 77 | updates, done, pbar = samples.get_next_sample(reference, conditions) 78 | wav_file = updates[0]["value"] 79 | 80 | txt_update = gr.update(value=get_text(wav_file)) 81 | 82 | return ( 83 | updates 84 | + [gr.update(value=50) for _ in ratings] 85 | + [done, samples, pbar, txt_update] 86 | ) 87 | 88 | progress = gr.HTML() 89 | begin = gr.Button("Submit", elem_id="start-survey") 90 | begin.click( 91 | fn=build, 92 | inputs=[user, samples] + ratings, 93 | outputs=player.to_list() + ratings + [begin, samples, progress, txt], 94 | ).then(None, _js=pr.reset_player) 95 | 96 | # Comment this back in to actually launch the script. 97 | app.launch(share=config.share) 98 | 99 | 100 | if __name__ == "__main__": 101 | args = argbind.parse_args() 102 | with argbind.scope(args): 103 | config = Config() 104 | main(config) 105 | -------------------------------------------------------------------------------- /scripts/organize_daps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import shutil 4 | from collections import defaultdict 5 | from typing import Tuple 6 | 7 | import argbind 8 | import numpy as np 9 | import tqdm 10 | from audiotools.core import util 11 | 12 | 13 | @argbind.bind() 14 | def split( 15 | audio_files, ratio: Tuple[float, float, float] = (0.8, 0.1, 0.1), seed: int = 0 16 | ): 17 | assert sum(ratio) == 1.0 18 | util.seed(seed) 19 | 20 | idx = np.arange(len(audio_files)) 21 | np.random.shuffle(idx) 22 | 23 | b = np.cumsum([0] + list(ratio)) * len(idx) 24 | b = [int(_b) for _b in b] 25 | train_idx = idx[b[0] : b[1]] 26 | val_idx = idx[b[1] : b[2]] 27 | test_idx = idx[b[2] :] 28 | 29 | audio_files = np.array(audio_files) 30 | train_files = audio_files[train_idx] 31 | val_files = audio_files[val_idx] 32 | test_files = audio_files[test_idx] 33 | 34 | return train_files, val_files, test_files 35 | 36 | 37 | def assign(val_split, test_split): 38 | def _assign(value): 39 | if value in val_split: 40 | return "val" 41 | if value in test_split: 42 | return "test" 43 | return "train" 44 | 45 | return _assign 46 | 47 | 48 | DAPS_VAL = ["f2", "m2"] 49 | DAPS_TEST = ["f10", "m10"] 50 | 51 | 52 | @argbind.bind(without_prefix=True) 53 | def process( 54 | dataset: str = "daps", 55 | daps_subset: str = "", 56 | ): 57 | get_split = None 58 | get_value = lambda path: path 59 | 60 | data_path = pathlib.Path("/data") 61 | dataset_path = data_path / dataset 62 | audio_files = util.find_audio(dataset_path) 63 | 64 | if dataset == "daps": 65 | get_split = assign(DAPS_VAL, DAPS_TEST) 66 | get_value = lambda path: (str(path).split("/")[-1].split("_", maxsplit=4)[0]) 67 | audio_files = [ 68 | x 69 | for x in util.find_audio(dataset_path) 70 | if daps_subset in str(x) and "breaths" not in str(x) 71 | ] 72 | 73 | if get_split is None: 74 | _, val, test = split(audio_files) 75 | get_split = assign(val, test) 76 | 77 | splits = defaultdict(list) 78 | for x in audio_files: 79 | _split = get_split(get_value(x)) 80 | splits[_split].append(x) 81 | 82 | with util.chdir(dataset_path): 83 | for k, v in splits.items(): 84 | v = sorted(v) 85 | print(f"Processing {k} in {dataset_path} of length {len(v)}") 86 | for _v in tqdm.tqdm(v): 87 | tgt_path = pathlib.Path( 88 | str(_v).replace(str(dataset_path), str(dataset_path / k)) 89 | ) 90 | tgt_path.parent.mkdir(parents=True, exist_ok=True) 91 | shutil.copyfile(_v, tgt_path) 92 | 93 | 94 | if __name__ == "__main__": 95 | args = argbind.parse_args() 96 | with argbind.scope(args): 97 | process() 98 | -------------------------------------------------------------------------------- /scripts/save_test_set.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | 4 | import argbind 5 | import torch 6 | from audiotools.ml.decorators import Tracker 7 | 8 | import scripts.train as train 9 | 10 | 11 | @torch.no_grad() 12 | def process(batch, test_data): 13 | signal = test_data.transform(batch["signal"].clone(), **batch["transform_args"]) 14 | return signal.cpu() 15 | 16 | 17 | @argbind.bind(without_prefix=True) 18 | @torch.no_grad() 19 | def save_test_set(args, sample_rate: int = 44100, output: str = "samples/input"): 20 | tracker = Tracker() 21 | with argbind.scope(args, "test"): 22 | test_data = train.create_dataset(sample_rate=sample_rate) 23 | 24 | global process 25 | process = tracker.track("process", len(test_data))(process) 26 | 27 | output = Path(output) 28 | output.mkdir(parents=True, exist_ok=True) 29 | (output.parent / "input").mkdir(parents=True, exist_ok=True) 30 | with open(output / "metadata.csv", "w") as csvfile: 31 | keys = ["path", "original"] 32 | writer = csv.DictWriter(csvfile, fieldnames=keys) 33 | writer.writeheader() 34 | 35 | with tracker.live: 36 | for i in range(len(test_data)): 37 | signal = process(test_data[i], test_data) 38 | input_path = output.parent / "input" / f"sample_{i}.wav" 39 | metadata = { 40 | "path": str(input_path), 41 | "original": str(signal.path_to_input_file), 42 | } 43 | writer.writerow(metadata) 44 | signal.write(input_path) 45 | tracker.done("test", f"N={len(test_data)}") 46 | 47 | 48 | if __name__ == "__main__": 49 | args = argbind.parse_args() 50 | with argbind.scope(args): 51 | save_test_set(args) 52 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = dac_jax 3 | version = attr: dac_jax.__version__ 4 | url = https://github.com/DBraun/DAC-JAX 5 | author = David Braun 6 | author_email = braun@ccrma.stanford.edu 7 | description = Descript Audio Codec and EnCodec in JAX. 8 | long_description = file: README.md 9 | long_description_content_type = "text/markdown" 10 | keywords = 11 | audio 12 | compression 13 | machine learning 14 | license = MIT 15 | classifiers = 16 | Intended Audience :: Developers 17 | Natural Language :: English 18 | Programming Language :: Python :: 3.10 19 | Programming Language :: Python :: 3.11 20 | Programming Language :: Python :: 3.12 21 | Programming Language :: Python :: 3.13 22 | Topic :: Artistic Software 23 | Topic :: Multimedia 24 | Topic :: Multimedia :: Sound/Audio 25 | Topic :: Multimedia :: Sound/Audio :: Editors 26 | Topic :: Software Development :: Libraries 27 | 28 | [options] 29 | package_dir = 30 | = src 31 | packages = find: 32 | python_requires = >=3.10 33 | install_requires = 34 | argbind @ git+https://github.com/DBraun/argbind.git@improve.subclasses 35 | audiotree>=0.2.0 36 | clu>=0.0.12 37 | dm_aux @ git+https://github.com/DBraun/dm_aux.git@DBraun-patch-2 38 | einops>=0.8.0 39 | grain==0.2.* 40 | huggingface-hub 41 | jax-ai-stack>=2025.2.5 42 | jaxloudnorm @ git+https://github.com/boris-kuz/jaxloudnorm.git 43 | librosa>=0.10.1 44 | omegaconf 45 | tqdm>=4.66.4 46 | 47 | [options.packages.find] 48 | where = src 49 | 50 | [options.extras_require] 51 | dev = 52 | audiocraft 53 | descript-audiotools 54 | descript-audio-codec 55 | pytest 56 | pytest-cov 57 | pandas 58 | pandas 59 | pesq 60 | encodec 61 | -------------------------------------------------------------------------------- /src/dac_jax/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | 3 | __author__ = """David Braun""" 4 | __email__ = "braun@ccrma.stanford.edu" 5 | 6 | from dac_jax import nn 7 | from dac_jax import model 8 | from dac_jax import utils 9 | from dac_jax.utils import load_model, load_encodec_model 10 | from dac_jax.model import DACFile 11 | from dac_jax.model import DAC 12 | from dac_jax.model import EncodecModel 13 | from dac_jax.nn.quantize import QuantizedResult 14 | -------------------------------------------------------------------------------- /src/dac_jax/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import argbind 4 | 5 | from dac_jax.utils import download_model 6 | from dac_jax.utils import download_encodec 7 | from dac_jax.utils.decode import decode 8 | from dac_jax.utils.encode import encode 9 | 10 | STAGES = ["encode", "decode", "download_model", "download_encodec"] 11 | 12 | 13 | def run(stage: str): 14 | """Run stages. 15 | 16 | Parameters 17 | ---------- 18 | stage : str 19 | Stage to run 20 | """ 21 | if stage not in STAGES: 22 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") 23 | stage_fn = globals()[stage] 24 | 25 | stage_fn() 26 | 27 | 28 | if __name__ == "__main__": 29 | group = sys.argv.pop(1) 30 | args = argbind.parse_args(group=group) 31 | 32 | with argbind.scope(args): 33 | run(group) 34 | -------------------------------------------------------------------------------- /src/dac_jax/audio_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import math 3 | from pathlib import Path 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import chex 7 | import dm_aux as aux 8 | from einops import rearrange 9 | import jax.numpy as jnp 10 | import jax.scipy.signal 11 | import jaxloudnorm as jln 12 | import librosa 13 | 14 | 15 | def find_audio(folder: Union[str, Path], ext: List[str] = None) -> List[Path]: 16 | """Finds all audio files in a directory recursively. 17 | Returns a list. 18 | 19 | Parameters 20 | ---------- 21 | folder : str 22 | Folder to look for audio files in, recursively. 23 | ext : List[str], optional 24 | Extensions to look for without the ., by default 25 | ``['.wav', '.flac', '.mp3', '.mp4']``. 26 | 27 | Copied from 28 | https://github.com/descriptinc/audiotools/blob/7776c296c711db90176a63ff808c26e0ee087263/audiotools/core/util.py#L225 29 | """ 30 | if ext is None: 31 | ext = [".wav", ".flac", ".mp3", ".mp4"] 32 | 33 | folder = Path(folder) 34 | # Take care of case where user has passed in an audio file directly 35 | # into one of the calling functions. 36 | if str(folder).endswith(tuple(ext)): 37 | # if, however, there's a glob in the path, we need to 38 | # return the glob, not the file. 39 | if "*" in str(folder): 40 | return glob.glob(str(folder), recursive=("**" in str(folder))) 41 | else: 42 | return [folder] 43 | 44 | files = [] 45 | for x in ext: 46 | files += folder.glob(f"**/*{x}") 47 | return files 48 | 49 | 50 | def compute_stft_padding( 51 | length, window_length: int, hop_length: int, match_stride: bool 52 | ): 53 | """Compute how the STFT should be padded, based on match_stride. 54 | 55 | Parameters 56 | ---------- 57 | length: int 58 | window_length : int 59 | Window length of STFT. 60 | hop_length : int 61 | Hop length of STFT. 62 | match_stride : bool 63 | Whether to match stride, making the STFT have the same alignment as convolutional layers. 64 | 65 | Returns 66 | ------- 67 | tuple 68 | Amount to pad on either side of audio. 69 | """ 70 | if match_stride: 71 | assert ( 72 | hop_length == window_length // 4 73 | ), "For match_stride, hop must equal n_fft // 4" 74 | right_pad = math.ceil(length / hop_length) * hop_length - length 75 | pad = (window_length - hop_length) // 2 76 | else: 77 | right_pad = 0 78 | pad = 0 79 | 80 | return right_pad, pad 81 | 82 | 83 | def stft( 84 | x: jnp.ndarray, 85 | frame_length=2048, 86 | hop_factor=0.25, 87 | window="hann", 88 | match_stride=False, 89 | padding_type: str = "reflect", 90 | ): 91 | """Reference: 92 | https://github.com/descriptinc/audiotools/blob/7776c296c711db90176a63ff808c26e0ee087263/audiotools/core/audio_signal.py#L1123 93 | """ 94 | 95 | batch_size, num_channels, audio_length = x.shape 96 | 97 | frame_step = int(frame_length * hop_factor) 98 | 99 | right_pad, pad = compute_stft_padding( 100 | audio_length, frame_length, frame_step, match_stride 101 | ) 102 | x = jnp.pad( 103 | x, pad_width=((0, 0), (0, 0), (pad, pad + right_pad)), mode=padding_type 104 | ) 105 | 106 | x = rearrange(x, "b c t -> (b c) t") 107 | 108 | if window == "sqrt_hann": 109 | from scipy import signal as scipy_signal 110 | 111 | window = jnp.sqrt(scipy_signal.get_window("hann", frame_length)) 112 | 113 | # todo: https://github.com/google-deepmind/dm_aux/issues/2 114 | stft_data = aux.spectral.stft( 115 | x, 116 | n_fft=frame_length, 117 | frame_step=frame_step, 118 | window_fn=window, 119 | pad_mode=padding_type, 120 | pad=aux.spectral.Pad.BOTH, 121 | ) 122 | stft_data = rearrange(stft_data, "(b c) nt nf -> b c nf nt", b=batch_size) 123 | 124 | if match_stride: 125 | # Drop first two and last two frames, which are added 126 | # because of padding. Now num_frames * hop_length = num_samples. 127 | if hop_factor == 0.25: 128 | stft_data = stft_data[..., 2:-2] 129 | else: 130 | # I think this would be correct if DAC torch ever allowed match_stride==True and hop_factor==0.5 131 | stft_data = stft_data[..., 1:-1] 132 | 133 | return stft_data 134 | 135 | 136 | def mel_spectrogram( 137 | spectrograms: chex.Array, 138 | log_scale: bool = True, 139 | sample_rate: int = 16000, 140 | frame_length: Optional[int] = 2048, 141 | num_features: int = 128, 142 | lower_edge_hertz: float = 0.0, 143 | upper_edge_hertz: Optional[float] = None, 144 | ) -> chex.Array: 145 | """Converts the spectrograms to Mel-scale. 146 | 147 | Adapted from dm_aux: 148 | https://github.com/google-deepmind/dm_aux/blob/77f5ed76df2928bac8550e1c5466c0dac2934be3/dm_aux/spectral.py#L312 149 | 150 | https://en.wikipedia.org/wiki/Mel_scale 151 | 152 | Args: 153 | spectrograms: Input spectrograms of shape [batch_size, time_steps, 154 | num_features]. 155 | log_scale: Whether to return the mel_filterbanks in the log scale. 156 | sample_rate: The sample rate of the input audio. 157 | frame_length: The length of each spectrogram frame. 158 | num_features: The number of mel spectrogram features. 159 | lower_edge_hertz: Lowest frequency to consider to general mel filterbanks. 160 | upper_edge_hertz: Highest frequency to consider to general mel filterbanks. 161 | If None, use `sample_rate / 2.0`. 162 | 163 | Returns: 164 | Converted spectrograms in (log) Mel-scale. 165 | """ 166 | # This setup mimics tf.signal.linear_to_mel_weight_matrix. 167 | linear_to_mel_weight_matrix = librosa.filters.mel( 168 | sr=sample_rate, 169 | n_fft=frame_length, 170 | n_mels=num_features, 171 | fmin=lower_edge_hertz, 172 | fmax=upper_edge_hertz, 173 | ).T 174 | spectrograms = jnp.matmul(spectrograms, linear_to_mel_weight_matrix) 175 | 176 | if log_scale: 177 | spectrograms = jnp.log(spectrograms + 1e-6) 178 | return spectrograms 179 | 180 | 181 | def decibel_loudness(stft_data: jnp.ndarray, clamp_eps=1e-5, pow=2.0) -> jnp.ndarray: 182 | return jnp.log10(jnp.power(jnp.maximum(jnp.abs(stft_data), clamp_eps), pow)) 183 | 184 | 185 | def db2linear(decibels: jnp.ndarray): 186 | return jnp.pow(10.0, decibels / 20.0) 187 | 188 | 189 | def volume_norm( 190 | audio_data: jnp.ndarray, 191 | target_db: jnp.ndarray, 192 | sample_rate: int, 193 | filter_class: str = "K-weighting", 194 | block_size: float = 0.400, 195 | min_loudness: float = -70, 196 | zeros: int = 2048, 197 | ): 198 | """Calculates loudness using an implementation of ITU-R BS.1770-4. 199 | Allows control over gating block size and frequency weighting filters for 200 | additional control. Measure the integrated gated loudness of a signal. 201 | 202 | API is derived from PyLoudnorm, but this implementation is ported to PyTorch 203 | and is tensorized across batches. When on GPU, an FIR approximation of the IIR 204 | filters is used to compute loudness for speed. 205 | 206 | Uses the weighting filters and block size defined by the meter 207 | the integrated loudness is measured based upon the gating algorithm 208 | defined in the ITU-R BS.1770-4 specification. 209 | 210 | Parameters 211 | ---------- 212 | audio_data: jnp.ndarray 213 | audio signal [B, C, T] 214 | target_db: jnp.ndarray 215 | array of target decibel loudnesses [B] 216 | sample_rate: int 217 | sample rate of audio_data 218 | filter_class : str, optional 219 | Class of weighting filter used. 220 | K-weighting' (default), 'Fenton/Lee 1' 221 | 'Fenton/Lee 2', 'Dash et al.' 222 | by default "K-weighting" 223 | block_size : float, optional 224 | Gating block size in seconds, by default 0.400 225 | min_loudness : float, optional 226 | Minimum loudness in decibels 227 | zeros : int, optional 228 | The length of the FIR filter. You should pick a power of 2 between 512 and 4096. 229 | 230 | Returns 231 | ------- 232 | jnp.ndarray 233 | Audio normalized to `target_db` loudness 234 | jnp.ndarray 235 | Loudness of original audio data. 236 | 237 | Reference: https://github.com/descriptinc/audiotools/blob/master/audiotools/core/loudness.py 238 | """ 239 | 240 | padded_audio = audio_data 241 | 242 | original_length = padded_audio.shape[-1] 243 | signal_duration = original_length / sample_rate 244 | 245 | if signal_duration < block_size: 246 | padded_audio = jnp.pad( 247 | padded_audio, 248 | pad_width=( 249 | (0, 0), 250 | (0, 0), 251 | (0, int(block_size * sample_rate) - original_length), 252 | ), 253 | ) 254 | 255 | # create BS.1770 meter 256 | meter = jln.Meter( 257 | sample_rate, 258 | filter_class=filter_class, 259 | block_size=block_size, 260 | use_fir=True, 261 | zeros=zeros, 262 | ) 263 | 264 | # measure loudness 265 | loudness = jax.vmap(meter.integrated_loudness)( 266 | rearrange(padded_audio, "b c t -> b t c") 267 | ) 268 | 269 | loudness = jnp.maximum(loudness, jnp.full_like(loudness, min_loudness)) 270 | 271 | audio_data = audio_data * db2linear(target_db - loudness)[:, None, None] 272 | 273 | return audio_data, loudness 274 | -------------------------------------------------------------------------------- /src/dac_jax/compare/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DBraun/DAC-JAX/919ce4a2a9ec4c5c3fa7d10dcb2944259da00865/src/dac_jax/compare/__init__.py -------------------------------------------------------------------------------- /src/dac_jax/compare/encodec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from audiotools import AudioSignal 3 | from audiotools.ml import BaseModel 4 | from encodec import EncodecModel 5 | 6 | 7 | class Encodec(BaseModel): 8 | def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): 9 | super().__init__() 10 | 11 | if sample_rate == 24000: 12 | self.model = EncodecModel.encodec_model_24khz() 13 | else: 14 | self.model = EncodecModel.encodec_model_48khz() 15 | self.model.set_target_bandwidth(bandwidth) 16 | self.sample_rate = 44100 17 | 18 | def forward( 19 | self, 20 | audio_data: torch.Tensor, 21 | sample_rate: int = 44100, 22 | n_quantizers: int = None, 23 | ): 24 | signal = AudioSignal(audio_data, sample_rate) 25 | signal.resample(self.model.sample_rate) 26 | recons = self.model(signal.audio_data) 27 | recons = AudioSignal(recons, self.model.sample_rate) 28 | recons.resample(sample_rate) 29 | return {"audio": recons.audio_data} 30 | 31 | 32 | if __name__ == "__main__": 33 | import numpy as np 34 | from functools import partial 35 | 36 | model = Encodec() 37 | 38 | for n, m in model.named_modules(): 39 | o = m.extra_repr() 40 | p = sum([np.prod(p.size()) for p in m.parameters()]) 41 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 42 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 43 | print(model) 44 | print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) 45 | 46 | length = 88200 * 2 47 | x = torch.randn(1, 1, length).to(model.device) 48 | x.requires_grad_(True) 49 | x.retain_grad() 50 | 51 | # Make a forward pass 52 | out = model(x)["audio"] 53 | 54 | print(x.shape, out.shape) 55 | -------------------------------------------------------------------------------- /src/dac_jax/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .dac import DAC, DACFile 2 | from .discriminator import Discriminator 3 | from .encodec import SEANetEncoder, SEANetDecoder, EncodecModel 4 | -------------------------------------------------------------------------------- /src/dac_jax/model/core.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from flax import linen as nn 3 | from jax import numpy as jnp 4 | import typing as tp 5 | 6 | from dac_jax.nn.encodec_quantize import QuantizedResult 7 | 8 | 9 | class CompressionModel(ABC, nn.Module): 10 | """Base API for all compression models that aim at being used as audio tokenizers 11 | with a language model. 12 | """ 13 | 14 | @abstractmethod 15 | def __call__(self, x: jnp.ndarray) -> QuantizedResult: ... 16 | 17 | @abstractmethod 18 | def encode(self, x: jnp.ndarray) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]: 19 | """See `EncodecModel.encode`.""" 20 | ... 21 | 22 | @abstractmethod 23 | def decode( 24 | self, 25 | codes: jnp.ndarray, 26 | scale: tp.Optional[jnp.ndarray] = None, 27 | length: int = None, 28 | ): 29 | """See `EncodecModel.decode`.""" 30 | ... 31 | 32 | @abstractmethod 33 | def decode_latent(self, codes: jnp.ndarray): 34 | """Decode from the discrete codes to continuous latent space.""" 35 | ... 36 | 37 | @property 38 | @abstractmethod 39 | def channels(self) -> int: ... 40 | 41 | @property 42 | @abstractmethod 43 | def frame_rate(self) -> float: ... 44 | 45 | @property 46 | @abstractmethod 47 | def sample_rate(self) -> int: ... 48 | 49 | @property 50 | @abstractmethod 51 | def cardinality(self) -> int: ... 52 | 53 | @property 54 | @abstractmethod 55 | def num_codebooks(self) -> int: ... 56 | 57 | @property 58 | @abstractmethod 59 | def total_codebooks(self) -> int: ... 60 | -------------------------------------------------------------------------------- /src/dac_jax/model/discriminator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field 2 | 3 | from audiotree.resample import resample 4 | from einops import rearrange 5 | import flax.linen as nn 6 | import jax 7 | from jax import numpy as jnp 8 | 9 | from dac_jax.audio_utils import stft 10 | from dac_jax.nn.layers import make_initializer 11 | 12 | 13 | class LeakyReLU(nn.Module): 14 | 15 | negative_slope: float = 0.01 16 | 17 | @nn.compact 18 | def __call__(self, x): 19 | return nn.leaky_relu(x, negative_slope=self.negative_slope) 20 | 21 | 22 | class WNConv(nn.Conv): 23 | 24 | act: bool = True 25 | 26 | @nn.compact 27 | def __call__(self, x): 28 | 29 | kernel_init = make_initializer( 30 | x.shape[-1], 31 | self.features, 32 | self.kernel_size, 33 | self.feature_group_count, 34 | mode="fan_in", 35 | ) 36 | 37 | if self.use_bias: 38 | # note: we just ignore whatever self.bias_init is 39 | bias_init = make_initializer( 40 | x.shape[-1], 41 | self.features, 42 | self.kernel_size, 43 | self.feature_group_count, 44 | mode="fan_in", 45 | ) 46 | else: 47 | bias_init = None 48 | 49 | conv = nn.Conv( 50 | features=self.features, 51 | kernel_size=self.kernel_size, 52 | strides=self.strides, 53 | padding=self.padding, 54 | input_dilation=self.input_dilation, 55 | kernel_dilation=self.kernel_dilation, 56 | feature_group_count=self.feature_group_count, 57 | use_bias=self.use_bias, 58 | mask=self.mask, 59 | dtype=self.dtype, 60 | param_dtype=self.param_dtype, 61 | precision=self.precision, 62 | kernel_init=kernel_init, 63 | bias_init=bias_init, 64 | ) 65 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3)) 66 | block = nn.WeightNorm(conv, scale_init=scale_init) 67 | x = block(x) 68 | 69 | if self.act: 70 | x = LeakyReLU(0.1)(x) 71 | 72 | return x 73 | 74 | 75 | class MPD(nn.Module): 76 | 77 | period: int 78 | 79 | def pad_to_period(self, x): 80 | t = x.shape[-1] 81 | x = jnp.pad( 82 | x, 83 | pad_width=((0, 0), (0, 0), (0, self.period - t % self.period)), 84 | mode="reflect", 85 | ) 86 | return x 87 | 88 | @nn.compact 89 | def __call__(self, x): 90 | convs = [ 91 | WNConv( 92 | features=32, 93 | kernel_size=(5, 1), 94 | strides=(3, 1), 95 | padding=((2, 2), (0, 0)), 96 | ), 97 | WNConv( 98 | features=128, 99 | kernel_size=(5, 1), 100 | strides=(3, 1), 101 | padding=((2, 2), (0, 0)), 102 | ), 103 | WNConv( 104 | features=512, 105 | kernel_size=(5, 1), 106 | strides=(3, 1), 107 | padding=((2, 2), (0, 0)), 108 | ), 109 | WNConv( 110 | features=1024, 111 | kernel_size=(5, 1), 112 | strides=(3, 1), 113 | padding=((2, 2), (0, 0)), 114 | ), 115 | WNConv( 116 | features=1024, 117 | kernel_size=(5, 1), 118 | strides=(1, 1), 119 | padding=((2, 2), (0, 0)), 120 | ), 121 | WNConv(features=1, kernel_size=(3, 1), padding=((1, 1), (0, 0)), act=False), 122 | ] 123 | 124 | fmap = [] 125 | 126 | x = self.pad_to_period(x) 127 | x = rearrange(x, "b c (l p) -> b l p c", p=self.period) 128 | 129 | for layer in convs: 130 | x = layer(x) 131 | fmap.append(x) 132 | 133 | return fmap 134 | 135 | 136 | class MSD(nn.Module): 137 | 138 | rate: int = 1 139 | sample_rate: int = 44100 140 | 141 | @nn.compact 142 | def __call__(self, x): 143 | convs = [ 144 | WNConv(features=16, kernel_size=15, strides=1, padding=7), 145 | WNConv( 146 | features=64, 147 | kernel_size=41, 148 | strides=4, 149 | feature_group_count=4, 150 | padding=20, 151 | ), 152 | WNConv( 153 | features=256, 154 | kernel_size=41, 155 | strides=4, 156 | feature_group_count=16, 157 | padding=20, 158 | ), 159 | WNConv( 160 | features=1024, 161 | kernel_size=41, 162 | strides=4, 163 | feature_group_count=64, 164 | padding=20, 165 | ), 166 | WNConv( 167 | features=1024, 168 | kernel_size=41, 169 | strides=4, 170 | feature_group_count=256, 171 | padding=20, 172 | ), 173 | WNConv(features=1024, kernel_size=5, strides=1, padding=2), 174 | WNConv(features=1, kernel_size=3, strides=1, padding=1, act=False), 175 | ] 176 | 177 | x = resample(x, old_sr=self.sample_rate, new_sr=self.sample_rate // self.rate) 178 | 179 | x = rearrange(x, "b c l -> b l c") 180 | 181 | fmap = [] 182 | 183 | for layer in convs: 184 | x = layer(x) 185 | fmap.append(x) 186 | 187 | return fmap 188 | 189 | 190 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 191 | 192 | 193 | class MRD(nn.Module): 194 | 195 | window_length: int 196 | hop_factor: float = 0.25 197 | sample_rate: int = 44100 198 | bands: list = field(default_factory=lambda: BANDS) 199 | 200 | def __post_init__(self) -> None: 201 | n_fft = self.window_length // 2 + 1 202 | self.bands = [(int(low * n_fft), int(high * n_fft)) for low, high in self.bands] 203 | super().__post_init__() 204 | 205 | @nn.compact 206 | def __call__(self, x): 207 | """Complex multi-band spectrogram discriminator. 208 | Parameters 209 | ---------- 210 | window_length : int 211 | Window length of STFT. 212 | hop_factor : float, optional 213 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 214 | sample_rate : int, optional 215 | Sampling rate of audio in Hz, by default 44100 216 | bands : list, optional 217 | Bands to run discriminator over. 218 | """ 219 | 220 | ch = 32 221 | convs = lambda: [ 222 | WNConv( 223 | features=ch, 224 | kernel_size=(3, 9), 225 | strides=(1, 1), 226 | padding=((1, 1), (4, 4)), 227 | ), 228 | WNConv( 229 | features=ch, 230 | kernel_size=(3, 9), 231 | strides=(1, 2), 232 | padding=((1, 1), (4, 4)), 233 | ), 234 | WNConv( 235 | features=ch, 236 | kernel_size=(3, 9), 237 | strides=(1, 2), 238 | padding=((1, 1), (4, 4)), 239 | ), 240 | WNConv( 241 | features=ch, 242 | kernel_size=(3, 9), 243 | strides=(1, 2), 244 | padding=((1, 1), (4, 4)), 245 | ), 246 | WNConv( 247 | features=ch, 248 | kernel_size=(3, 3), 249 | strides=(1, 1), 250 | padding=((1, 1), (1, 1)), 251 | ), 252 | ] 253 | band_convs = [convs() for _ in range(len(self.bands))] 254 | conv_post = WNConv( 255 | features=1, 256 | kernel_size=(3, 3), 257 | strides=(1, 1), 258 | padding=((1, 1), (1, 1)), 259 | act=False, 260 | ) 261 | 262 | x_bands = self.get_bands(x) 263 | fmap = [] 264 | 265 | x = [] 266 | for band, stack in zip(x_bands, band_convs): 267 | band = rearrange(band, "b c t f -> b t f c") 268 | for layer in stack: 269 | band = layer(band) 270 | fmap.append(band) 271 | x.append(band) 272 | 273 | x = jnp.concatenate(x, axis=-2) # concatenate along frequency axis 274 | x = conv_post(x) 275 | fmap.append(x) 276 | 277 | return fmap 278 | 279 | def get_bands(self, x): 280 | stft_data = stft( 281 | x, 282 | frame_length=self.window_length, 283 | hop_factor=self.hop_factor, 284 | match_stride=True, 285 | ) 286 | x = self.as_real(stft_data) 287 | x = rearrange( 288 | x, "b c f t ri -> (b c) ri t f", c=1, ri=2 289 | ) # ri is 2 for real and imaginary 290 | # Split into bands 291 | x_bands = [x[..., low:high] for low, high in self.bands] 292 | return x_bands 293 | 294 | @staticmethod 295 | def as_real(x: jnp.ndarray) -> jnp.ndarray: 296 | # https://github.com/google/jax/issues/9496#issuecomment-1033961377 297 | if not jnp.issubdtype(x.dtype, jnp.complexfloating): 298 | return x 299 | 300 | return jnp.stack([x.real, x.imag], axis=-1) 301 | 302 | 303 | class Discriminator(nn.Module): 304 | 305 | rates: list = field(default_factory=lambda: []) 306 | periods: list = field(default_factory=lambda: [2, 3, 5, 7, 11]) 307 | fft_sizes: list = field(default_factory=lambda: [2048, 1024, 512]) 308 | sample_rate: int = 44100 309 | bands: list = field(default_factory=lambda: BANDS) 310 | 311 | @staticmethod 312 | def preprocess(y: jnp.ndarray): 313 | # Remove DC offset 314 | y = y - y.mean(axis=-1, keepdims=True) 315 | # Peak normalize the volume of input audio 316 | y = 0.8 * y / (jnp.abs(y).max(axis=-1, keepdims=True) + 1e-9) 317 | return y 318 | 319 | @nn.compact 320 | def __call__(self, x): 321 | """Discriminator that combines multiple discriminators. 322 | 323 | Parameters 324 | ---------- 325 | rates : list, optional 326 | sampling rates (in Hz) to run MSD at, by default [] 327 | If empty, MSD is not used. 328 | periods : list, optional 329 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] 330 | fft_sizes : list, optional 331 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] 332 | sample_rate : int, optional 333 | Sampling rate of audio in Hz, by default 44100 334 | bands : list, optional 335 | Bands to run MRD at, by default `BANDS` 336 | """ 337 | discriminators = [] 338 | discriminators += [MPD(p) for p in self.periods] 339 | discriminators += [MSD(r, sample_rate=self.sample_rate) for r in self.rates] 340 | discriminators += [ 341 | MRD(f, sample_rate=self.sample_rate, bands=self.bands) 342 | for f in self.fft_sizes 343 | ] 344 | x = self.preprocess(x) 345 | fmaps = [d(x) for d in discriminators] 346 | return fmaps 347 | 348 | 349 | if __name__ == "__main__": 350 | import numpy as np 351 | 352 | disc = Discriminator() 353 | x = jnp.zeros(shape=(1, 1, 44100)) 354 | 355 | print( 356 | disc.tabulate( 357 | jax.random.key(1), 358 | x, 359 | # compute_flops=True, 360 | # compute_vjp_flops=True, 361 | depth=3, 362 | # column_kwargs={"width": 400}, 363 | console_kwargs={"width": 400}, 364 | ) 365 | ) 366 | 367 | results, variables = disc.init_with_output(jax.random.key(3), x) 368 | 369 | for i, result in enumerate(results): 370 | print(f"disc{i}") 371 | for i, _r in enumerate(result): 372 | r = np.array(_r) 373 | print( 374 | r.shape, 375 | f"{r.mean().item():,.5f}, {r.min().item():,.5f} {r.max().item():,.5f}", 376 | ) 377 | print("All Done!") 378 | -------------------------------------------------------------------------------- /src/dac_jax/model/encodec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import field 8 | import typing as tp 9 | 10 | from einops import rearrange 11 | from flax import linen as nn 12 | from jax import numpy as jnp 13 | import numpy as np 14 | 15 | from dac_jax.model.core import CompressionModel 16 | from dac_jax.nn.quantize import QuantizedResult 17 | 18 | from dac_jax.nn.encodec_layers import ( 19 | StreamableConv1d, 20 | StreamableConvTranspose1d, 21 | StreamableLSTM, 22 | ) 23 | 24 | 25 | class SEANetResnetBlock(nn.Module): 26 | """Residual block from SEANet model. 27 | 28 | Args: 29 | dim (int): Dimension of the input/output. 30 | kernel_sizes (list): List of kernel sizes for the convolutions. 31 | dilations (list): List of dilations for the convolutions. 32 | activation (str): Activation function. 33 | activation_params (dict): Parameters to provide to the activation function. 34 | norm (str): Normalization method. 35 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 36 | causal (bool): Whether to use fully causal convolution. 37 | pad_mode (str): Padding mode for the convolutions. 38 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 39 | true_skip (bool): Whether to use true skip connection or a simple 40 | (streamable) convolution as the skip connection. 41 | """ 42 | 43 | dim: int 44 | kernel_sizes: tp.List[int] = field(default_factory=lambda: [3, 1]) 45 | dilations: tp.List[int] = field(default_factory=lambda: [1, 1]) 46 | activation: str = "elu" 47 | activation_params: dict = field(default_factory=lambda: {"alpha": 1.0}) 48 | norm: str = "none" 49 | norm_params: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 50 | causal: int = 0 # bool 51 | pad_mode: str = "reflect" 52 | compress: int = 2 53 | true_skip: int = 1 # bool 54 | 55 | @nn.compact 56 | def __call__(self, x): 57 | assert len(self.kernel_sizes) == len( 58 | self.dilations 59 | ), "Number of kernel sizes should match number of dilations" 60 | act = lambda y: getattr(nn.activation, self.activation)( 61 | y, **self.activation_params 62 | ) 63 | hidden = self.dim // self.compress 64 | block = [] 65 | for i, (kernel_size, dilation) in enumerate( 66 | zip(self.kernel_sizes, self.dilations) 67 | ): 68 | out_chs = self.dim if i == len(self.kernel_sizes) - 1 else hidden 69 | block += [ 70 | act, 71 | StreamableConv1d( 72 | out_chs, 73 | kernel_size=kernel_size, 74 | dilation=dilation, 75 | norm=self.norm, 76 | norm_kwargs=self.norm_params, 77 | causal=self.causal, 78 | pad_mode=self.pad_mode, 79 | ), 80 | ] 81 | block = nn.Sequential(block) 82 | if self.true_skip: 83 | return x + block(x) 84 | else: 85 | shortcut = StreamableConv1d( 86 | self.dim, 87 | kernel_size=1, 88 | norm=self.norm, 89 | norm_kwargs=self.norm_params, 90 | causal=self.causal, 91 | pad_mode=self.pad_mode, 92 | ) 93 | 94 | return shortcut(x) + block(x) 95 | 96 | 97 | class SEANetEncoder(nn.Module): 98 | """SEANet encoder. 99 | 100 | Args: 101 | channels (int): Audio channels. 102 | dimension (int): Intermediate representation dimension. 103 | n_filters (int): Base width for the model. 104 | n_residual_layers (int): nb of residual layers. 105 | ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of 106 | upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here 107 | that must match the decoder order. We use the decoder order as some models may only employ the decoder. 108 | activation (str): Activation function. 109 | activation_params (dict): Parameters to provide to the activation function. 110 | norm (str): Normalization method. 111 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 112 | kernel_size (int): Kernel size for the initial convolution. 113 | last_kernel_size (int): Kernel size for the initial convolution. 114 | residual_kernel_size (int): Kernel size for the residual layers. 115 | dilation_base (int): How much to increase the dilation with each layer. 116 | causal (bool): Whether to use fully causal convolution. 117 | pad_mode (str): Padding mode for the convolutions. 118 | true_skip (bool): Whether to use true skip connection or a simple 119 | (streamable) convolution as the skip connection in the residual network blocks. 120 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 121 | lstm (int): Number of LSTM layers at the end of the encoder. 122 | disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. 123 | For the encoder, it corresponds to the N first blocks. 124 | """ 125 | 126 | channels: int = 1 127 | dimension: int = 128 128 | n_filters: int = 32 129 | n_residual_layers: int = 3 130 | ratios: tp.List[int] = field(default_factory=lambda: [8, 5, 4, 2]) 131 | activation: str = "elu" 132 | activation_params: dict = field(default_factory=lambda: {"alpha": 1.0}) 133 | norm: str = "none" 134 | norm_params: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 135 | kernel_size: int = 7 136 | last_kernel_size: int = 7 137 | residual_kernel_size: int = 3 138 | dilation_base: int = 2 139 | causal: bool = False 140 | pad_mode: str = "reflect" 141 | true_skip: bool = True 142 | compress: int = 2 143 | lstm: int = 0 144 | disable_norm_outer_blocks: int = 0 145 | 146 | def __post_init__(self) -> None: 147 | self.hop_length = np.prod(self.ratios) 148 | self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks 149 | assert ( 150 | self.disable_norm_outer_blocks >= 0 151 | and self.disable_norm_outer_blocks <= self.n_blocks 152 | ), ( 153 | "Number of blocks for which to disable norm is invalid." 154 | "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." 155 | ) 156 | super().__post_init__() 157 | 158 | @nn.compact 159 | def __call__(self, x): 160 | act = lambda y: getattr(nn.activation, self.activation)( 161 | y, **self.activation_params 162 | ) 163 | mult = 1 164 | layers = [ 165 | StreamableConv1d( 166 | mult * self.n_filters, 167 | kernel_size=self.kernel_size, 168 | norm="none" if self.disable_norm_outer_blocks >= 1 else self.norm, 169 | norm_kwargs=self.norm_params, 170 | causal=self.causal, 171 | pad_mode=self.pad_mode, 172 | ) 173 | ] 174 | # Downsample to raw audio scale 175 | for i, ratio in enumerate(reversed(self.ratios)): 176 | block_norm = ( 177 | "none" if self.disable_norm_outer_blocks >= i + 2 else self.norm 178 | ) 179 | # Add residual layers 180 | for j in range(self.n_residual_layers): 181 | layers += [ 182 | SEANetResnetBlock( 183 | mult * self.n_filters, 184 | kernel_sizes=[self.residual_kernel_size, 1], 185 | dilations=[self.dilation_base**j, 1], 186 | norm=block_norm, 187 | norm_params=self.norm_params, 188 | activation=self.activation, 189 | activation_params=self.activation_params, 190 | causal=self.causal, 191 | pad_mode=self.pad_mode, 192 | compress=self.compress, 193 | true_skip=self.true_skip, 194 | ) 195 | ] 196 | 197 | # Add downsampling layers 198 | layers += [ 199 | act, 200 | StreamableConv1d( 201 | mult * self.n_filters * 2, 202 | kernel_size=ratio * 2, 203 | stride=ratio, 204 | norm=block_norm, 205 | norm_kwargs=self.norm_params, 206 | causal=self.causal, 207 | pad_mode=self.pad_mode, 208 | ), 209 | ] 210 | mult *= 2 211 | 212 | if self.lstm: 213 | layers += [StreamableLSTM(mult * self.n_filters, num_layers=self.lstm)] 214 | 215 | layers += [ 216 | act, 217 | StreamableConv1d( 218 | self.dimension, 219 | kernel_size=self.last_kernel_size, 220 | norm=( 221 | "none" 222 | if self.disable_norm_outer_blocks == self.n_blocks 223 | else self.norm 224 | ), 225 | norm_kwargs=self.norm_params, 226 | causal=self.causal, 227 | pad_mode=self.pad_mode, 228 | ), 229 | ] 230 | 231 | model = nn.Sequential(layers) 232 | x = rearrange(x, "B C T -> B T C") 233 | return model(x) 234 | 235 | 236 | class SEANetDecoder(nn.Module): 237 | """SEANet decoder. 238 | 239 | Args: 240 | channels (int): Audio channels. 241 | dimension (int): Intermediate representation dimension. 242 | n_filters (int): Base width for the model. 243 | n_residual_layers (int): nb of residual layers. 244 | ratios (Sequence[int]): kernel size and stride ratios. 245 | activation (str): Activation function. 246 | activation_params (dict): Parameters to provide to the activation function. 247 | final_activation (str): Final activation function after all convolutions. 248 | final_activation_params (dict): Parameters to provide to the activation function. 249 | norm (str): Normalization method. 250 | norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. 251 | kernel_size (int): Kernel size for the initial convolution. 252 | last_kernel_size (int): Kernel size for the initial convolution. 253 | residual_kernel_size (int): Kernel size for the residual layers. 254 | dilation_base (int): How much to increase the dilation with each layer. 255 | causal (bool): Whether to use fully causal convolution. 256 | pad_mode (str): Padding mode for the convolutions. 257 | true_skip (bool): Whether to use true skip connection or a simple. 258 | (streamable) convolution as the skip connection in the residual network blocks. 259 | compress (int): Reduced dimensionality in residual branches (from Demucs v3). 260 | lstm (int): Number of LSTM layers at the end of the encoder. 261 | disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. 262 | For the decoder, it corresponds to the N last blocks. 263 | trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. 264 | If equal to 1.0, it means that all the trimming is done at the right. 265 | """ 266 | 267 | channels: int = 1 268 | dimension: int = 128 269 | n_filters: int = 32 270 | n_residual_layers: int = 3 271 | ratios: tp.List[int] = field(default_factory=lambda: [8, 5, 4, 2]) 272 | activation: str = "elu" 273 | activation_params: dict = field(default_factory=lambda: {"alpha": 1.0}) 274 | final_activation: tp.Optional[str] = None 275 | final_activation_params: tp.Optional[dict] = None 276 | norm: str = "none" 277 | norm_params: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 278 | kernel_size: int = 7 279 | last_kernel_size: int = 7 280 | residual_kernel_size: int = 3 281 | dilation_base: int = 2 282 | causal: bool = False 283 | pad_mode: str = "reflect" 284 | true_skip: bool = True 285 | compress: int = 2 286 | lstm: int = 0 287 | disable_norm_outer_blocks: int = 0 288 | trim_right_ratio: float = 1.0 289 | 290 | def __post_init__(self) -> None: 291 | self.hop_length = np.prod(self.ratios) 292 | self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks 293 | assert ( 294 | self.disable_norm_outer_blocks >= 0 295 | and self.disable_norm_outer_blocks <= self.n_blocks 296 | ), ( 297 | "Number of blocks for which to disable norm is invalid." 298 | "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." 299 | ) 300 | super().__post_init__() 301 | 302 | @nn.compact 303 | def __call__(self, z): 304 | z = z.transpose(0, 2, 1) 305 | act = lambda y: getattr(nn.activation, self.activation)( 306 | y, **self.activation_params 307 | ) 308 | mult = int(2 ** len(self.ratios)) 309 | layers = [ 310 | StreamableConv1d( 311 | mult * self.n_filters, 312 | kernel_size=self.kernel_size, 313 | norm=( 314 | "none" 315 | if self.disable_norm_outer_blocks == self.n_blocks 316 | else self.norm 317 | ), 318 | norm_kwargs=self.norm_params, 319 | causal=self.causal, 320 | pad_mode=self.pad_mode, 321 | ) 322 | ] 323 | 324 | if self.lstm: 325 | layers += [StreamableLSTM(mult * self.n_filters, num_layers=self.lstm)] 326 | 327 | # Upsample to raw audio scale 328 | for i, ratio in enumerate(self.ratios): 329 | block_norm = ( 330 | "none" 331 | if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) 332 | else self.norm 333 | ) 334 | # Add upsampling layers 335 | layers += [ 336 | act, 337 | StreamableConvTranspose1d( 338 | mult * self.n_filters // 2, 339 | kernel_size=ratio * 2, 340 | stride=ratio, 341 | norm=block_norm, 342 | norm_kwargs=self.norm_params, 343 | causal=self.causal, 344 | trim_right_ratio=self.trim_right_ratio, 345 | ), 346 | ] 347 | # Add residual layers 348 | for j in range(self.n_residual_layers): 349 | layers += [ 350 | SEANetResnetBlock( 351 | mult * self.n_filters // 2, 352 | kernel_sizes=[self.residual_kernel_size, 1], 353 | dilations=[self.dilation_base**j, 1], 354 | activation=self.activation, 355 | activation_params=self.activation_params, 356 | norm=block_norm, 357 | norm_params=self.norm_params, 358 | causal=self.causal, 359 | pad_mode=self.pad_mode, 360 | compress=self.compress, 361 | true_skip=self.true_skip, 362 | ) 363 | ] 364 | 365 | mult //= 2 366 | 367 | # Add final layers 368 | layers += [ 369 | act, 370 | StreamableConv1d( 371 | self.channels, 372 | kernel_size=self.last_kernel_size, 373 | norm="none" if self.disable_norm_outer_blocks >= 1 else self.norm, 374 | norm_kwargs=self.norm_params, 375 | causal=self.causal, 376 | pad_mode=self.pad_mode, 377 | ), 378 | ] 379 | # Add optional final activation to decoder (eg. tanh) 380 | if self.final_activation is not None: 381 | final_act = getattr(nn, self.final_activation) 382 | final_activation_params = self.final_activation_params or {} 383 | layers += [final_act(**final_activation_params)] 384 | model = nn.Sequential(layers) 385 | y = model(z) 386 | y = rearrange(y, "B T C -> B C T") 387 | return y 388 | 389 | 390 | class EncodecModel(CompressionModel): 391 | """Encodec model operating on the raw waveform. 392 | 393 | Args: 394 | encoder (nn.Module): Encoder network. 395 | decoder (nn.Module): Decoder network. 396 | quantizer (qt.BaseQuantizer): Quantizer network. 397 | frame_rate (int): Frame rate for the latent representation. 398 | sample_rate (int): Audio sample rate. 399 | channels (int): Number of audio channels. 400 | causal (bool): Whether to use a causal version of the model. 401 | renormalize (bool): Whether to renormalize the audio before running the model. 402 | """ 403 | 404 | encoder: nn.Module 405 | decoder: nn.Module 406 | quantizer: nn.Module # todo: qt.BaseQuantizer, 407 | causal: int = 0 # bool 408 | renormalize: int = 0 # bool 409 | 410 | # todo: must declare these? 411 | frame_rate: float = 0 # todo: or int? 412 | sample_rate: int = 0 413 | channels: int = 0 414 | 415 | def __post_init__(self) -> None: 416 | if self.causal: 417 | # we force disabling here to avoid handling linear overlap of segments 418 | # as supported in original EnCodec codebase. 419 | assert not self.renormalize, "Causal model does not support renormalize" 420 | super().__post_init__() 421 | 422 | @property 423 | def total_codebooks(self): 424 | """Total number of quantizer codebooks available.""" 425 | return self.quantizer.total_codebooks 426 | 427 | @property 428 | def num_codebooks(self): 429 | """Active number of codebooks used by the quantizer.""" 430 | return self.quantizer.num_codebooks 431 | 432 | def set_num_codebooks(self, n: int): 433 | """Set the active number of codebooks used by the quantizer.""" 434 | self.quantizer.set_num_codebooks(n) 435 | 436 | @property 437 | def cardinality(self): 438 | """Cardinality of each codebook.""" 439 | return self.quantizer.bins 440 | 441 | def preprocess( 442 | self, x: jnp.ndarray 443 | ) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]: 444 | scale: tp.Optional[jnp.ndarray] 445 | if self.renormalize: 446 | mono = x.mean(axis=1, keepdims=True) 447 | volume = jnp.sqrt(jnp.square(mono).mean(axis=2, keepdims=True)) 448 | scale = 1e-8 + volume 449 | x = x / scale 450 | scale = scale.reshape(-1, 1) 451 | else: 452 | scale = None 453 | return x, scale 454 | 455 | def postprocess( 456 | self, x: jnp.ndarray, scale: tp.Optional[jnp.ndarray] = None 457 | ) -> jnp.ndarray: 458 | if scale is not None: 459 | assert self.renormalize 460 | x = x * scale.reshape(-1, 1, 1) 461 | return x 462 | 463 | def __call__(self, x: jnp.ndarray, train=False) -> QuantizedResult: 464 | assert x.ndim == 3 465 | length = x.shape[-1] 466 | x, scale = self.preprocess(x) 467 | 468 | emb = self.encoder(x) 469 | q_res: QuantizedResult = self.quantizer(emb, self.frame_rate, train=train) 470 | out = self.decoder(q_res.z) 471 | 472 | # remove extra padding added by the encoder and decoder 473 | assert out.shape[-1] >= length, (out.shape[-1], length) 474 | out = out[..., :length] 475 | 476 | q_res.recons = self.postprocess(out, scale) 477 | 478 | return q_res 479 | 480 | def encode( 481 | self, x: jnp.ndarray, n_quantizers: int = None 482 | ) -> tp.Tuple[jnp.ndarray, tp.Optional[jnp.ndarray]]: 483 | """Encode the given input tensor to quantized representation along with scale parameter. 484 | 485 | Args: 486 | x (jnp.ndarray): Float tensor of shape [B, C, T] 487 | 488 | Returns: 489 | codes, scale (tuple of jnp.ndarray, jnp.ndarray): Tuple composed of: 490 | codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. 491 | scale: a float tensor containing the scale for audio renormalization. 492 | """ 493 | assert x.ndim == 3 494 | x, scale = self.preprocess(x) 495 | emb = self.encoder(x) 496 | emb = emb.transpose(0, 2, 1) 497 | codes = self.quantizer.encode(emb, n_quantizers) 498 | return codes, scale 499 | 500 | def decode( 501 | self, 502 | codes: jnp.ndarray, 503 | scale: tp.Optional[jnp.ndarray] = None, 504 | length: int = None, 505 | ): 506 | """Decode the given codes to a reconstructed representation, using the scale to perform 507 | audio denormalization if needed. 508 | 509 | Args: 510 | codes (jnp.ndarray): Int tensor of shape [B, K, T] 511 | scale (jnp.ndarray, optional): Float tensor containing the scale value. 512 | 513 | Returns: 514 | out (jnp.ndarray): Float tensor of shape [B, C, T], the reconstructed audio. 515 | """ 516 | emb = self.decode_latent(codes) 517 | out = self.decoder(emb) 518 | out = self.postprocess(out, scale) 519 | 520 | # remove extra padding added by the encoder and decoder 521 | if length is not None: 522 | out = out[..., :length] 523 | return out 524 | 525 | def decode_latent(self, codes: jnp.ndarray): 526 | """Decode from the discrete codes to continuous latent space.""" 527 | return self.quantizer.decode(codes) 528 | -------------------------------------------------------------------------------- /src/dac_jax/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import loss 3 | from . import quantize 4 | -------------------------------------------------------------------------------- /src/dac_jax/nn/encodec_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import field 8 | import math 9 | import typing as tp 10 | import warnings 11 | 12 | from flax import linen as nn 13 | from jax import numpy as jnp 14 | 15 | from dac_jax.nn.layers import make_initializer 16 | 17 | 18 | CONV_NORMALIZATIONS = frozenset( 19 | ["none", "weight_norm", "spectral_norm", "time_group_norm"] 20 | ) 21 | 22 | 23 | def apply_parametrization_norm(module: nn.Module, norm: str = "none"): 24 | assert norm in CONV_NORMALIZATIONS 25 | if norm == "weight_norm": 26 | # why we use scale_init: https://github.com/google/flax/issues/4138 27 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3)) 28 | return nn.WeightNorm(module, scale_init=scale_init) 29 | elif norm == "spectral_norm": 30 | return nn.SpectralNorm(module) 31 | else: 32 | # We already check was in CONV_NORMALIZATION, so any other choice 33 | # doesn't need reparametrization. 34 | return module 35 | 36 | 37 | def get_norm_module( 38 | module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs 39 | ): 40 | """Return the proper normalization module. If causal is True, this will ensure the returned 41 | module is causal, or return an error if the normalization doesn't support causal evaluation. 42 | """ 43 | assert norm in CONV_NORMALIZATIONS 44 | if norm == "time_group_norm": 45 | if causal: 46 | raise ValueError("GroupNorm doesn't support causal evaluation.") 47 | assert isinstance(module, nn.Conv) 48 | return nn.GroupNorm(num_groups=1, **norm_kwargs) 49 | else: 50 | return lambda x: x 51 | 52 | 53 | def get_extra_padding_for_conv1d( 54 | x: jnp.ndarray, kernel_size: int, stride: int, padding_total: int = 0 55 | ) -> int: 56 | """See `pad_for_conv1d`.""" 57 | length = x.shape[-2] 58 | n_frames = (length - kernel_size + padding_total) / stride + 1 59 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) 60 | return ideal_length - length 61 | 62 | 63 | def pad_for_conv1d( 64 | x: jnp.ndarray, kernel_size: int, stride: int, padding_total: int = 0 65 | ): 66 | """Pad for a convolution to make sure that the last window is full. 67 | Extra padding is added at the end. This is required to ensure that we can rebuild 68 | an output of the same length, as otherwise, even with padding, some time steps 69 | might get removed. 70 | For instance, with total padding = 4, kernel size = 4, stride = 2: 71 | 0 0 1 2 3 4 5 0 0 # (0s are padding) 72 | 1 2 3 # (output frames of a convolution, last 0 is never used) 73 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 74 | 1 2 3 4 # once you removed padding, we are missing one time step ! 75 | """ 76 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 77 | return jnp.pad(x, ((0, 0), (0, extra_padding), (0, 0))) 78 | 79 | 80 | def pad1d( 81 | x: jnp.ndarray, 82 | paddings: tp.Tuple[int, int], 83 | mode: str = "constant", 84 | value: float = 0.0, 85 | ): 86 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input. 87 | If this is the case, we insert extra 0 padding to the right before the reflection happen. 88 | """ 89 | length = x.shape[-2] 90 | padding_left, padding_right = paddings 91 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 92 | if mode == "constant": 93 | pad_kwargs = {"constant_values": value} 94 | else: 95 | pad_kwargs = {} 96 | if mode == "reflect": 97 | max_pad = max(padding_left, padding_right) 98 | extra_pad = 0 99 | if length <= max_pad: 100 | extra_pad = max_pad - length + 1 101 | x = jnp.pad(x, ((0, 0), (0, extra_pad), (0, 0))) 102 | padded = jnp.pad( 103 | x, pad_width=((0, 0), paddings, (0, 0)), mode=mode, **pad_kwargs 104 | ) 105 | end = padded.shape[-2] - extra_pad 106 | return padded[:, :end, :] 107 | else: 108 | return jnp.pad(x, pad_width=((0, 0), paddings, (0, 0)), mode=mode, **pad_kwargs) 109 | 110 | 111 | def unpad1d(x: jnp.ndarray, paddings: tp.Tuple[int, int]): 112 | """Remove padding from x, handling properly zero padding. Only for 1d!""" 113 | padding_left, padding_right = paddings 114 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 115 | assert (padding_left + padding_right) <= x.shape[-2] 116 | end = x.shape[-2] - padding_right 117 | return x[:, padding_left:end, :] 118 | 119 | 120 | class NormConv1d(nn.Conv): 121 | """Wrapper around Conv and normalization applied to this conv 122 | to provide a uniform interface across normalization approaches. 123 | """ 124 | 125 | causal: bool = False 126 | norm: str = "none" 127 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 128 | 129 | @nn.compact 130 | def __call__(self, x): 131 | 132 | # note: we just ignore whatever self.kernel_init is 133 | kernel_init = make_initializer( 134 | x.shape[-1], 135 | self.features, 136 | self.kernel_size, 137 | self.feature_group_count, 138 | mode="fan_in", 139 | ) 140 | 141 | if self.use_bias: 142 | # note: we just ignore whatever self.bias_init is 143 | bias_init = make_initializer( 144 | x.shape[-1], 145 | self.features, 146 | self.kernel_size, 147 | self.feature_group_count, 148 | mode="fan_in", 149 | ) 150 | else: 151 | bias_init = None 152 | 153 | conv = nn.Conv( 154 | features=self.features, 155 | kernel_size=(self.kernel_size,), 156 | strides=(self.strides,), 157 | padding="VALID", 158 | input_dilation=self.input_dilation, 159 | kernel_dilation=self.kernel_dilation, 160 | feature_group_count=self.feature_group_count, 161 | use_bias=self.use_bias, 162 | mask=self.mask, 163 | dtype=self.dtype, 164 | param_dtype=self.param_dtype, 165 | precision=self.precision, 166 | kernel_init=kernel_init, 167 | bias_init=bias_init, 168 | ) 169 | conv = apply_parametrization_norm(conv, self.norm) 170 | norm = get_norm_module(conv, self.causal, self.norm, **self.norm_kwargs) 171 | x = conv(x) 172 | x = norm(x) 173 | return x 174 | 175 | 176 | class NormConv2d(nn.Conv): 177 | """Wrapper around Conv and normalization applied to this conv 178 | to provide a uniform interface across normalization approaches. 179 | """ 180 | 181 | norm: str = "none" 182 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 183 | 184 | @nn.compact 185 | def __call__(self, x): 186 | 187 | # note: we just ignore whatever self.kernel_init is 188 | kernel_init = make_initializer( 189 | x.shape[-1], 190 | self.features, 191 | self.kernel_size, 192 | self.feature_group_count, 193 | mode="fan_in", 194 | ) 195 | 196 | if self.use_bias: 197 | # note: we just ignore whatever self.bias_init is 198 | bias_init = make_initializer( 199 | x.shape[-1], 200 | self.features, 201 | self.kernel_size, 202 | self.feature_group_count, 203 | mode="fan_in", 204 | ) 205 | else: 206 | bias_init = None 207 | 208 | conv = nn.Conv( 209 | features=self.features, 210 | kernel_size=self.kernel_size, 211 | strides=self.strides, 212 | padding="VALID", 213 | input_dilation=self.input_dilation, 214 | kernel_dilation=self.kernel_dilation, 215 | feature_group_count=self.feature_group_count, 216 | use_bias=self.use_bias, 217 | mask=self.mask, 218 | dtype=self.dtype, 219 | param_dtype=self.param_dtype, 220 | precision=self.precision, 221 | kernel_init=kernel_init, 222 | bias_init=bias_init, 223 | ) 224 | conv = apply_parametrization_norm(conv, self.norm) 225 | norm = get_norm_module(conv, causal=False, norm=self.norm, **self.norm_kwargs) 226 | x = conv(x) 227 | x = norm(x) 228 | return x 229 | 230 | 231 | class NormConvTranspose1d(nn.ConvTranspose): 232 | """Wrapper around ConvTranspose1d and normalization applied to this conv 233 | to provide a uniform interface across normalization approaches. 234 | """ 235 | 236 | causal: bool = False 237 | norm: str = "none" 238 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 239 | 240 | @nn.compact 241 | def __call__(self, x): 242 | groups = 1 243 | # note: we just ignore whatever self.kernel_init is 244 | kernel_init = make_initializer( 245 | x.shape[-1], 246 | self.features, 247 | self.kernel_size, 248 | groups, 249 | mode="fan_out", 250 | ) 251 | 252 | if self.use_bias: 253 | # note: we just ignore whatever self.bias_init is 254 | bias_init = make_initializer( 255 | x.shape[-1], 256 | self.features, 257 | self.kernel_size, 258 | groups, 259 | mode="fan_out", 260 | ) 261 | else: 262 | bias_init = None 263 | 264 | convtr = nn.ConvTranspose( 265 | features=self.features, 266 | kernel_size=self.kernel_size, 267 | strides=self.strides, 268 | padding="VALID", 269 | kernel_dilation=self.kernel_dilation, 270 | use_bias=self.use_bias, 271 | mask=self.mask, 272 | dtype=self.dtype, 273 | param_dtype=self.param_dtype, 274 | precision=self.precision, 275 | kernel_init=kernel_init, 276 | bias_init=bias_init, 277 | transpose_kernel=True, # note: this helps us load weights from PyTorch 278 | ) 279 | convtr = apply_parametrization_norm(convtr, self.norm) 280 | norm = get_norm_module(convtr, self.causal, self.norm, **self.norm_kwargs) 281 | x = convtr(x) 282 | x = norm(x) 283 | return x 284 | 285 | 286 | class NormConvTranspose2d(nn.ConvTranspose): 287 | """Wrapper around ConvTranspose2d and normalization applied to this conv 288 | to provide a uniform interface across normalization approaches. 289 | """ 290 | 291 | norm: str = "none" 292 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 293 | 294 | @nn.compact 295 | def __call__(self, x): 296 | groups = 1 297 | # note: we just ignore whatever self.kernel_init is 298 | kernel_init = make_initializer( 299 | x.shape[-1], 300 | self.features, 301 | self.kernel_size, 302 | groups, 303 | mode="fan_out", 304 | ) 305 | 306 | if self.use_bias: 307 | # note: we just ignore whatever self.bias_init is 308 | bias_init = make_initializer( 309 | x.shape[-1], 310 | self.features, 311 | self.kernel_size, 312 | groups, 313 | mode="fan_out", 314 | ) 315 | else: 316 | bias_init = None 317 | 318 | convtr = nn.ConvTranspose( 319 | features=self.features, 320 | kernel_size=self.kernel_size, 321 | strides=self.strides, 322 | padding="VALID", 323 | kernel_dilation=self.kernel_dilation, 324 | use_bias=self.use_bias, 325 | mask=self.mask, 326 | dtype=self.dtype, 327 | param_dtype=self.param_dtype, 328 | precision=self.precision, 329 | kernel_init=kernel_init, 330 | bias_init=bias_init, 331 | transpose_kernel=True, # note: this helps us load weights from PyTorch 332 | ) 333 | convtr = apply_parametrization_norm(convtr, self.norm) 334 | norm = get_norm_module(convtr, causal=False, norm=self.norm, **self.norm_kwargs) 335 | x = convtr(x) 336 | x = norm(x) 337 | return x 338 | 339 | 340 | class StreamableConv1d(nn.Module): 341 | """Conv1d with some builtin handling of asymmetric or causal padding 342 | and normalization. 343 | """ 344 | 345 | out_channels: int 346 | kernel_size: int 347 | stride: int = 1 348 | dilation: int = 1 349 | groups: int = 1 350 | bias: bool = True 351 | causal: bool = False 352 | norm: str = "none" 353 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 354 | pad_mode: str = "reflect" 355 | 356 | def __post_init__(self) -> None: 357 | # warn user on unusual setup between dilation and stride 358 | if self.stride > 1 and self.dilation > 1: 359 | warnings.warn( 360 | "StreamableConv1d has been initialized with stride > 1 and dilation > 1" 361 | f" (kernel_size={self.kernel_size} stride={self.stride}, dilation={self.dilation})." 362 | ) 363 | super().__post_init__() 364 | 365 | @nn.compact 366 | def __call__(self, x): 367 | conv = NormConv1d( 368 | self.out_channels, 369 | kernel_size=self.kernel_size, 370 | strides=self.stride, 371 | kernel_dilation=self.dilation, 372 | feature_group_count=self.groups, 373 | use_bias=self.bias, 374 | causal=self.causal, 375 | norm=self.norm, 376 | norm_kwargs=self.norm_kwargs, 377 | ) 378 | B, T, C = x.shape 379 | kernel_size = conv.kernel_size 380 | stride = conv.strides 381 | dilation = conv.kernel_dilation 382 | kernel_size = ( 383 | kernel_size - 1 384 | ) * dilation + 1 # effective kernel size with dilations 385 | padding_total = kernel_size - stride 386 | extra_padding = get_extra_padding_for_conv1d( 387 | x, kernel_size, stride, padding_total 388 | ) 389 | if self.causal: 390 | # Left padding for causal 391 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) 392 | else: 393 | # Asymmetric padding required for odd strides 394 | padding_right = padding_total // 2 395 | padding_left = padding_total - padding_right 396 | x = pad1d( 397 | x, (padding_left, padding_right + extra_padding), mode=self.pad_mode 398 | ) 399 | y = conv(x) 400 | return y 401 | 402 | 403 | class StreamableConvTranspose1d(nn.Module): 404 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding 405 | and normalization. 406 | """ 407 | 408 | out_channels: int 409 | kernel_size: int 410 | stride: int = 1 411 | causal: bool = False 412 | norm: str = "none" 413 | trim_right_ratio: float = 1.0 414 | norm_kwargs: tp.Dict[str, tp.Any] = field(default_factory=lambda: {}) 415 | 416 | def __post_init__(self): 417 | assert ( 418 | self.causal or self.trim_right_ratio == 1.0 419 | ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" 420 | assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 421 | super().__post_init__() 422 | 423 | @nn.compact 424 | def __call__(self, x): 425 | convtr = NormConvTranspose1d( 426 | self.out_channels, 427 | kernel_size=self.kernel_size, 428 | strides=self.stride, 429 | causal=self.causal, 430 | norm=self.norm, 431 | norm_kwargs=self.norm_kwargs, 432 | ) 433 | kernel_size = convtr.kernel_size 434 | stride = convtr.strides 435 | padding_total = kernel_size - stride 436 | 437 | y = convtr(x) 438 | 439 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be 440 | # removed at the very end, when keeping only the right length for the output, 441 | # as removing it here would require also passing the length at the matching layer 442 | # in the encoder. 443 | if self.causal: 444 | # Trim the padding on the right according to the specified ratio 445 | # if trim_right_ratio = 1.0, trim everything from right 446 | padding_right = math.ceil(padding_total * self.trim_right_ratio) 447 | padding_left = padding_total - padding_right 448 | y = unpad1d(y, (padding_left, padding_right)) 449 | else: 450 | # Asymmetric padding required for odd strides 451 | padding_right = padding_total // 2 452 | padding_left = padding_total - padding_right 453 | y = unpad1d(y, (padding_left, padding_right)) 454 | return y 455 | 456 | 457 | class StreamableLSTM(nn.Module): 458 | """LSTM without worrying about the hidden state, nor the layout of the data. 459 | Expects input as convolutional layout. 460 | """ 461 | 462 | dimension: int 463 | num_layers: int = 2 464 | skip: int = 1 # bool 465 | 466 | @nn.compact 467 | def __call__(self, x): 468 | y = x 469 | for _ in range(self.num_layers): 470 | y = nn.RNN(nn.LSTMCell(self.dimension))(y) 471 | 472 | if self.skip: 473 | y = y + x 474 | 475 | return y 476 | -------------------------------------------------------------------------------- /src/dac_jax/nn/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import flax.linen as nn 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | 8 | def default_stride(strides): 9 | if strides is None: 10 | return 1 11 | if isinstance(strides, int): 12 | return strides 13 | return strides[0] 14 | 15 | 16 | def default_kernel_dilation(kernel_dilation): 17 | if kernel_dilation is None: 18 | return 1 19 | if isinstance(kernel_dilation, int): 20 | return kernel_dilation 21 | return kernel_dilation[0] 22 | 23 | 24 | def default_kernel_size(kernel_size): 25 | if kernel_size is None: 26 | return 1 27 | if isinstance(kernel_size, int): 28 | return kernel_size 29 | return kernel_size[0] 30 | 31 | 32 | def conv_to_delay(s, d, k, L): 33 | L = (L - 1) * s + d * (k - 1) + 1 34 | L = math.ceil(L) 35 | return L 36 | 37 | 38 | def convtranspose_to_delay(s, d, k, L): 39 | L = ((L - d * (k - 1) - 1) / s) + 1 40 | L = math.ceil(L) 41 | return L 42 | 43 | 44 | def conv_to_output_length(s, d, k, L): 45 | L = ((L - d * (k - 1) - 1) / s) + 1 46 | L = math.floor(L) 47 | return L 48 | 49 | 50 | def convtranspose_to_output_length(s, d, k, L): 51 | L = (L - 1) * s + d * (k - 1) + 1 52 | L = math.floor(L) 53 | return L 54 | 55 | 56 | def make_initializer(in_channels, out_channels, kernel_size, groups, mode="fan_in"): 57 | # https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html 58 | if mode == "fan_in": 59 | c = in_channels 60 | elif mode == "fan_out": 61 | c = out_channels 62 | else: 63 | raise ValueError(f"Unexpected mode: {mode}") 64 | k = groups / (c * jnp.prod(jnp.array(kernel_size))) 65 | scale = jnp.sqrt(k) 66 | return lambda key, shape, dtype: jax.random.uniform( 67 | key, shape, minval=-scale, maxval=scale, dtype=dtype 68 | ) 69 | 70 | 71 | class WNConv1d(nn.Conv): 72 | 73 | @nn.compact 74 | def __call__(self, x): 75 | # https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L18-L21 76 | # https://github.com/google/flax/issues/4091 77 | # Note: we are just ignoring whatever self.kernel_init and self.bias_init are. 78 | kernel_init = jax.nn.initializers.truncated_normal( 79 | 0.02, lower=-2 / 0.02, upper=2 / 0.02 80 | ) 81 | bias_init = nn.initializers.zeros 82 | 83 | conv = nn.Conv( 84 | features=self.features, 85 | kernel_size=self.kernel_size, 86 | strides=self.strides, 87 | padding=self.padding, 88 | input_dilation=self.input_dilation, 89 | kernel_dilation=self.kernel_dilation, 90 | feature_group_count=self.feature_group_count, 91 | use_bias=self.use_bias, 92 | mask=self.mask, 93 | dtype=self.dtype, 94 | param_dtype=self.param_dtype, 95 | precision=self.precision, 96 | kernel_init=kernel_init, 97 | bias_init=bias_init, 98 | ) 99 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3)) 100 | block = nn.WeightNorm(conv, scale_init=scale_init) 101 | x = block(x) 102 | return x 103 | 104 | @staticmethod 105 | def delay(s, d, k, L): 106 | s = default_stride(s) 107 | d = default_kernel_dilation(d) 108 | k = default_kernel_size(k) 109 | return conv_to_delay(s, d, k, L) 110 | 111 | @staticmethod 112 | def output_length(s, d, k, L): 113 | s = default_stride(s) 114 | d = default_kernel_dilation(d) 115 | k = default_kernel_size(k) 116 | return conv_to_output_length(s, d, k, L) 117 | 118 | 119 | class WNConvTranspose1d(nn.ConvTranspose): 120 | 121 | @nn.compact 122 | def __call__(self, x): 123 | 124 | groups = 1 125 | # note: we just ignore whatever self.kernel_init is 126 | kernel_init = make_initializer( 127 | x.shape[-1], 128 | self.features, 129 | self.kernel_size, 130 | groups, 131 | mode="fan_out", 132 | ) 133 | 134 | if self.use_bias: 135 | # note: we just ignore whatever self.bias_init is 136 | bias_init = make_initializer( 137 | x.shape[-1], 138 | self.features, 139 | self.kernel_size, 140 | groups, 141 | mode="fan_out", 142 | ) 143 | else: 144 | bias_init = None 145 | 146 | conv = nn.ConvTranspose( 147 | features=self.features, 148 | kernel_size=self.kernel_size, 149 | strides=self.strides, 150 | padding=self.padding, 151 | kernel_dilation=self.kernel_dilation, 152 | use_bias=self.use_bias, 153 | mask=self.mask, 154 | dtype=self.dtype, 155 | param_dtype=self.param_dtype, 156 | precision=self.precision, 157 | kernel_init=kernel_init, 158 | bias_init=bias_init, 159 | transpose_kernel=True, # note: this helps us load weights from PyTorch 160 | ) 161 | scale_init = nn.initializers.constant(1 / jnp.sqrt(3)) 162 | block = nn.WeightNorm(conv, scale_init=scale_init) 163 | x = block(x) 164 | return x 165 | 166 | @staticmethod 167 | def delay(s, d, k, L): 168 | s = default_stride(s) 169 | d = default_kernel_dilation(d) 170 | k = default_kernel_size(k) 171 | return convtranspose_to_delay(s, d, k, L) 172 | 173 | @staticmethod 174 | def output_length(s, d, k, L): 175 | s = default_stride(s) 176 | d = default_kernel_dilation(d) 177 | k = default_kernel_size(k) 178 | return convtranspose_to_output_length(s, d, k, L) 179 | 180 | 181 | class Snake1d(nn.Module): 182 | 183 | channels: int 184 | 185 | @nn.compact 186 | def __call__(self, x): 187 | alpha = self.param("alpha", nn.initializers.ones, (1, 1, self.channels)) 188 | x = x + jnp.reciprocal(alpha + 1e-9) * jnp.square(jnp.sin(alpha * x)) 189 | return x 190 | -------------------------------------------------------------------------------- /src/dac_jax/nn/loss.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | from typing import Callable, Optional 4 | 5 | from einops import rearrange 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | 10 | from dac_jax.audio_utils import stft, decibel_loudness, mel_spectrogram 11 | 12 | 13 | def l1_loss(y_true: jnp.ndarray, y_pred: jnp.ndarray, reduction="mean") -> jnp.ndarray: 14 | 15 | errors = jnp.abs(y_pred - y_true) 16 | if reduction == "none": 17 | return errors 18 | elif reduction == "mean": 19 | return jnp.mean(errors) 20 | elif reduction == "sum": 21 | return jnp.sum(errors) 22 | else: 23 | raise ValueError(f"Invalid reduction method: {reduction}") 24 | 25 | 26 | def sisdr_loss( 27 | y_true: jnp.ndarray, 28 | y_pred: jnp.ndarray, 29 | scaling: int = True, 30 | reduction: str = "mean", 31 | zero_mean: int = True, 32 | clip_min: int = None, 33 | ): 34 | """ 35 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch 36 | of estimated and reference audio signals or aligned features. 37 | 38 | Parameters 39 | ---------- 40 | y_true : jnp.ndarray 41 | Estimate jnp.ndarray 42 | y_pred : jnp.ndarray 43 | Reference jnp.ndarray 44 | scaling : int, optional 45 | Whether to use scale-invariant (True) or 46 | signal-to-noise ratio (False), by default True 47 | reduction : str, optional 48 | How to reduce across the batch (either 'mean', 49 | 'sum', or none).], by default ' mean' 50 | zero_mean : int, optional 51 | Zero mean the references and estimates before 52 | computing the loss, by default True 53 | clip_min : int, optional 54 | The minimum possible loss value. Helps network 55 | to not focus on making already good examples better, by default None 56 | 57 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py 58 | """ 59 | 60 | eps = 1e-8 61 | # nb, nc, nt 62 | references = y_true 63 | estimates = y_pred 64 | 65 | nb = references.shape[0] 66 | references = references.reshape(nb, 1, -1).transpose(0, 2, 1) 67 | estimates = estimates.reshape(nb, 1, -1).transpose(0, 2, 1) 68 | 69 | # samples now on axis 1 70 | if zero_mean: 71 | mean_reference = references.mean(axis=1, keepdims=True) 72 | mean_estimate = estimates.mean(axis=1, keepdims=True) 73 | else: 74 | mean_reference = 0 75 | mean_estimate = 0 76 | 77 | _references = references - mean_reference 78 | _estimates = estimates - mean_estimate 79 | 80 | references_projection = jnp.square(_references).sum(axis=-2) + eps 81 | references_on_estimates = (_estimates * _references).sum(axis=-2) + eps 82 | 83 | scale = ( 84 | jnp.expand_dims(references_on_estimates / references_projection, 1) 85 | if scaling 86 | else 1 87 | ) 88 | 89 | e_true = scale * _references 90 | e_res = _estimates - e_true 91 | 92 | signal = jnp.square(e_true).sum(axis=1) 93 | noise = jnp.square(e_res).sum(axis=1) 94 | sdr = -10 * jnp.log10(signal / noise + eps) 95 | 96 | if clip_min is not None: 97 | sdr = jnp.maximum(sdr, clip_min) 98 | 99 | if reduction == "mean": 100 | sdr = sdr.mean() 101 | elif reduction == "sum": 102 | sdr = sdr.sum() 103 | return sdr 104 | 105 | 106 | def discriminator_loss(fake, real): 107 | """ 108 | Computes a discriminator loss, given the outputs of the discriminator 109 | used on a fake input and a real input. 110 | """ 111 | d_fake, d_real = fake, real 112 | 113 | loss_d = 0 114 | for x_fake, x_real in zip(d_fake, d_real): 115 | loss_d = loss_d + jnp.square(x_fake[-1]).mean() 116 | loss_d = loss_d + jnp.square(1 - x_real[-1]).mean() 117 | # We normalize based on the number of feature maps, but the original DAC doesn't do this. 118 | # loss_d = loss_d / len(d_fake) 119 | return loss_d 120 | 121 | 122 | def generator_loss(fake, real): 123 | """ 124 | Computes a generator loss, given the outputs of the discriminator 125 | used on a fake input and a real input. 126 | """ 127 | d_fake, d_real = fake, jax.lax.stop_gradient(real) 128 | 129 | loss_g = 0 130 | for x_fake in d_fake: 131 | loss_g = loss_g + jnp.square(1 - x_fake[-1]).mean() 132 | 133 | # We normalize based on the number of feature maps, but the original DAC doesn't do this. 134 | # loss_g = loss_g / len(d_fake) 135 | 136 | loss_feature = 0 137 | 138 | for i in range(len(d_fake)): 139 | for j in range(len(d_fake[i]) - 1): 140 | loss_feature = loss_feature + l1_loss(d_fake[i][j], d_real[i][j]) 141 | 142 | # We normalize based on the number of feature maps, but the original DAC doesn't do this. 143 | # loss_feature = loss_feature / sum([len(d_fake[i])-1 for i in range(len(d_fake))]) 144 | 145 | return loss_g, loss_feature 146 | 147 | 148 | def multiscale_stft_loss( 149 | y_true: jnp.ndarray, 150 | y_pred: jnp.ndarray, 151 | window_lengths=None, 152 | loss_fn: Callable = l1_loss, 153 | clamp_eps: float = 1e-5, 154 | mag_weight: float = 1.0, 155 | log_weight: float = 1.0, 156 | pow: float = 2.0, 157 | match_stride: Optional[bool] = False, 158 | window: str = "hann", 159 | ): 160 | """Computes the multiscale STFT loss from [1]. 161 | 162 | Parameters 163 | ---------- 164 | y_true : AudioSignal 165 | Estimate signal 166 | y_pred : AudioSignal 167 | Reference signal 168 | window_lengths : List[int], optional 169 | Length of each window of each STFT, by default [2048, 512] 170 | loss_fn : typing.Callable, optional 171 | How to compare each loss, by default l1_loss 172 | clamp_eps : float, optional 173 | Clamp on the log magnitude, below, by default 1e-5 174 | mag_weight : float, optional 175 | Weight of raw magnitude portion of loss, by default 1.0 176 | log_weight : float, optional 177 | Weight of log magnitude portion of loss, by default 1.0 178 | pow : float, optional 179 | Power to raise magnitude to before taking log, by default 2.0 180 | match_stride : bool, optional 181 | Whether to match the stride of convolutional layers, by default False 182 | window : str or tuple or array_like, optional 183 | Desired window to use. If `window` is a string or tuple, it is 184 | passed to `get_window` to generate the window values, which are 185 | DFT-even by default. See `get_window` for a list of windows and 186 | required parameters. If `window` is array_like it will be used 187 | directly as the window and its length must be nperseg. Defaults 188 | to a Hann window. 189 | 190 | Returns 191 | ------- 192 | jnp.ndarray 193 | Multi-scale STFT loss. 194 | 195 | References 196 | ---------- 197 | 198 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. 199 | "DDSP: Differentiable Digital Signal Processing." 200 | International Conference on Learning Representations. 2019. 201 | 202 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 203 | """ 204 | 205 | x = y_pred 206 | y = y_true 207 | 208 | loss = jnp.zeros(()) 209 | 210 | if window_lengths is None: 211 | window_lengths = [2048, 512] 212 | 213 | for frame_length in window_lengths: 214 | stft_fun = partial( 215 | stft, 216 | frame_length=frame_length, 217 | hop_factor=0.25, 218 | window=window, 219 | match_stride=match_stride, 220 | ) 221 | x_stft = stft_fun(x) 222 | y_stft = stft_fun(y) 223 | 224 | loss = loss + log_weight * loss_fn( 225 | decibel_loudness(x_stft, clamp_eps=clamp_eps, pow=pow), 226 | decibel_loudness(y_stft, clamp_eps=clamp_eps, pow=pow), 227 | ) 228 | loss = loss + mag_weight * loss_fn(jnp.abs(x_stft), jnp.abs(y_stft)) 229 | 230 | return loss 231 | 232 | 233 | def mel_spectrogram_loss( 234 | y_true: jnp.ndarray, 235 | y_pred: jnp.ndarray, 236 | sample_rate: int, 237 | n_mels=None, 238 | window_lengths=None, 239 | loss_fn: Callable = l1_loss, 240 | clamp_eps: float = 1e-5, 241 | mag_weight: float = 1.0, 242 | log_weight: float = 1.0, 243 | pow: float = 2.0, 244 | match_stride: Optional[bool] = False, 245 | lower_edge_hz=None, 246 | upper_edge_hz=None, 247 | window: str = "hann", 248 | ): 249 | """Compute distance between mel spectrograms. Can be used in a multiscale way. 250 | 251 | Parameters 252 | ---------- 253 | y_true : jnp.ndarray 254 | Estimate signal 255 | y_pred : jnp.ndarray 256 | Reference signal 257 | sample_rate : int 258 | Sample rate 259 | n_mels : List[int] 260 | Number of mel bins per STFT, by default [150, 80], 261 | window_lengths : List[int], optional 262 | Length of each window of each STFT, by default [2048, 512] 263 | loss_fn : typing.Callable, optional 264 | How to compare each loss, by default L1Loss() 265 | clamp_eps : float, optional 266 | Clamp on the log magnitude, below, by default 1e-5 267 | mag_weight : float, optional 268 | Weight of raw magnitude portion of loss, by default 1.0 269 | log_weight : float, optional 270 | Weight of log magnitude portion of loss, by default 1.0 271 | pow : float, optional 272 | Power to raise magnitude to before taking log, by default 2.0 273 | match_stride : bool, optional 274 | Whether to match the stride of convolutional layers, by default False 275 | lower_edge_hz: List[float], optional 276 | Lowest frequency to consider to general mel filterbanks. 277 | upper_edge_hz: List[float], optional 278 | Highest frequency to consider to general mel filterbanks. 279 | window : str or tuple or array_like, optional 280 | Desired window to use. If `window` is a string or tuple, it is 281 | passed to `get_window` to generate the window values, which are 282 | DFT-even by default. See `get_window` for a list of windows and 283 | required parameters. If `window` is array_like it will be used 284 | directly as the window and its length must be nperseg. Defaults 285 | to a Hann window. 286 | 287 | Returns 288 | ------- 289 | jnp.ndarray 290 | Mel loss. 291 | 292 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 293 | """ 294 | 295 | x = y_pred 296 | y = y_true 297 | 298 | if n_mels is None: 299 | n_mels = [150, 80] 300 | 301 | if window_lengths is None: 302 | window_lengths = [2048, 512] 303 | 304 | if lower_edge_hz is None: 305 | lower_edge_hz = [0.0, 0.0] 306 | 307 | if upper_edge_hz is None: 308 | upper_edge_hz = [None, None] # librosa converts None to sample_rate/2 309 | 310 | def decibel_fn(mels: jnp.ndarray) -> jnp.ndarray: 311 | return jnp.log10(jnp.pow(jnp.maximum(mels, clamp_eps), pow)) 312 | 313 | loss = jnp.zeros(()) 314 | for features, fmin, fmax, frame_length in zip( 315 | n_mels, lower_edge_hz, upper_edge_hz, window_lengths 316 | ): 317 | 318 | def spectrogram_fn(signal): 319 | stft_data = stft( 320 | signal, 321 | frame_length=frame_length, 322 | hop_factor=0.25, 323 | window=window, 324 | match_stride=match_stride, 325 | ) 326 | stft_data = rearrange(stft_data, "b c nf nt -> (b c) nt nf") 327 | 328 | spectrogram = jnp.abs(stft_data) 329 | return spectrogram 330 | 331 | x_spectrogram = spectrogram_fn(x) 332 | y_spectrogram = spectrogram_fn(y) 333 | 334 | nf = x_spectrogram.shape[-1] 335 | 336 | mel_fun = partial( 337 | mel_spectrogram, 338 | log_scale=False, 339 | sample_rate=sample_rate, 340 | frame_length=2 * (nf - 1), 341 | num_features=features, 342 | lower_edge_hertz=fmin, 343 | upper_edge_hertz=fmax, 344 | ) 345 | 346 | x_mels = mel_fun(x_spectrogram) 347 | y_mels = mel_fun(y_spectrogram) 348 | 349 | loss = loss + log_weight * loss_fn(decibel_fn(x_mels), decibel_fn(y_mels)) 350 | loss = loss + mag_weight * loss_fn(x_mels, y_mels) 351 | 352 | return loss 353 | 354 | 355 | def phase_loss( 356 | y_true: jnp.ndarray, 357 | y_pred: jnp.ndarray, 358 | window_length: int = 2048, 359 | hop_factor: float = 0.25, 360 | ): 361 | """Computes phase loss between an estimate and a reference signal. 362 | 363 | Parameters 364 | ---------- 365 | y_true : AudioSignal 366 | Reference signal 367 | y_pred : AudioSignal 368 | Estimate signal 369 | window_length : int, optional 370 | Length of STFT window, by default 2048 371 | hop_factor : float, optional 372 | Hop factor between 0 and 1, which is multiplied by the length of STFT 373 | window length to determine the hop size. 374 | 375 | Returns 376 | ------- 377 | jnp.ndarray 378 | Phase loss. 379 | 380 | Implementation adapted from https://github.com/descriptinc/audiotools/blob/7776c296c711db90176a63ff808c26e0ee087263/audiotools/metrics/spectral.py#L195 381 | """ 382 | 383 | x = y_pred 384 | y = y_true 385 | 386 | stft_fun = partial( 387 | stft, frame_length=window_length, hop_factor=hop_factor, window="hann" 388 | ) 389 | 390 | x_stft = stft_fun(x) 391 | y_stft = stft_fun(y) 392 | 393 | def phase(spec): 394 | return jnp.angle(spec) 395 | 396 | # Take circular difference 397 | diff = phase(x_stft) - phase(y_stft) 398 | diff = diff.at[diff < -jnp.pi].set(diff[diff < -jnp.pi] + 2 * jnp.pi) 399 | diff = diff.at[diff > jnp.pi].set(diff[diff > jnp.pi - 2 * jnp.pi]) 400 | 401 | # Scale true magnitude to weights in [0, 1] 402 | x_mag = jnp.abs(x_stft) 403 | x_min, x_max = x_mag.min(), x_mag.max() 404 | weights = (x_mag - x_min) / (x_max - x_min) 405 | 406 | # Take weighted mean of all phase errors 407 | loss = jnp.square(weights * diff).mean() 408 | return loss 409 | 410 | 411 | def stoi( 412 | estimates: jnp.ndarray, 413 | references: jnp.ndarray, 414 | sample_rate: int, 415 | extended: int = False, 416 | ): 417 | """Short term objective intelligibility 418 | Computes the STOI (See [1][2]) of a de-noised signal compared to a clean 419 | signal, The output is expected to have a monotonic relation with the 420 | subjective speech-intelligibility, where a higher score denotes better 421 | speech intelligibility. Uses pystoi under the hood. 422 | 423 | Parameters 424 | ---------- 425 | estimates : jnp.ndarray 426 | De-noised speech 427 | references : jnp.ndarray 428 | Clean original speech 429 | sample_rate: int 430 | Sample rate of the references 431 | extended : int, optional 432 | Boolean, whether to use the extended STOI described in [3], by default False 433 | 434 | Returns 435 | ------- 436 | Tensor[float] 437 | Short time objective intelligibility measure between clean and 438 | de-noised speech 439 | 440 | References 441 | ---------- 442 | 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 443 | Objective Intelligibility Measure for Time-Frequency Weighted Noisy 444 | Speech', ICASSP 2010, Texas, Dallas. 445 | 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 446 | Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 447 | IEEE Transactions on Audio, Speech, and Language Processing, 2011. 448 | 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the 449 | Intelligibility of Speech Masked by Modulated Noise Maskers', 450 | IEEE Transactions on Audio, Speech and Language Processing, 2016. 451 | """ 452 | import pystoi 453 | 454 | if estimates.ndim == 3: 455 | estimates = jnp.average(estimates, axis=-2) # to mono 456 | if references.ndim == 3: 457 | references = jnp.average(references, axis=-2) # to mono 458 | 459 | stois = [] 460 | for reference, estimate in zip(references, estimates): 461 | _stoi = pystoi.stoi( 462 | np.array(reference), 463 | np.array(estimates), 464 | sample_rate, 465 | extended=extended, 466 | ) 467 | stois.append(_stoi) 468 | return jnp.array(np.array(stois)) 469 | 470 | 471 | def pesq( 472 | estimates: jnp.ndarray, 473 | estimates_sample_rate: int, 474 | references: jnp.ndarray, 475 | references_sample_rate: int, 476 | mode: str = "wb", 477 | target_sr: int = 16000, 478 | ): 479 | """_summary_ 480 | 481 | Parameters 482 | ---------- 483 | estimates : jnp.ndarray 484 | Degraded audio signal 485 | estimates_sample_rate: int 486 | Sample rate of the estimates 487 | references : jnp.ndarray 488 | Reference audio signal 489 | references_sample_rate: int 490 | Sample rate of the references 491 | mode : str, optional 492 | 'wb' (wide-band) or 'nb' (narrow-band), by default "wb" 493 | target_sr : int, optional 494 | Target sample rate, by default 16000 495 | 496 | Returns 497 | ------- 498 | Tensor[float] 499 | PESQ score: P.862.2 Prediction (MOS-LQO) 500 | """ 501 | from pesq import pesq as pesq_fn 502 | from audiotree.resample import resample 503 | 504 | if estimates.ndim == 3: 505 | estimates = jnp.average(estimates, axis=-2, keepdims=True) # to mono 506 | if references.ndim == 3: 507 | references = jnp.average(references, axis=-2, keepdims=True) # to mono 508 | 509 | estimates = resample(estimates, old_sr=estimates_sample_rate, new_sr=target_sr) 510 | references = resample(references, old_sr=references_sample_rate, new_sr=target_sr) 511 | 512 | pesqs = [] 513 | for reference, estimate in zip(references, estimates): 514 | _pesq = pesq_fn( 515 | estimates_sample_rate, 516 | np.array(reference[0]), 517 | np.array(estimate[0]), 518 | mode, 519 | ) 520 | pesqs.append(_pesq) 521 | return jnp.array(np.array(pesqs)) 522 | 523 | 524 | def visqol( 525 | estimates: jnp.ndarray, 526 | estimates_sample_rate: int, 527 | references: jnp.ndarray, 528 | references_sample_rate: int, 529 | mode: str = "audio", 530 | ): # pragma: no cover 531 | """ViSQOL score. 532 | 533 | Parameters 534 | ---------- 535 | estimates : jnp.ndarray 536 | Degraded audio 537 | references : jnp.ndarray 538 | Reference audio 539 | mode : str, optional 540 | 'audio' or 'speech', by default 'audio' 541 | 542 | Returns 543 | ------- 544 | Tensor[float] 545 | ViSQOL score (MOS-LQO) 546 | """ 547 | from visqol import visqol_lib_py 548 | from visqol.pb2 import visqol_config_pb2 549 | from visqol.pb2 import similarity_result_pb2 550 | from audiotree.resample import resample 551 | 552 | config = visqol_config_pb2.VisqolConfig() 553 | if mode == "audio": 554 | target_sr = 48000 555 | config.options.use_speech_scoring = False 556 | svr_model_path = "libsvm_nu_svr_model.txt" 557 | elif mode == "speech": 558 | target_sr = 16000 559 | config.options.use_speech_scoring = True 560 | svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" 561 | else: 562 | raise ValueError(f"Unrecognized mode: {mode}") 563 | config.audio.sample_rate = target_sr 564 | config.options.svr_model_path = os.path.join( 565 | os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path 566 | ) 567 | 568 | api = visqol_lib_py.VisqolApi() 569 | api.Create(config) 570 | 571 | if estimates.ndim == 3: 572 | estimates = jnp.average(estimates, axis=-2, keepdims=True) # to mono 573 | if references.ndim == 3: 574 | references = jnp.average(references, axis=-2, keepdims=True) # to mono 575 | 576 | estimates = resample(estimates, old_sr=estimates_sample_rate, new_sr=target_sr) 577 | references = resample(references, old_sr=references_sample_rate, new_sr=target_sr) 578 | 579 | visqols = [] 580 | for reference, estimate in zip(references, estimates): 581 | _visqol = api.Measure( 582 | np.array(reference[0], dtype=np.float32), 583 | np.array(estimate[0], dtype=np.float32), 584 | ) 585 | visqols.append(_visqol.moslqo) 586 | return jnp.array(np.array(visqols)) 587 | -------------------------------------------------------------------------------- /src/dac_jax/nn/quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | from einops import rearrange 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | import jax.random 8 | 9 | from dac_jax.nn.encodec_quantize import QuantizedResult 10 | from dac_jax.nn.layers import WNConv1d 11 | 12 | 13 | def mse_loss( 14 | predictions: jnp.ndarray, targets: jnp.ndarray, reduction="mean" 15 | ) -> jnp.ndarray: 16 | errors = (predictions - targets) ** 2 17 | if reduction == "none": 18 | return errors 19 | elif reduction == "mean": 20 | return jnp.mean(errors) 21 | elif reduction == "sum": 22 | return jnp.sum(errors) 23 | else: 24 | raise ValueError(f"Invalid reduction method: {reduction}") 25 | 26 | 27 | def normalize(x, ord=2, axis=1, eps=1e-12): 28 | """Normalizes an array along a specified dimension. 29 | 30 | Args: 31 | x: A JAX array to normalize. 32 | ord: The order of the norm (default is 2, corresponding to L2-norm). 33 | axis: The dimension along which to normalize. 34 | eps: A small constant to avoid division by zero. 35 | 36 | Returns: 37 | A JAX array with normalized vectors. 38 | 39 | Reference: 40 | https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html 41 | """ 42 | denom = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True) 43 | denom = jnp.maximum(eps, denom) 44 | return x / denom 45 | 46 | 47 | class VectorQuantize(nn.Module): 48 | """ 49 | Implementation of VQ similar to Karpathy's repo: 50 | https://github.com/karpathy/deep-vector-quantization 51 | Additionally uses following tricks from Improved VQGAN 52 | (https://arxiv.org/pdf/2110.04627.pdf): 53 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 54 | for improved codebook usage 55 | 2. l2-normalized codes: Converts Euclidean distance to cosine similarity which 56 | improves training stability 57 | """ 58 | 59 | input_dim: int 60 | codebook_size: int 61 | codebook_dim: int 62 | 63 | def setup(self): 64 | self.in_proj = WNConv1d(features=self.codebook_dim, kernel_size=(1,)) 65 | self.out_proj = WNConv1d(features=self.input_dim, kernel_size=(1,)) 66 | # PyTorch uses a normal distribution for weight initialization of Embeddings. 67 | self.codebook = nn.Embed( 68 | num_embeddings=self.codebook_size, 69 | features=self.codebook_dim, 70 | embedding_init=nn.initializers.normal(stddev=1), 71 | ) 72 | 73 | def __call__( 74 | self, z 75 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: 76 | """Quantized the input tensor using a fixed codebook and returns the corresponding codebook vectors 77 | 78 | Parameters 79 | ---------- 80 | z : Tensor[B x T x D] 81 | 82 | Returns 83 | ------- 84 | Tensor[B x T x D] 85 | Quantized continuous representation of input 86 | Tensor[1] 87 | Commitment loss to train encoder to predict vectors closer to codebook 88 | entries 89 | Tensor[1] 90 | Codebook loss to update the codebook 91 | Tensor[B x T] 92 | Codebook indices (quantized discrete representation of input) 93 | Tensor[B x T x D] 94 | Projected latents (continuous representation of input before quantization) 95 | """ 96 | 97 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 98 | z_e = self.in_proj(z) # z_e : (B x T x D) 99 | z_q, indices = self.decode_latents(z_e) 100 | 101 | commitment_loss = mse_loss( 102 | z_e, jax.lax.stop_gradient(z_q), reduction="none" 103 | ).mean([1, 2]) 104 | codebook_loss = mse_loss( 105 | z_q, jax.lax.stop_gradient(z_e), reduction="none" 106 | ).mean([1, 2]) 107 | 108 | z_q = z_e + jax.lax.stop_gradient( 109 | z_q - z_e 110 | ) # noop in forward pass, straight-through gradient estimator in backward pass 111 | 112 | z_q = self.out_proj(z_q) 113 | 114 | return z_q, commitment_loss, codebook_loss, indices, z_e 115 | 116 | def embed_code(self, embed_id): 117 | return self.codebook(embed_id) 118 | 119 | def decode_code(self, embed_id): 120 | return self.embed_code(embed_id) 121 | 122 | def decode_latents(self, latents: jnp.ndarray): 123 | encodings = rearrange(latents, "b t d -> (b t) d", d=self.codebook_dim) 124 | codebook = self.codebook.embedding # codebook: (N x D) 125 | # L2 normalize encodings and codebook (ViT-VQGAN) 126 | encodings = normalize(encodings) 127 | codebook = normalize(codebook) 128 | 129 | # Compute Euclidean distance with codebook 130 | dist = ( 131 | jnp.square(encodings).sum(1, keepdims=True) 132 | - 2 * encodings @ codebook.transpose() 133 | + jnp.square(codebook).sum(1, keepdims=True).transpose() 134 | ) 135 | indices = rearrange( 136 | jnp.argmax(-dist, axis=1), "(b t) -> b t", b=latents.shape[0] 137 | ) 138 | z_q = self.decode_code(indices) 139 | return z_q, indices 140 | 141 | 142 | class ResidualVectorQuantize(nn.Module): 143 | """ 144 | Introduced in SoundStream: An End-to-End Neural Audio Codec 145 | https://arxiv.org/abs/2107.03312 146 | """ 147 | 148 | input_dim: int = 512 149 | num_codebooks: int = 9 150 | codebook_size: int = 1024 151 | codebook_dim: Union[int, list] = 8 152 | quantizer_dropout: float = 0.0 153 | 154 | def __post_init__(self) -> None: 155 | if isinstance(self.codebook_dim, int): 156 | self.codebook_dim = [self.codebook_dim for _ in range(self.num_codebooks)] 157 | super().__post_init__() 158 | 159 | def setup(self) -> None: 160 | 161 | self.quantizers = [ 162 | VectorQuantize(self.input_dim, self.codebook_size, self.codebook_dim[i]) 163 | for i in range(self.num_codebooks) 164 | ] 165 | 166 | def __call__(self, z, n_quantizers: int = None, train=True) -> QuantizedResult: 167 | z_q = 0 168 | residual = z 169 | commitment_loss = jnp.zeros(()) 170 | codebook_loss = jnp.zeros(()) 171 | 172 | codebook_indices = [] 173 | latents = [] 174 | 175 | if n_quantizers is None: 176 | n_quantizers = self.num_codebooks 177 | if train: 178 | n_quantizers = jnp.ones((z.shape[0],)) * self.num_codebooks + 1 179 | dropout = jax.random.randint( 180 | self.make_rng("rng_stream"), 181 | shape=(z.shape[0],), 182 | minval=1, 183 | maxval=self.num_codebooks + 1, 184 | ) 185 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 186 | n_quantizers = n_quantizers.at[:n_dropout].set(dropout[:n_dropout]) 187 | 188 | # todo: this loop would possibly compile faster if jax.lax.scan were used 189 | for i, quantizer in enumerate(self.quantizers): 190 | if not train and i >= n_quantizers: 191 | break 192 | 193 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 194 | residual 195 | ) 196 | 197 | # Create mask to apply quantizer dropout 198 | mask = jnp.full((z.shape[0],), fill_value=i) < n_quantizers 199 | z_q = z_q + z_q_i * mask[:, None, None] 200 | residual = residual - z_q_i 201 | 202 | # Sum losses 203 | commitment_loss = commitment_loss + (commitment_loss_i * mask).mean() 204 | codebook_loss = codebook_loss + (codebook_loss_i * mask).mean() 205 | 206 | codebook_indices.append(indices_i) 207 | latents.append(z_e_i) 208 | 209 | codes = jnp.stack(codebook_indices, axis=1) 210 | latents = jnp.concatenate(latents, axis=2).transpose(0, 2, 1) 211 | 212 | # normalize based on number of codebooks 213 | # commitment_loss = commitment_loss / self.num_codebooks 214 | # codebook_loss = codebook_loss / self.num_codebooks 215 | 216 | return QuantizedResult( 217 | z_q, 218 | codes=codes, 219 | bandwidth=None, 220 | penalty=None, 221 | metrics=None, 222 | latents=latents, 223 | commitment_loss=commitment_loss, 224 | codebook_loss=codebook_loss, 225 | ) 226 | 227 | def from_codes(self, codes: jnp.ndarray): 228 | """Given the quantized codes, reconstruct the continuous representation 229 | Parameters 230 | ---------- 231 | codes : Tensor[B x T x N] 232 | Quantized discrete representation of input 233 | Returns 234 | ------- 235 | Tensor[B x T x D] 236 | Quantized continuous representation of input 237 | """ 238 | z_q = 0.0 239 | z_p = [] 240 | num_codebooks = codes.shape[-2] 241 | assert num_codebooks <= self.num_codebooks 242 | 243 | # todo: use jax.lax.scan for this loop 244 | for i in range(num_codebooks): 245 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 246 | z_p.append(z_p_i) 247 | 248 | z_q_i = self.quantizers[i].out_proj(z_p_i) 249 | z_q = z_q + z_q_i 250 | 251 | return z_q, jnp.concatenate(z_p, axis=1), codes 252 | 253 | def from_latents(self, latents: jnp.ndarray): 254 | # todo: this function hasn't been tested/used yet. 255 | 256 | """Given the unquantized latents, reconstruct the 257 | continuous representation after quantization. 258 | 259 | Parameters 260 | ---------- 261 | latents : Tensor[B x T x N] 262 | Continuous representation of input after projection 263 | 264 | Returns # todo: make this return info correct 265 | ------- 266 | Tensor[B x T x D] 267 | Quantized representation of full-projected space 268 | Tensor[B x T x D] 269 | Quantized representation of latent space 270 | """ 271 | z_q = 0 272 | z_p = [] 273 | codes = [] 274 | dims = jnp.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 275 | 276 | num_codebooks = jnp.where(dims <= latents.shape[2])[0].max( 277 | axis=0, keepdims=True 278 | ) # todo: check 279 | 280 | # todo: use jax.lax.scan for this loop 281 | for i in range(num_codebooks): 282 | j, k = dims[i], dims[i + 1] 283 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, :, j:k]) 284 | z_p.append(z_p_i) 285 | codes.append(codes_i) 286 | 287 | z_q_i = self.quantizers[i].out_proj(z_p_i) 288 | z_q = z_q + z_q_i 289 | 290 | return z_q, jnp.concatenate(z_p, axis=1), jnp.stack(codes, axis=1) 291 | 292 | 293 | if __name__ == "__main__": 294 | rvq = ResidualVectorQuantize(quantizer_dropout=True) 295 | key = jax.random.PRNGKey(0) 296 | key, subkey = jax.random.split(key) 297 | x = jax.random.normal(key=subkey, shape=(16, 80, 512)) 298 | 299 | key, subkey = jax.random.split(key) 300 | params = rvq.init({"params": subkey, "rng_stream": jax.random.key(4)}, x)["params"] 301 | z_q, codes, latents, commitment_loss, codebook_loss = rvq.apply( 302 | {"params": params}, x, rngs={"rng_stream": jax.random.key(4)} 303 | ) 304 | print(latents.shape) 305 | -------------------------------------------------------------------------------- /src/dac_jax/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import environ 3 | from pathlib import Path 4 | import json 5 | import typing as tp 6 | 7 | import argbind 8 | from huggingface_hub import hf_hub_download 9 | import numpy as np 10 | from omegaconf import OmegaConf 11 | import torch 12 | 13 | from dac_jax.utils import load_torch_weights_encodec 14 | from dac_jax.utils import load_torch_weights 15 | from dac_jax.model import DAC 16 | from dac_jax.model import EncodecModel, SEANetEncoder, SEANetDecoder 17 | from dac_jax.nn.encodec_quantize import ResidualVectorQuantizer 18 | 19 | 20 | def get_audiocraft_cache_dir() -> tp.Optional[str]: 21 | return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) 22 | 23 | 24 | def _get_state_dict( 25 | file_or_url_or_id: tp.Union[Path, str], 26 | filename: tp.Optional[str] = None, 27 | device='cpu', 28 | cache_dir: tp.Optional[str] = None, 29 | ): 30 | if cache_dir is None: 31 | cache_dir = get_audiocraft_cache_dir() 32 | # Return the state dict either from a file or url 33 | file_or_url_or_id = str(file_or_url_or_id) 34 | assert isinstance(file_or_url_or_id, str) 35 | 36 | if os.path.isfile(file_or_url_or_id): 37 | return torch.load(file_or_url_or_id, map_location=device) 38 | 39 | if os.path.isdir(file_or_url_or_id): 40 | file = f"{file_or_url_or_id}/{filename}" 41 | return torch.load(file, map_location=device) 42 | 43 | elif file_or_url_or_id.startswith('https://'): 44 | return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) 45 | 46 | else: 47 | assert filename is not None, "filename needs to be defined if using HF checkpoints" 48 | 49 | file = hf_hub_download( 50 | repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, 51 | library_name="audiocraft", library_version="1.3.0") 52 | return torch.load(file, map_location=device) 53 | 54 | 55 | try: 56 | from audiocraft.models.loaders import load_compression_model_ckpt 57 | except Exception as e: 58 | def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): 59 | return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) 60 | 61 | 62 | __MODEL_LATEST_TAGS__ = { 63 | ("44khz", "8kbps"): "0.0.1", 64 | ("24khz", "8kbps"): "0.0.4", 65 | ("16khz", "8kbps"): "0.0.5", 66 | ("44khz", "16kbps"): "1.0.0", 67 | } 68 | 69 | __MODEL_URLS__ = { 70 | ( 71 | "44khz", 72 | "0.0.1", 73 | "8kbps", 74 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", 75 | ( 76 | "24khz", 77 | "0.0.4", 78 | "8kbps", 79 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", 80 | ( 81 | "16khz", 82 | "0.0.5", 83 | "8kbps", 84 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", 85 | ( 86 | "44khz", 87 | "1.0.0", 88 | "16kbps", 89 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", 90 | } 91 | 92 | 93 | def convert_torch_weights_to_numpy( 94 | torch_weights_path: Path, write_path: Path, metadata_path: Path 95 | ): 96 | 97 | if write_path.exists() and metadata_path.exists(): 98 | return 99 | 100 | if not write_path.exists(): 101 | write_path.parent.mkdir(parents=True, exist_ok=True) 102 | 103 | weights = torch.load(str(torch_weights_path), map_location=torch.device("cpu")) 104 | 105 | kwargs = weights["metadata"]["kwargs"] 106 | with open(metadata_path, "w") as f: 107 | f.write(json.dumps(kwargs)) 108 | 109 | weights = weights["state_dict"] 110 | weights = {key: value.numpy() for key, value in weights.items()} 111 | 112 | allow_pickle = ( 113 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53 114 | ) 115 | 116 | np.save(write_path, weights, allow_pickle=allow_pickle) 117 | 118 | 119 | @argbind.bind(group="download_encodec", positional=True, without_prefix=True) 120 | def download_encodec( 121 | name: str = "facebook/musicgen-small", 122 | ): 123 | if ( 124 | "DAC_JAX_CACHE" in environ 125 | and environ["DAC_JAX_CACHE"].strip() 126 | and os.path.isabs(environ["DAC_JAX_CACHE"]) 127 | ): 128 | cache_home = environ["DAC_JAX_CACHE"] 129 | cache_home = Path(cache_home) 130 | else: 131 | cache_home = Path.home() / ".cache" / "dac_jax" 132 | 133 | safename = name.replace("/", "_") 134 | 135 | metadata_path = cache_home / f"encodec_weights_{safename}.json" 136 | jax_write_path = cache_home / f"encodec_jax_weights_{safename}.npy" 137 | 138 | if jax_write_path.exists() and metadata_path.exists(): 139 | return jax_write_path, metadata_path 140 | 141 | torch_model_path = cache_home / f"encodec_weights_{safename}.pth" 142 | 143 | if not torch_model_path.exists(): 144 | torch_model_path.parent.mkdir(parents=True, exist_ok=True) 145 | 146 | file_or_url_or_id = name 147 | pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=str(cache_home)) 148 | cfg = OmegaConf.create(pkg["xp.cfg"]) 149 | 150 | weights = pkg["best_state"] 151 | weights = {key: value.numpy() for key, value in weights.items()} 152 | 153 | jax_write_path.parent.mkdir(parents=True, exist_ok=True) 154 | 155 | allow_pickle = ( 156 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53 157 | ) 158 | 159 | np.save(jax_write_path, weights, allow_pickle=allow_pickle) 160 | 161 | OmegaConf.save(config=cfg, f=metadata_path) 162 | 163 | return jax_write_path, metadata_path 164 | 165 | 166 | # todo: we don't call this function `download` because that would conflict with the PyTorch implementation's `download`. 167 | # and we need to be able to run both in our tests. 168 | # Reference issue: https://github.com/pseeth/argbind/?tab=readme-ov-file#bound-function-names-should-be-unique 169 | @argbind.bind(group="download_model", positional=True, without_prefix=True) 170 | def download_model( 171 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" 172 | ): 173 | """ 174 | Function that downloads the weights file from URL if a local cache is not found. 175 | 176 | Parameters 177 | ---------- 178 | model_type : str 179 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". 180 | model_bitrate: str 181 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 182 | Only 44khz model supports 16kbps. 183 | tag : str 184 | The tag of the model to download. Defaults to "latest". 185 | 186 | Returns 187 | ------- 188 | Path 189 | Directory path required to load model via audiotools. 190 | """ 191 | model_type = model_type.lower() 192 | tag = tag.lower() 193 | 194 | if ( 195 | "DAC_JAX_CACHE" in environ 196 | and environ["DAC_JAX_CACHE"].strip() 197 | and os.path.isabs(environ["DAC_JAX_CACHE"]) 198 | ): 199 | cache_home = environ["DAC_JAX_CACHE"] 200 | cache_home = Path(cache_home) 201 | else: 202 | cache_home = Path.home() / ".cache" / "dac_jax" 203 | 204 | metadata_path = cache_home / f"weights_{model_type}_{model_bitrate}_{tag}.json" 205 | jax_write_path = cache_home / f"jax_weights_{model_type}_{model_bitrate}_{tag}.npy" 206 | 207 | if jax_write_path.exists() and metadata_path.exists(): 208 | return jax_write_path, metadata_path 209 | 210 | assert model_type in [ 211 | "44khz", 212 | "24khz", 213 | "16khz", 214 | ], "model_type must be one of '44khz', '24khz', or '16khz'" 215 | 216 | assert model_bitrate in [ 217 | "8kbps", 218 | "16kbps", 219 | ], "model_bitrate must be one of '8kbps', or '16kbps'" 220 | 221 | if tag == "latest": 222 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] 223 | 224 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) 225 | 226 | if download_link is None: 227 | raise ValueError( 228 | f"Could not find model with tag {tag} and model type {model_type}" 229 | ) 230 | 231 | torch_model_path = cache_home / f"weights_{model_type}_{model_bitrate}_{tag}.pth" 232 | 233 | if not torch_model_path.exists(): 234 | torch_model_path.parent.mkdir(parents=True, exist_ok=True) 235 | 236 | # Download the model 237 | import requests 238 | 239 | response = requests.get(download_link) 240 | 241 | if response.status_code != 200: 242 | raise ValueError( 243 | f"Could not download model. Received response code {response.status_code}" 244 | ) 245 | torch_model_path.write_bytes(response.content) 246 | 247 | convert_torch_weights_to_numpy(torch_model_path, jax_write_path, metadata_path) 248 | 249 | # remove torch model because it's not needed anymore. 250 | if torch_model_path.exists(): 251 | os.remove(torch_model_path) 252 | 253 | return jax_write_path, metadata_path 254 | 255 | 256 | def load_encodec_model( 257 | name: str = "facebook/musicgen-small", 258 | load_path: str = None, 259 | metadata_path: str = None, 260 | ): 261 | if not load_path or not metadata_path: 262 | load_path, metadata_path = download_encodec(name) 263 | 264 | kwargs = OmegaConf.load(metadata_path) 265 | 266 | seanet_kwargs = kwargs["seanet"] 267 | 268 | common_kwargs = { 269 | "channels": kwargs["channels"], 270 | "dimension": seanet_kwargs["dimension"], 271 | "n_filters": seanet_kwargs["n_filters"], 272 | "n_residual_layers": seanet_kwargs["n_residual_layers"], 273 | "ratios": seanet_kwargs["ratios"], 274 | "activation": seanet_kwargs["activation"].lower(), 275 | "activation_params": OmegaConf.to_object(seanet_kwargs["activation_params"]), 276 | "norm": seanet_kwargs["norm"], 277 | "norm_params": OmegaConf.to_object(seanet_kwargs["norm_params"]), 278 | "kernel_size": seanet_kwargs["kernel_size"], 279 | "last_kernel_size": seanet_kwargs["last_kernel_size"], 280 | "residual_kernel_size": seanet_kwargs["residual_kernel_size"], 281 | "dilation_base": seanet_kwargs["dilation_base"], 282 | "causal": kwargs["encodec"]["causal"], 283 | "pad_mode": seanet_kwargs["pad_mode"], 284 | "true_skip": seanet_kwargs["true_skip"], 285 | "compress": seanet_kwargs["compress"], 286 | "lstm": seanet_kwargs["compress"], 287 | "disable_norm_outer_blocks": seanet_kwargs["disable_norm_outer_blocks"], 288 | } 289 | encoder_override_kwargs = {} 290 | decoder_override_kwargs = { 291 | "trim_right_ratio": seanet_kwargs["decoder"]["trim_right_ratio"], 292 | "final_activation": seanet_kwargs["decoder"]["final_activation"], 293 | "final_activation_params": seanet_kwargs["decoder"]["final_activation_params"], 294 | } 295 | encoder_kwargs = {**common_kwargs, **encoder_override_kwargs} 296 | decoder_kwargs = {**common_kwargs, **decoder_override_kwargs} 297 | 298 | rvq_kwargs = kwargs["rvq"] 299 | quantizer_kwargs = { 300 | "dimension": seanet_kwargs["dimension"], 301 | "n_q": rvq_kwargs["n_q"], 302 | "q_dropout": rvq_kwargs["q_dropout"], 303 | "bins": rvq_kwargs["bins"], 304 | "decay": rvq_kwargs["decay"], 305 | "kmeans_init": rvq_kwargs["kmeans_init"], 306 | "kmeans_iters": rvq_kwargs["kmeans_iters"], 307 | "threshold_ema_dead_code": rvq_kwargs["threshold_ema_dead_code"], 308 | "orthogonal_reg_weight": rvq_kwargs["orthogonal_reg_weight"], 309 | "orthogonal_reg_active_codes_only": rvq_kwargs[ 310 | "orthogonal_reg_active_codes_only" 311 | ], 312 | "orthogonal_reg_max_codes": None, # todo: 313 | } 314 | 315 | encoder = SEANetEncoder(**encoder_kwargs) 316 | decoder = SEANetDecoder(**decoder_kwargs) 317 | quantizer = ResidualVectorQuantizer(**quantizer_kwargs) 318 | 319 | sample_rate = kwargs["sample_rate"] 320 | 321 | encodec_model = EncodecModel( 322 | encoder=encoder, 323 | decoder=decoder, 324 | quantizer=quantizer, 325 | causal=kwargs["encodec"]["causal"], 326 | renormalize=kwargs["encodec"]["renormalize"], 327 | frame_rate=sample_rate // encoder.hop_length, 328 | sample_rate=sample_rate, 329 | channels=kwargs["channels"], 330 | ) 331 | 332 | allow_pickle = ( 333 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53 334 | ) 335 | 336 | torch_params = np.load(load_path, allow_pickle=allow_pickle) 337 | torch_params = torch_params.item() # todo 338 | 339 | variables = load_torch_weights_encodec.torch_to_linen( 340 | torch_params, 341 | encodec_model.encoder.ratios, 342 | encodec_model.decoder.ratios, 343 | encodec_model.num_codebooks, 344 | ) 345 | 346 | return encodec_model, variables 347 | 348 | 349 | def load_model( 350 | model_type: str = "44khz", 351 | model_bitrate: str = "8kbps", 352 | tag: str = "latest", 353 | load_path: str = None, 354 | metadata_path: str = None, 355 | padding=True, 356 | ): 357 | # reference: 358 | # https://flax.readthedocs.io/en/latest/guides/training_techniques/transfer_learning.html#create-a-function-for-model-loading 359 | 360 | if not load_path or not metadata_path: 361 | load_path, metadata_path = download_model( 362 | model_type=model_type, model_bitrate=model_bitrate, tag=tag 363 | ) 364 | 365 | with open(str(metadata_path), "r") as f: 366 | kwargs = json.loads(f.read()) 367 | 368 | kwargs["padding"] = padding # todo: seems like bad design 369 | kwargs["num_codebooks"] = kwargs.pop("n_codebooks") 370 | 371 | model = DAC(**kwargs) 372 | 373 | allow_pickle = ( 374 | True # todo: https://github.com/descriptinc/descript-audio-codec/issues/53 375 | ) 376 | 377 | torch_params = np.load(load_path, allow_pickle=allow_pickle) 378 | torch_params = torch_params.item() 379 | 380 | variables = load_torch_weights.torch_to_linen( 381 | torch_params, model.encoder_rates, model.decoder_rates, model.num_codebooks 382 | ) 383 | 384 | return model, variables 385 | 386 | 387 | if __name__ == "__main__": 388 | load_encodec_model() 389 | -------------------------------------------------------------------------------- /src/dac_jax/utils/decode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | from tqdm import tqdm 6 | 7 | import jax 8 | 9 | from dac_jax import DACFile 10 | from dac_jax.utils import load_model 11 | 12 | 13 | warnings.filterwarnings( 14 | "ignore", category=UserWarning 15 | ) # ignore librosa warnings related to mel bins 16 | 17 | 18 | @jax.jit 19 | @argbind.bind(group="decode", positional=True, without_prefix=True) 20 | def decode( 21 | input: str, 22 | output: str = "", 23 | weights_path: str = "", 24 | model_tag: str = "latest", 25 | model_bitrate: str = "8kbps", 26 | model_type: str = "44khz", 27 | verbose: bool = False, 28 | ): 29 | """Decode audio from codes. 30 | 31 | Parameters 32 | ---------- 33 | input : str 34 | Path to input directory or file 35 | output : str, optional 36 | Path to output directory, by default "". 37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. 38 | weights_path : str, optional 39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet 40 | using the model_tag and model_type. 41 | model_tag : str, optional 42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 43 | model_bitrate: str 44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 45 | model_type : str, optional 46 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if 47 | `weights_path` is specified. 48 | """ 49 | model, variables = load_model( 50 | model_type=model_type, 51 | model_bitrate=model_bitrate, 52 | tag=model_tag, 53 | load_path=weights_path, 54 | ) 55 | 56 | # Find all .dac files in input directory 57 | _input = Path(input) 58 | input_files = list(_input.glob("**/*.dac")) 59 | 60 | # If input is a .dac file, add it to the list 61 | if _input.suffix == ".dac": 62 | input_files.append(_input) 63 | 64 | # Create output directory 65 | output = Path(output) 66 | output.mkdir(parents=True, exist_ok=True) 67 | 68 | @jax.jit 69 | def decompress_chunk(c): 70 | return model.apply(variables, c, method="decompress_chunk") 71 | 72 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"): 73 | # Load file 74 | dac_file = DACFile.load(input_files[i]) 75 | 76 | # Reconstruct audio from codes 77 | recons = model.decompress(decompress_chunk, dac_file, verbose=verbose) 78 | 79 | # Compute output path 80 | relative_path = input_files[i].relative_to(input) 81 | output_dir = output / relative_path.parent 82 | if not relative_path.name: 83 | output_dir = output 84 | relative_path = input_files[i] 85 | output_name = relative_path.with_suffix(".wav").name 86 | output_path = output_dir / output_name 87 | output_path.parent.mkdir(parents=True, exist_ok=True) 88 | 89 | # Write to file 90 | recons.write(output_path) 91 | 92 | 93 | if __name__ == "__main__": 94 | args = argbind.parse_args() 95 | with argbind.scope(args): 96 | decode() 97 | -------------------------------------------------------------------------------- /src/dac_jax/utils/encode.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import argbind 5 | import librosa 6 | from tqdm import tqdm 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | 11 | from dac_jax import load_model 12 | from dac_jax.audio_utils import find_audio 13 | 14 | warnings.filterwarnings( 15 | "ignore", category=UserWarning 16 | ) # ignore librosa warnings related to mel bins 17 | 18 | 19 | @jax.jit 20 | @argbind.bind(group="encode", positional=True, without_prefix=True) 21 | def encode( 22 | input: str, 23 | output: str = "", 24 | weights_path: str = "", 25 | model_tag: str = "latest", 26 | model_bitrate: str = "8kbps", 27 | n_quantizers: int = None, 28 | model_type: str = "44khz", 29 | win_duration: float = 5.0, 30 | verbose: bool = False, 31 | ): 32 | """Encode audio files in input path to .dac format. 33 | 34 | Parameters 35 | ---------- 36 | input : str 37 | Path to input audio file or directory 38 | output : str, optional 39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` 40 | is re-created in `output`. 41 | weights_path : str, optional 42 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet 43 | using the model_tag and model_type. 44 | model_tag : str, optional 45 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. 46 | model_bitrate: str 47 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". 48 | n_quantizers : int, optional 49 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model 50 | will compress at maximum bitrate. 51 | model_type : str, optional 52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if 53 | `weights_path` is specified. 54 | """ 55 | model, variables = load_model( 56 | model_type=model_type, 57 | model_bitrate=model_bitrate, 58 | tag=model_tag, 59 | load_path=weights_path, 60 | ) 61 | 62 | # Find all audio files in input path 63 | input = Path(input) 64 | audio_files = find_audio(input) 65 | 66 | output = Path(output) 67 | output.mkdir(parents=True, exist_ok=True) 68 | 69 | @jax.jit 70 | def compress_chunk(x): 71 | return model.apply(variables, x, method="compress_chunk") 72 | 73 | for audio_file in tqdm(audio_files, desc="Encoding files"): 74 | # Load file with original sample rate 75 | signal, sample_rate = librosa.load(audio_file, sr=None, mono=False) 76 | while signal.ndim < 3: 77 | signal = jnp.expand_dims(signal, axis=0) 78 | 79 | # Encode audio to .dac format 80 | dac_file = model.compress( 81 | compress_chunk, 82 | signal, 83 | sample_rate, 84 | win_duration=win_duration, 85 | verbose=verbose, 86 | n_quantizers=n_quantizers, 87 | ) 88 | 89 | # Compute output path 90 | relative_path = audio_file.relative_to(input) 91 | output_dir = output / relative_path.parent 92 | if not relative_path.name: 93 | output_dir = output 94 | relative_path = audio_file 95 | output_name = relative_path.with_suffix(".dac").name 96 | output_path = output_dir / output_name 97 | output_path.parent.mkdir(parents=True, exist_ok=True) 98 | 99 | dac_file.save(output_path) 100 | 101 | 102 | if __name__ == "__main__": 103 | args = argbind.parse_args() 104 | with argbind.scope(args): 105 | encode() 106 | -------------------------------------------------------------------------------- /src/dac_jax/utils/load_torch_weights.py: -------------------------------------------------------------------------------- 1 | def torch_to_linen( 2 | torch_params: dict, 3 | encoder_rates: tuple[int] = None, 4 | decoder_rates: tuple[int] = None, 5 | num_codebooks: int = 9, 6 | ) -> dict: 7 | """Convert PyTorch parameters to Linen nested dictionaries""" 8 | 9 | if encoder_rates is None: 10 | encoder_rates = [2, 4, 8, 8] 11 | if decoder_rates is None: 12 | decoder_rates = [8, 8, 4, 2] 13 | 14 | def parse_wn_conv(flax_params, from_prefix, to_i: int): 15 | d = {} 16 | d[f"Conv_0"] = { 17 | "bias": torch_params[f"{from_prefix}.bias"], 18 | "kernel": torch_params[f"{from_prefix}.weight_v"].T, 19 | } 20 | d[f"WeightNorm_0"] = { 21 | f"Conv_0/kernel/scale": torch_params[f"{from_prefix}.weight_g"].squeeze( 22 | (1, 2) 23 | ) 24 | } 25 | flax_params[f"WNConv1d_{to_i}"] = d 26 | 27 | def parse_wn_convtranspose(flax_params, from_prefix, to_i: int): 28 | d = {} 29 | d[f"ConvTranspose_0"] = { 30 | "bias": torch_params[f"{from_prefix}.bias"], 31 | "kernel": torch_params[f"{from_prefix}.weight_v"].transpose(), 32 | } 33 | d[f"WeightNorm_0"] = { 34 | f"ConvTranspose_0/kernel/scale": torch_params[ 35 | f"{from_prefix}.weight_g" 36 | ].squeeze((1, 2)) 37 | } 38 | flax_params[f"WNConvTranspose1d_{to_i}"] = d 39 | 40 | def parse_residual_unit(flax_params, from_prefix, to_i): 41 | d = {} 42 | d["Snake1d_0"] = { 43 | "alpha": torch_params[f"{from_prefix}.block.0.alpha"].transpose(0, 2, 1) 44 | } 45 | parse_wn_conv(d, f"{from_prefix}.block.1", 0) 46 | d["Snake1d_1"] = { 47 | "alpha": torch_params[f"{from_prefix}.block.2.alpha"].transpose(0, 2, 1) 48 | } 49 | parse_wn_conv(d, f"{from_prefix}.block.3", 1) 50 | flax_params[f"ResidualUnit_{to_i}"] = d 51 | 52 | def parse_encoder_block(flax_params, from_prefix, to_i): 53 | d = {} 54 | for i in range(3): 55 | parse_residual_unit(d, f"{from_prefix}.block.{i}", i) 56 | 57 | d["Snake1d_0"] = { 58 | "alpha": torch_params[f"{from_prefix}.block.3.alpha"].transpose(0, 2, 1) 59 | } 60 | 61 | parse_wn_conv(d, f"{from_prefix}.block.4", 0) 62 | flax_params[f"EncoderBlock_{to_i}"] = d 63 | 64 | def parse_decoder_block(flax_params, from_prefix, to_i): 65 | d = {} 66 | d["Snake1d_0"] = { 67 | "alpha": torch_params[f"{from_prefix}.block.0.alpha"].transpose(0, 2, 1) 68 | } 69 | 70 | parse_wn_convtranspose(d, f"{from_prefix}.block.1", 0) 71 | 72 | for i in range(3): 73 | parse_residual_unit(d, f"{from_prefix}.block.{i+2}", i) 74 | 75 | flax_params[f"DecoderBlock_{to_i}"] = d 76 | 77 | flax_params = {"encoder": {}, "decoder": {}, "quantizer": {}} 78 | 79 | i = 0 80 | # add Encoder 81 | parse_wn_conv(flax_params["encoder"], f"encoder.block.{i}", 0) 82 | 83 | # add EncoderBlocks 84 | for _ in encoder_rates: 85 | parse_encoder_block(flax_params["encoder"], f"encoder.block.{i+1}", i) 86 | i += 1 87 | 88 | i += 1 89 | flax_params["encoder"]["Snake1d_0"] = { 90 | "alpha": torch_params[f"encoder.block.{i}.alpha"].transpose(0, 2, 1) 91 | } 92 | 93 | i += 1 94 | parse_wn_conv(flax_params["encoder"], f"encoder.block.{i}", 1) 95 | 96 | # Add Quantizer 97 | for i in range(num_codebooks): 98 | quantizer = {} 99 | quantizer["in_proj"] = { 100 | "WeightNorm_0": { 101 | "Conv_0/kernel/scale": torch_params[ 102 | f"quantizer.quantizers.{i}.in_proj.weight_g" 103 | ].squeeze((1, 2)) 104 | }, 105 | "Conv_0": { 106 | "bias": torch_params[f"quantizer.quantizers.{i}.in_proj.bias"], 107 | "kernel": torch_params[f"quantizer.quantizers.{i}.in_proj.weight_v"].T, 108 | }, 109 | } 110 | quantizer["codebook"] = { 111 | "embedding": torch_params[f"quantizer.quantizers.{i}.codebook.weight"] 112 | } 113 | quantizer["out_proj"] = { 114 | "WeightNorm_0": { 115 | "Conv_0/kernel/scale": torch_params[ 116 | f"quantizer.quantizers.{i}.out_proj.weight_g" 117 | ].squeeze((1, 2)) 118 | }, 119 | "Conv_0": { 120 | "bias": torch_params[f"quantizer.quantizers.{i}.out_proj.bias"], 121 | "kernel": torch_params[f"quantizer.quantizers.{i}.out_proj.weight_v"].T, 122 | }, 123 | } 124 | flax_params["quantizer"][f"quantizers_{i}"] = quantizer 125 | 126 | i = 0 127 | # Add Decoder 128 | parse_wn_conv(flax_params["decoder"], f"decoder.model.{i}", 0) 129 | 130 | # Add DecoderBlocks 131 | for _ in decoder_rates: 132 | parse_decoder_block(flax_params["decoder"], f"decoder.model.{i+1}", i) 133 | i += 1 134 | 135 | i += 1 136 | flax_params["decoder"]["Snake1d_0"] = { 137 | "alpha": torch_params[f"decoder.model.{i}.alpha"].transpose(0, 2, 1) 138 | } 139 | 140 | i += 1 141 | parse_wn_conv(flax_params["decoder"], f"decoder.model.{i}", 1) 142 | 143 | return {"params": flax_params} 144 | -------------------------------------------------------------------------------- /src/dac_jax/utils/load_torch_weights_encodec.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | 3 | 4 | def streamable(torch_params, prefix: str): 5 | return { 6 | "NormConv1d_0": { 7 | "WeightNorm_0": { 8 | "Conv_0/kernel/scale": torch_params[ 9 | f"{prefix}.conv.conv.weight_g" 10 | ].squeeze((1, 2)), 11 | }, 12 | "Conv_0": { 13 | "bias": torch_params[f"{prefix}.conv.conv.bias"], 14 | "kernel": torch_params[f"{prefix}.conv.conv.weight_v"].T, 15 | }, 16 | } 17 | } 18 | 19 | 20 | def streamable_transpose(torch_params, prefix: str): 21 | return { 22 | "NormConvTranspose1d_0": { 23 | "WeightNorm_0": { 24 | "ConvTranspose_0/kernel/scale": torch_params[ 25 | f"{prefix}.convtr.convtr.weight_g" 26 | ].squeeze((1, 2)), 27 | }, 28 | "ConvTranspose_0": { 29 | "bias": torch_params[f"{prefix}.convtr.convtr.bias"], 30 | "kernel": torch_params[f"{prefix}.convtr.convtr.weight_v"].T, 31 | }, 32 | } 33 | } 34 | 35 | 36 | def lstm(torch_params, prefix: str, i: int): 37 | weight_ih_l0 = torch_params[f"{prefix}.lstm.weight_ih_l{i}"] 38 | weight_hh_l0 = torch_params[f"{prefix}.lstm.weight_hh_l{i}"] 39 | bias_ih_l0 = torch_params[f"{prefix}.lstm.bias_ih_l{i}"] 40 | bias_hh_l0 = torch_params[f"{prefix}.lstm.bias_hh_l{i}"] 41 | 42 | weight_hh_l0 = weight_hh_l0.transpose(1, 0) 43 | weight_ih_l0 = weight_ih_l0.transpose(1, 0) 44 | 45 | # https://github.com/pytorch/pytorch/blob/40de63be097ce6d499aac15fc58ed27ca33e5227/aten/src/ATen/native/RNN.cpp#L1560-L1564 46 | kernel_hi, kernel_hf, kernel_hg, kernel_ho = jnp.split(weight_hh_l0, 4, axis=1) 47 | kernel_ii, kernel_if, kernel_ig, kernel_io = jnp.split(weight_ih_l0, 4, axis=1) 48 | 49 | bias = bias_ih_l0 + bias_hh_l0 50 | 51 | bias_i, bias_f, bias_g, bias_o = jnp.split(bias, 4) 52 | 53 | return { 54 | "hi": { 55 | "bias": bias_i, 56 | "kernel": kernel_hi, 57 | }, 58 | "hf": { 59 | "bias": bias_f, 60 | "kernel": kernel_hf, 61 | }, 62 | "hg": { 63 | "bias": bias_g, 64 | "kernel": kernel_hg, 65 | }, 66 | "ho": { 67 | "bias": bias_o, 68 | "kernel": kernel_ho, 69 | }, 70 | "ii": { 71 | "kernel": kernel_ii, 72 | }, 73 | "if": { 74 | "kernel": kernel_if, 75 | }, 76 | "ig": { 77 | "kernel": kernel_ig, 78 | }, 79 | "io": { 80 | "kernel": kernel_io, 81 | }, 82 | } 83 | 84 | 85 | def torch_to_encoder(torch_params: dict, encoder_rates: tuple[int] = None): 86 | d = {} 87 | 88 | i = 0 89 | j = 0 90 | for _ in range(len(encoder_rates)): 91 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"encoder.model.{j}") 92 | j += 1 93 | d[f"SEANetResnetBlock_{i}"] = { 94 | f"StreamableConv1d_0": streamable( 95 | torch_params, f"encoder.model.{j}.block.1" 96 | ), 97 | f"StreamableConv1d_1": streamable( 98 | torch_params, f"encoder.model.{j}.block.3" 99 | ), 100 | } 101 | i += 1 102 | j += 2 103 | 104 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"encoder.model.{j}") 105 | 106 | j += 1 107 | lstm_layers = 2 # todo: 108 | d[f"StreamableLSTM_0"] = { 109 | f"LSTMCell_{k}": lstm(torch_params, f"encoder.model.{j}", k) 110 | for k in range(lstm_layers) 111 | } 112 | j += lstm_layers 113 | 114 | i += 1 115 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"encoder.model.{j}") 116 | 117 | return d 118 | 119 | 120 | def torch_to_decoder(torch_params: dict, decoder_rates: tuple[int] = None): 121 | d = {} 122 | 123 | i = 0 124 | j = 0 125 | 126 | d[f"StreamableConv1d_{i}"] = streamable(torch_params, f"decoder.model.{j}") 127 | j += 1 128 | lstm_layers = 2 # todo: 129 | d[f"StreamableLSTM_0"] = { 130 | f"LSTMCell_{k}": lstm(torch_params, f"decoder.model.{j}", k) 131 | for k in range(lstm_layers) 132 | } 133 | j += lstm_layers 134 | for k in range(len(decoder_rates)): 135 | d[f"StreamableConvTranspose1d_{i}"] = streamable_transpose( 136 | torch_params, f"decoder.model.{j}" 137 | ) 138 | j += 1 139 | d[f"SEANetResnetBlock_{i}"] = { 140 | f"StreamableConv1d_0": streamable( 141 | torch_params, f"decoder.model.{j}.block.1" 142 | ), 143 | f"StreamableConv1d_1": streamable( 144 | torch_params, f"decoder.model.{j}.block.3" 145 | ), 146 | } 147 | i += 1 148 | j += 2 149 | 150 | d[f"StreamableConv1d_1"] = streamable(torch_params, f"decoder.model.{j}") 151 | 152 | return d 153 | 154 | 155 | def torch_to_quantizer(torch_params: dict, n_quantizers): 156 | d = { 157 | f"layers_{i}": { 158 | "_codebook": { 159 | "embed": torch_params[f"quantizer.vq.layers.{i}._codebook.embed"], 160 | "embed_avg": torch_params[ 161 | f"quantizer.vq.layers.{i}._codebook.embed_avg" 162 | ], 163 | } 164 | } 165 | for i in range(n_quantizers) 166 | } 167 | 168 | return {"vq": d} 169 | 170 | 171 | def torch_to_linen( 172 | torch_params: dict, 173 | encoder_rates: tuple[int] = None, 174 | decoder_rates: tuple[int] = None, 175 | num_codebooks: int = 9, 176 | ) -> dict: 177 | """Convert PyTorch parameters to Linen nested dictionaries""" 178 | 179 | if encoder_rates is None: 180 | encoder_rates = [2, 4, 8, 8] 181 | if decoder_rates is None: 182 | decoder_rates = [8, 8, 4, 2] 183 | 184 | return { 185 | "params": { 186 | "encoder": torch_to_encoder(torch_params, encoder_rates=encoder_rates), 187 | "decoder": torch_to_decoder(torch_params, decoder_rates=decoder_rates), 188 | "quantizer": torch_to_quantizer(torch_params, num_codebooks), 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | `60013__qubodup__whoosh.flac`: 2 | https://freesound.org/people/qubodup/sounds/60013/ 3 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DBraun/DAC-JAX/919ce4a2a9ec4c5c3fa7d10dcb2944259da00865/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_audio_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | from einops import rearrange 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pytest 7 | 8 | from dac_jax.audio_utils import stft, mel_spectrogram 9 | from dac_jax.nn.loss import mel_spectrogram_loss, multiscale_stft_loss 10 | 11 | from dac.nn.loss import MelSpectrogramLoss, MultiScaleSTFTLoss 12 | from audiotools import AudioSignal 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "match_stride,hop_factor,length", 17 | product( 18 | [False, True], 19 | [0.25, 0.5], 20 | [44100, 44101], 21 | ), 22 | ) 23 | def test_mel_same_as_audiotools(match_stride: bool, hop_factor: float, length: int): 24 | 25 | if hop_factor == 0.5 and match_stride: 26 | return # for some reason DAC torch disallows this 27 | 28 | sample_rate = 44100 29 | 30 | B = 1 31 | x = np.random.uniform(low=-1, high=1, size=(B, 1, length)) 32 | 33 | signal1 = AudioSignal(x, sample_rate=sample_rate) 34 | 35 | window_length = 2048 36 | hop_length = int(window_length * hop_factor) 37 | 38 | stft_kwargs = { 39 | "window_length": window_length, 40 | "hop_length": hop_length, 41 | "window_type": "hann", 42 | "match_stride": match_stride, 43 | "padding_type": "reflect", 44 | } 45 | 46 | n_mels = 80 47 | 48 | mel1 = signal1.mel_spectrogram(n_mels=n_mels, **stft_kwargs) 49 | 50 | stft1 = signal1.stft_data 51 | 52 | stft_data = stft( 53 | jnp.array(x), 54 | frame_length=stft_kwargs["window_length"], 55 | hop_factor=hop_factor, 56 | window=stft_kwargs["window_type"], 57 | match_stride=stft_kwargs["match_stride"], 58 | padding_type=stft_kwargs["padding_type"], 59 | ) 60 | 61 | assert np.allclose(np.abs(stft1), np.abs(stft_data), atol=1e-4) 62 | 63 | stft_data = rearrange(stft_data, "b c nf nt -> (b c) nt nf") 64 | 65 | spectrogram = jnp.abs(stft_data) 66 | 67 | mel2 = mel_spectrogram( 68 | spectrogram, 69 | log_scale=False, 70 | sample_rate=sample_rate, 71 | num_features=n_mels, 72 | frame_length=stft_kwargs["window_length"], 73 | ) 74 | 75 | mel2 = rearrange(mel2, "(b c) t bins -> b c bins t", b=B) 76 | 77 | assert np.allclose(mel1, np.array(mel2), atol=1e-4) 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "length", 82 | (44100, 44101), 83 | ) 84 | def test_mel_loss_same_as_dac_torch(length: int): 85 | 86 | sample_rate = 44100 87 | 88 | x1 = np.random.uniform(low=-1, high=1, size=(1, 1, length)) 89 | x2 = x1 * 0.5 90 | 91 | signal1 = AudioSignal(x1, sample_rate=sample_rate) 92 | signal2 = AudioSignal(x2, sample_rate=sample_rate) 93 | 94 | loss1 = mel_spectrogram_loss(jnp.array(x1), jnp.array(x2), sample_rate=sample_rate) 95 | loss2 = MelSpectrogramLoss()(signal1, signal2) 96 | 97 | assert np.isclose(np.array(loss1), loss2) 98 | 99 | 100 | @pytest.mark.parametrize( 101 | "length", 102 | (44100, 44101), 103 | ) 104 | def test_multiscale_stft_loss_same_as_dac_torch(length: int): 105 | sample_rate = 44100 106 | 107 | x1 = np.random.uniform(low=-1, high=1, size=(1, 1, length)) 108 | x2 = x1 * 0.5 109 | 110 | signal1 = AudioSignal(x1, sample_rate=sample_rate) 111 | signal2 = AudioSignal(x2, sample_rate=sample_rate) 112 | 113 | loss1 = multiscale_stft_loss(jnp.array(x1), jnp.array(x2)) 114 | loss2 = MultiScaleSTFTLoss()(signal1, signal2) 115 | 116 | assert np.isclose(np.array(loss1), loss2) 117 | 118 | 119 | if __name__ == "__main__": 120 | # test_mel_same_as_audiotools() 121 | # test_mel_loss_same_as_dac_torch() 122 | # test_multiscale_stft_loss_same_as_dac_torch() 123 | # test_stft_equivalence(True) 124 | # test_stft_equivalence(False) 125 | # test_stft_equivalence2(0.5) 126 | test_mel_same_as_audiotools(False, 0.25, 44100) 127 | -------------------------------------------------------------------------------- /tests/test_binding.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from pathlib import Path 3 | import tempfile 4 | 5 | import jax.numpy as jnp 6 | import librosa 7 | 8 | import dac_jax 9 | 10 | 11 | def test_binding(): 12 | 13 | # Download a model and bind variables to it. 14 | model, variables = dac_jax.load_model(model_type="44khz") 15 | model = model.bind(variables) 16 | 17 | # Load audio file 18 | filepath = Path(__file__).parent / "assets" / "60013__qubodup__whoosh.flac" 19 | signal, sample_rate = librosa.load(filepath, sr=44100, mono=True, duration=0.5) 20 | 21 | signal = jnp.array(signal, dtype=jnp.float32) 22 | while signal.ndim < 3: 23 | signal = jnp.expand_dims(signal, axis=0) 24 | 25 | # Encode audio signal as one long file (may run out of GPU memory on long files) 26 | dac_file = model.encode_to_dac(signal, sample_rate) 27 | 28 | with tempfile.TemporaryDirectory() as tmpdirname: 29 | filepath = os.path.join(tmpdirname, "dac_file_001.dac") 30 | 31 | # Save to a file 32 | dac_file.save(filepath) 33 | 34 | # Load a file 35 | dac_file = dac_jax.DACFile.load(filepath) 36 | 37 | # Decode audio signal 38 | y = model.decode(dac_file) 39 | 40 | # reconstruction mean-square error 41 | mse = jnp.square(y - signal).mean() 42 | 43 | # Informal expected maximum MSE 44 | assert mse.item() < 0.005 45 | 46 | 47 | if __name__ == "__main__": 48 | test_binding() 49 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for CLI. 3 | """ 4 | 5 | import subprocess 6 | from pathlib import Path 7 | 8 | import argbind 9 | import numpy as np 10 | import pytest 11 | import soundfile 12 | 13 | from dac_jax.__main__ import run 14 | 15 | 16 | def setup_module(module): 17 | data_dir = Path(__file__).parent / "tmp_assets" 18 | data_dir.mkdir(exist_ok=True, parents=True) 19 | input_dir = data_dir / "input" 20 | input_dir.mkdir(exist_ok=True, parents=True) 21 | 22 | for i in range(5): 23 | sample_rate = 44_100 24 | signal = np.random.randn(1000, sample_rate) 25 | soundfile.write(input_dir / f"sample_{i}.wav", signal, samplerate=sample_rate) 26 | return input_dir 27 | 28 | 29 | def teardown_module(module): 30 | repo_root = Path(__file__).parent.parent 31 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/tmp_assets"]) 32 | 33 | 34 | @pytest.mark.parametrize("model_type", ["44khz", "24khz", "16khz"]) 35 | def test_reconstruction(model_type): 36 | # Test encoding 37 | input_dir = Path(__file__).parent / "tmp_assets" / "input" 38 | output_dir = input_dir.parent / model_type / "encoded_output" 39 | args = { 40 | "input": str(input_dir), 41 | "output": str(output_dir), 42 | "model_type": model_type, 43 | } 44 | with argbind.scope(args): 45 | run("encode") 46 | 47 | # Test decoding 48 | input_dir = output_dir 49 | output_dir = input_dir.parent / model_type / "decoded_output" 50 | args = { 51 | "input": str(input_dir), 52 | "output": str(output_dir), 53 | "model_type": model_type, 54 | } 55 | with argbind.scope(args): 56 | run("decode") 57 | 58 | 59 | def test_compression(): 60 | # Test encoding 61 | input_dir = Path(__file__).parent / "tmp_assets" / "input" 62 | output_dir = input_dir.parent / "encoded_output_quantizers" 63 | args = { 64 | "input": str(input_dir), 65 | "output": str(output_dir), 66 | "n_quantizers": 3, 67 | } 68 | with argbind.scope(args): 69 | run("encode") 70 | 71 | # Open .dac file 72 | dac_file = output_dir / "sample_0.dac" 73 | allow_pickle = True # todo: 74 | artifacts = np.load(dac_file, allow_pickle=allow_pickle)[()] 75 | codes = artifacts["codes"] 76 | 77 | # Ensure that the number of quantizers is correct 78 | assert codes.shape[2] == 3 79 | 80 | # Ensure that dtype of compression is uint16 81 | assert codes.dtype == np.uint16 82 | 83 | 84 | # CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s 85 | -------------------------------------------------------------------------------- /tests/test_dac_equivalence.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["XLA_FLAGS"] = ( 4 | " --xla_gpu_deterministic_ops=true" # todo: https://github.com/google/flax/discussions/3382 5 | ) 6 | os.environ["TF_CUDNN_DETERMINISTIC"] = "1" 7 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 8 | 9 | from functools import partial 10 | from pathlib import Path 11 | 12 | import torch 13 | 14 | torch.use_deterministic_algorithms(True) 15 | 16 | import jax 17 | from jax import numpy as jnp 18 | from jax import random 19 | 20 | import dac as dac_torch 21 | from audiotools import AudioSignal 22 | 23 | import librosa 24 | import numpy as np 25 | 26 | import dac_jax 27 | from dac_jax import QuantizedResult 28 | 29 | 30 | def _torch_padding(np_data) -> dict[np.array]: 31 | 32 | model_path = dac_torch.utils.download(model_type="44khz") 33 | model = dac_torch.DAC.load(model_path) 34 | 35 | x = torch.from_numpy(np_data) 36 | sample_rate = model.sample_rate # note: not always true outside of this test 37 | x = model.preprocess(x, sample_rate) 38 | z, codes, latents, commitment_loss, codebook_loss = model.encode(x) 39 | 40 | # Decode audio signal 41 | audio = model.decode(z) 42 | 43 | d = { 44 | "audio": audio, 45 | "z": z, 46 | "codes": codes, 47 | "latents": latents, 48 | "vq/commitment_loss": commitment_loss, 49 | "vq/codebook_loss": codebook_loss, 50 | } 51 | 52 | d = {k: v.detach().cpu().numpy() for k, v in d.items()} 53 | 54 | return d 55 | 56 | 57 | def _torch_compress(np_data, win_duration: float): 58 | 59 | model = dac_torch.utils.load_model(model_type="44khz") 60 | 61 | sample_rate = model.sample_rate # note: not always true outside of this test 62 | x = AudioSignal(np_data, sample_rate=sample_rate) 63 | 64 | dac_file = model.compress(x, win_duration=win_duration) 65 | # get an embedding z for just a single chunk, only for the sake of comparing to jax 66 | c = dac_file.codes[..., : dac_file.chunk_length] 67 | z = model.quantizer.from_codes(c)[0] 68 | z = z.detach().cpu().numpy() 69 | 70 | recons = model.decompress(dac_file).audio_data 71 | recons = recons.cpu().numpy() 72 | 73 | return dac_file.codes, z, recons 74 | 75 | 76 | def _jax_padding(np_data) -> dict[np.array]: 77 | 78 | model, variables = dac_jax.load_model(model_type="44khz") 79 | 80 | q_res: QuantizedResult = model.apply( 81 | variables, jnp.array(np_data), model.sample_rate, train=False 82 | ) 83 | 84 | # Multiply by model.num_codebooks since we normalize by num_codebooks and torch doesn't. 85 | # q_res.commitment_loss = q_res.commitment_loss*model.num_codebooks 86 | # q_res.codebook_loss = q_res.codebook_loss * model.num_codebooks 87 | 88 | y = { 89 | "audio": q_res.recons, 90 | "z": q_res.z.transpose(0, 2, 1), 91 | "latents": q_res.latents, 92 | "codes": q_res.codes, 93 | "vq/codebook_loss": q_res.codebook_loss, 94 | "vq/commitment_loss": q_res.commitment_loss, 95 | } 96 | 97 | y = jax.tree.map(lambda x: np.array(x), y) 98 | return y 99 | 100 | 101 | def _jax_padding_jit(np_data): 102 | 103 | model, variables = dac_jax.load_model(model_type="44khz") 104 | 105 | @jax.jit 106 | def encode_to_codes(x: jnp.ndarray): 107 | codes, scale = model.apply( 108 | variables, 109 | x, 110 | method="encode", 111 | ) 112 | return codes, scale 113 | 114 | @partial(jax.jit, static_argnums=(1, 2)) 115 | def decode_from_codes(codes: jnp.ndarray, scale, length: int = None): 116 | recons = model.apply( 117 | variables, 118 | codes, 119 | scale, 120 | length, 121 | method="decode", 122 | ) 123 | 124 | return recons 125 | 126 | x = jnp.array(np_data) 127 | 128 | original_length = x.shape[-1] 129 | 130 | codes, scale = encode_to_codes(x) 131 | assert codes.shape[1] == model.num_codebooks 132 | 133 | recons = decode_from_codes(codes, scale, original_length) 134 | 135 | return np.array(recons), np.array(codes) 136 | 137 | 138 | def _jax_compress(np_data, win_duration: float): 139 | 140 | # set padding to False since we're using the chunk functions 141 | model, variables = dac_jax.load_model(model_type="44khz", padding=False) 142 | sample_rate = 44100 143 | 144 | @jax.jit 145 | def compress_chunk(x): 146 | return model.apply(variables, x, method="compress_chunk") 147 | 148 | @jax.jit 149 | def decompress_chunk(c): 150 | return model.apply(variables, c, method="decompress_chunk") 151 | 152 | @jax.jit 153 | def decode_latent(c): 154 | return model.apply(variables, c, method="decode_latent") 155 | 156 | key = jax.random.key(0) 157 | subkey1, subkey2, subkey3 = jax.random.split(key, 3) 158 | x = jax.random.normal(subkey1, shape=(1, 1, int(sample_rate * 2))) 159 | 160 | _ = model.init({"params": subkey2, "rng_stream": subkey3}, x, sample_rate) 161 | 162 | x = jnp.array(np_data) 163 | dac_file = model.compress(compress_chunk, x, sample_rate, win_duration=win_duration) 164 | 165 | codes = dac_file.codes 166 | 167 | # get an embedding z for just a single chunk, only for the sake of comparing to torch 168 | z = decode_latent(codes[:, :, : dac_file.chunk_length]).transpose(0, 2, 1) 169 | 170 | recons = model.decompress(decompress_chunk, dac_file) 171 | recons = np.array(recons) 172 | 173 | return codes, z, recons 174 | 175 | 176 | def test_equivalence_padding(): 177 | 178 | np.random.seed(0) 179 | np_data = np.random.normal(loc=0, scale=1, size=(1, 1, 4096)).astype(np.float32) 180 | 181 | jax_result = _jax_padding(np_data) 182 | torch_result = _torch_padding(np_data) 183 | assert set(jax_result.keys()) == set(torch_result.keys()) 184 | assert list(jax_result.keys()) 185 | for key in jax_result.keys(): 186 | # print(f"key: {key}, torch: {torch_result[key].shape}, jax: {jax_result[key].shape}") 187 | if key == "latents": 188 | # todo: why do we need to accept lower absolute tolerance for this key? 189 | atol = 1e-3 190 | elif key in ["vq/commitment_loss", "vq/codebook_loss"]: 191 | # todo: why do we need to accept lower absolute tolerance for these keys? 192 | atol = 1e-3 193 | elif key == "codes": 194 | atol = 1e-8 195 | elif key == "audio": 196 | atol = 1e-5 197 | elif key == "z": 198 | atol = 1e-5 199 | else: 200 | raise ValueError(f"Unexpected key '{key}'.") 201 | assert ( 202 | jax_result[key].shape == torch_result[key].shape 203 | ), f"key: {key}, torch: {torch_result[key].shape}, jax: {jax_result[key].shape}" 204 | assert np.allclose( 205 | jax_result[key], torch_result[key], atol=atol 206 | ), f"Failed to match outputs for key: {key} and atol: {atol}" 207 | 208 | jax_recons, jax_codes = _jax_padding_jit(np_data) 209 | 210 | assert np.allclose(torch_result["codes"], jax_codes) 211 | assert np.allclose( 212 | torch_result["audio"], jax_recons, atol=1e-4 213 | ) # todo: reduce atol to 1e-5 214 | 215 | 216 | def test_equivalence_compress(verbose=False): 217 | 218 | def compress_helper(np_data, atol, win_duration=0.38): 219 | 220 | jax_codes, jax_z, jax_recons = _jax_compress(np_data, win_duration) 221 | torch_codes, torch_z, torch_recons = _torch_compress(np_data, win_duration) 222 | assert np.allclose(jax_codes, torch_codes) 223 | np.testing.assert_almost_equal( 224 | torch_z, jax_z, decimal=5 225 | ) # todo: raise this to decimal=6 226 | if verbose: 227 | print("max diff: ", jnp.abs(jax_recons - torch_recons).max()) 228 | assert np.allclose(jax_recons, torch_recons, atol=atol) 229 | 230 | np_data, sr = librosa.load( 231 | Path(__file__).parent / "assets/60013__qubodup__whoosh.flac", sr=None, mono=True 232 | ) 233 | np_data = np.expand_dims(np.array(np_data), 0) 234 | np_data = np.expand_dims(np.array(np_data), 0) 235 | np_data = np.concatenate([np_data, np_data, np_data, np_data], axis=-1) 236 | compress_helper(np_data, atol=1e-5) 237 | 238 | np.random.seed(0) 239 | num_samples = int(44100 * 10) 240 | np_data = 0.5 * np.random.uniform(low=-1, high=1, size=(1, 1, num_samples)).astype( 241 | np.float32 242 | ) 243 | # todo: for compressing/decompressing noise, why must we use a higher absolute tolerance? 244 | compress_helper(np_data, atol=0.003) 245 | 246 | 247 | if __name__ == "__main__": 248 | test_equivalence_padding() 249 | test_equivalence_compress() 250 | print("All Done!") 251 | -------------------------------------------------------------------------------- /tests/test_encodec_equivalence.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["XLA_FLAGS"] = ( 4 | " --xla_gpu_deterministic_ops=true" # todo: https://github.com/google/flax/discussions/3382 5 | ) 6 | os.environ["TF_CUDNN_DETERMINISTIC"] = "1" 7 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 8 | 9 | from functools import partial 10 | from pathlib import Path 11 | 12 | from audiocraft.models import MusicGen 13 | import jax 14 | from jax import numpy as jnp 15 | from jax import random 16 | import librosa 17 | import numpy as np 18 | import torch 19 | 20 | from dac_jax import load_encodec_model, QuantizedResult 21 | 22 | 23 | def run_jax_model1(np_data): 24 | 25 | x = jnp.array(np_data) 26 | 27 | encodec_model, variables = load_encodec_model("facebook/musicgen-small") 28 | 29 | result: QuantizedResult = encodec_model.apply( 30 | variables, x, train=False, rngs={"rng_stream": random.key(0)} 31 | ) 32 | recons = result.recons 33 | codes = result.codes 34 | assert codes.shape[1] == encodec_model.num_codebooks 35 | 36 | return np.array(recons), np.array(codes) 37 | 38 | 39 | def run_jax_model2(np_data): 40 | """jax.jit version of run_jax_model1""" 41 | 42 | model, variables = load_encodec_model() 43 | 44 | @jax.jit 45 | def encode_to_codes(x: jnp.ndarray): 46 | codes, scale = model.apply( 47 | variables, 48 | x, 49 | method="encode", 50 | ) 51 | return codes, scale 52 | 53 | @partial(jax.jit, static_argnums=(1, 2)) 54 | def decode_from_codes(codes: jnp.ndarray, scale, length: int = None): 55 | recons = model.apply( 56 | variables, 57 | codes, 58 | scale, 59 | length, 60 | method="decode", 61 | ) 62 | 63 | return recons 64 | 65 | x = jnp.array(np_data) 66 | 67 | original_length = x.shape[-1] 68 | 69 | codes, scale = encode_to_codes(x) 70 | assert codes.shape[1] == model.num_codebooks 71 | 72 | recons = decode_from_codes(codes, scale, original_length) 73 | 74 | return np.array(recons), np.array(codes) 75 | 76 | 77 | def run_torch_model(np_data): 78 | model = MusicGen.get_pretrained("facebook/musicgen-small") 79 | x = torch.from_numpy(np_data).cuda() 80 | result = model.compression_model(x) 81 | 82 | recons = result.x.detach().cpu().numpy() 83 | codes = result.codes.detach().cpu().numpy() 84 | assert codes.shape[1] == model.compression_model.num_codebooks 85 | 86 | return recons, codes 87 | 88 | 89 | def test_encoded_equivalence(): 90 | np_data, sr = librosa.load( 91 | Path(__file__).parent / "assets/60013__qubodup__whoosh.flac", sr=None, mono=True 92 | ) 93 | np_data = np.expand_dims(np.array(np_data), 0) 94 | np_data = np.expand_dims(np.array(np_data), 0) 95 | np_data = np.concatenate([np_data, np_data, np_data, np_data], axis=-1) 96 | 97 | np_data *= 0.5 98 | 99 | torch_recons, torch_codes = run_torch_model(np_data) 100 | jax_recons, jax_codes = run_jax_model1(np_data) 101 | 102 | assert np.allclose(torch_codes, jax_codes) 103 | assert np.allclose(torch_recons, jax_recons, atol=1e-4) # todo: reduce atol to 1e-5 104 | 105 | jax_recons, jax_codes = run_jax_model2(np_data) 106 | 107 | assert np.allclose(torch_codes, jax_codes) 108 | assert np.allclose(torch_recons, jax_recons, atol=1e-4) # todo: reduce atol to 1e-5 109 | 110 | 111 | if __name__ == "__main__": 112 | test_encoded_equivalence() 113 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for CLI. 3 | """ 4 | 5 | import os 6 | import shlex 7 | import subprocess 8 | from pathlib import Path 9 | 10 | import argbind 11 | import numpy as np 12 | from audiotools import AudioSignal 13 | 14 | from dac_jax.__main__ import run 15 | 16 | 17 | def make_fake_data(data_dir=Path(__file__).parent / "tmp_assets"): 18 | data_dir.mkdir(exist_ok=True, parents=True) 19 | input_dir = data_dir / "input" 20 | input_dir.mkdir(exist_ok=True, parents=True) 21 | 22 | for i in range(100): 23 | signal = AudioSignal(np.random.randn(44_100 * 5), 44_100) 24 | signal.write(input_dir / f"sample_{i}.wav") 25 | return input_dir 26 | 27 | 28 | def make_fake_data_tree(): 29 | data_dir = Path(__file__).parent / "tmp_assets" 30 | 31 | for relative_dir in [ 32 | "train/speech", 33 | "train/music", 34 | "train/env", 35 | "val/speech", 36 | "val/music", 37 | "val/env", 38 | "test/speech", 39 | "test/music", 40 | "test/env", 41 | ]: 42 | leaf_dir = data_dir / relative_dir 43 | leaf_dir.mkdir(exist_ok=True, parents=True) 44 | make_fake_data(leaf_dir) 45 | return { 46 | split: { 47 | key: [str(data_dir / f"{split}/{key}")] 48 | for key in ["speech", "music", "env"] 49 | } 50 | for split in ["train", "val", "test"] 51 | } 52 | 53 | 54 | def setup_module(module): 55 | # Make fake dataset dir 56 | input_datasets = make_fake_data_tree() 57 | repo_root = Path(__file__).parent.parent 58 | 59 | # Load baseline conf and modify it for testing 60 | conf = argbind.load_args(repo_root / "conf" / "ablations" / "baseline.yml") 61 | 62 | for key in ["train", "val", "test"]: 63 | conf[f"{key}/build_dataset.folders"] = input_datasets[key] 64 | conf["num_iters"] = 1 65 | conf["val/AudioDataset.n_examples"] = 1 66 | conf["val_idx"] = [0] 67 | conf["val_batch_size"] = 1 68 | 69 | argbind.dump_args(conf, Path(__file__).parent / "tmp_assets" / "conf.yml") 70 | 71 | 72 | def teardown_module(module): 73 | repo_root = Path(__file__).parent.parent 74 | # Remove fake dataset dir 75 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/tmp_assets"]) 76 | subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/runs"]) 77 | 78 | 79 | def test_single_gpu_train(): 80 | env = os.environ.copy() 81 | repo_root = Path(__file__).parent.parent 82 | args = shlex.split( 83 | f"python {repo_root}/scripts/train.py --args.load {repo_root}/tests/assets/conf.yml --train.save_path {repo_root}/tests/runs/baseline" 84 | ) 85 | subprocess.check_output(args, env=env) 86 | 87 | 88 | def test_multi_gpu_train(): 89 | pass # todo: 90 | --------------------------------------------------------------------------------