├── .gitignore ├── LICENSE ├── README.md ├── _LICENSES ├── SpecDiff-GAN.LICENSE ├── aeiou.LICENSE ├── melgan-neurips.LICENSE ├── stable-audio-tools.LICENSE └── wavegrad.LICENSE ├── assets └── fig │ └── wavefit.png ├── configs ├── data │ └── default.yaml ├── default.yaml ├── model │ ├── wavefit-3.yaml │ └── wavefit-3_mem-efficient.yaml ├── optimizer │ └── default.yaml └── trainer │ └── default.yaml ├── container ├── build_singularity.bash └── wavefit.def ├── requirements.txt └── src ├── data ├── __init__.py ├── audio_io.py ├── dataset.py └── modification.py ├── inference.py ├── loss ├── __init__.py └── mrstft.py ├── model ├── __init__.py ├── discriminator.py ├── generator.py └── wavefit.py ├── test.py ├── train.py ├── trainer.py └── utils ├── __init__.py ├── audio_utils.py ├── logging.py ├── scheduler.py ├── torch_common.py └── viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Outputs 165 | runs/ 166 | wandb/ 167 | *.wav 168 | 169 | # Container 170 | *.sif 171 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yukara Ikemiya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 💬 WaveFit | A lightweight and fast speech vocoder 2 | 3 | This is an unofficial implementation of **`WaveFit`**[1] 4 | which is **a state-of-the-art lightweight/fast speech vocoder** from Google Research. 5 | 6 | ![WaveFit](./assets/fig/wavefit.png) 7 | 8 | This repository supports: 9 | - 🔥 Full implementation and training code for the `WaveFit` model 10 | - 🔥 Memory-efficient architecture used in [**Miipher-2**](https://arxiv.org/abs/2505.04457) [3] 11 | - 🔥 Distributed training with multiple GPUs / multiple Nodes 12 | 13 | # 📢 UPDATES 14 | 15 | - May 18, 2025: 👋 Now supports memory-efficient WaveFit architecture used in [**Miipher-2**](https://arxiv.org/abs/2505.04457) [3] 16 | 17 | # Requirements 18 | 19 | - Python 3.8.10 or later 20 | - PyTorch 2.1 or later 21 | 22 | ## Building a training environment 23 | 24 | To simplify setting up the training environment, I recommend to use container systems like `Docker` or `Singularity` instead of installing dependencies on each GPU machine. Below are the steps for creating `Singularity` containers. 25 | 26 | All example scripts are stored at the [container](container/) folder. 27 | 28 | ### 1. Install Singularity 29 | 30 | Install the latest Singularity by following the official instruction. 31 | - https://docs.sylabs.io/guides/main/user-guide/quick_start.html#quick-installation-steps 32 | 33 | ### 2. Create a Singularity image file 34 | 35 | Create (build) a Singularity image file with a definition file. 36 | ```bash 37 | singularity build --fakeroot wavefit.sif wavefit.def 38 | ``` 39 | 40 | ** NOTE: You might need to change NVIDIA base image in the definition file to match your GPU machine. 41 | 42 | Now, you obtained a container file for training and inference of WaveFit. 43 | 44 | ## Setting a WandB account for logging 45 | 46 | The training code also requires a Weights & Biases account to log the training outputs and demos. 47 | Please create an account and follow the instruction. 48 | 49 | Once you create your WandB account, 50 | you can obtain the API key from https://wandb.ai/authorize after logging in to your account. 51 | And then, the API key can be passed as an environment variable `WANDB_API_KEY` to a training job 52 | for logging training information. 53 | 54 | ```bash 55 | $ WANDB_API_KEY="12345x6789y..." 56 | ``` 57 | 58 | # Training 59 | 60 | ## Training from scratch 61 | 62 | In this repository, all the training parameters are configured by `Hydra`, 63 | allowing them to be set as command-line arguments. 64 | 65 | The following is an example of a job script for training using the LibriTTS dataset. 66 | ```bash 67 | ROOT_DIR="/path/to/this/repository/" 68 | DATASET_DIR="/path/to/LibriTTS/" 69 | CONTAINER_PATH="/path/to/wavefit.sif" 70 | TRAIN_DIRS=${DATASET_DIR}/train-clean-100/,${DATASET_DIR}/train-clean-360/ 71 | TEST_DIRS=${DATASET_DIR}/test-clean/ 72 | 73 | WANDB_API_KEY="12345x6789y..." 74 | PORT=12345 75 | JOB_ID="job_name" 76 | OUTPUT_DIR=${ROOT_DIR}/output/${MODEL}/${JOB_ID}/ 77 | 78 | MODEL="wavefit-3" 79 | # MODEL="wavefit-3_mem-efficient" # <- Memory-efficient architecture 80 | 81 | BATCH_SIZE=512 # This must be a multiple of GPU number. Please adjust to your environment. 82 | NUM_WORKERS=8 83 | 84 | mkdir -p ${OUTPUT_DIR} 85 | 86 | # Execution 87 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_DIR \ 88 | --env MASTER_PORT=${PORT} --env WANDB_API_KEY=$WANDB_API_KEY \ 89 | ${CONTAINER_PATH} \ 90 | torchrun --nproc_per_node gpu ${ROOT_DIR}/src/train.py \ 91 | model=${MODEL} \ 92 | data.train.dir_list=[${TRAIN_DIRS}] data.test.dir_list=[${TEST_DIRS}] \ 93 | trainer.output_dir=${OUTPUT_DIR} \ 94 | trainer.batch_size=${BATCH_SIZE} \ 95 | trainer.num_workers=${NUM_WORKERS} \ 96 | trainer.logger.project_name=${MODEL} \ 97 | trainer.logger.run_name=job-${JOB_ID} 98 | ``` 99 | ** Please note that the dataset directories are provided as lists. 100 | 101 | ## Resume training from a checkpoint 102 | 103 | While training, checkpoints (state_dict) of models, optimizers and schedulers are saved under the output directory specified in the configuration as follows. 104 | ``` 105 | output_dir/ 106 | ├─ ckpt/ 107 | │ ├─ latest/ 108 | │ │ ├─ opt_d.pth 109 | │ │ ├─ opt_g.pth 110 | │ │ ├─ generator.pth 111 | │ │ ├─ discriminator.pth 112 | │ │ ├─ ... 113 | ``` 114 | 115 | By specifying the checkpoint directory, you can easily resume your training from the checkpoint. 116 | ```bash 117 | CKPT_DIR="output_dir/ckpt/latest/" 118 | 119 | # Execution 120 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_DIR \ 121 | --env MASTER_PORT=${PORT} --env WANDB_API_KEY=$WANDB_API_KEY \ 122 | ${CONTAINER_PATH} \ 123 | torchrun --nproc_per_node gpu ${ROOT_DIR}/src/train.py trainer.ckpt_dir=${CKPT_DIR} 124 | ``` 125 | 126 | ### Overrides of parameters 127 | 128 | When resuming training, you might want to override some configuration parameters. 129 | To achieve this, in my implementation, only the specified parameters in job scripts will override the configuration from the checkpoint directory. 130 | 131 | For example, in the following case, the checkpoint will be loaded from `CKPT_DIR`, 132 | but the training outputs will be saved under `OUTPUT_DIR`. 133 | ```bash 134 | CKPT_DIR="output_dir/ckpt/latest/" 135 | OUTPUT_DIR="another/directory/" 136 | 137 | # Execution 138 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $DATASET_DIR \ 139 | --env MASTER_PORT=${PORT} --env WANDB_API_KEY=$WANDB_API_KEY \ 140 | ${CONTAINER_PATH} \ 141 | torchrun --nproc_per_node gpu ${ROOT_DIR}/src/train.py \ 142 | trainer.ckpt_dir=${CKPT_DIR} \ 143 | trainer.output_dir=${OUTPUT_DIR} 144 | ``` 145 | 146 | # Inference 147 | 148 | Using pre-trained WaveFit models, you can perform inference with audio signals as input 149 | (e.g. for evaluation). 150 | 151 | The [`inference.py`](src/inference.py) perform inference for all of audio files in a target directory. 152 | To check other options for the script, please use `-h` option. 153 | 154 | ```bash 155 | CKPT_DIR="output_dir/ckpt/latest/" 156 | AUDIO_DIR="path/to/target/speech/directory/" 157 | 158 | singularity exec --nv --pwd $ROOT_DIR -B $ROOT_DIR -B $AUDIO_DIR \ 159 | --env MASTER_PORT=${PORT} \ 160 | ${CONTAINER_PATH} \ 161 | torchrun --nproc_per_node gpu --master_port ${PORT} \ 162 | ${ROOT_DIR}/src/inference.py \ 163 | --ckpt-dir ${CKPT_DIR} \ 164 | --input-audio-dir ${AUDIO_DIR} \ 165 | --output-dir ${OUTPUT_DIR} 166 | ``` 167 | 168 | # 🤔 Unclear points in the implementation 169 | 170 | ## 1. Parameter size of generator 171 | 172 | WaveFit uses "WaveGrad Base model"[2] as a generator. 173 | But the two paper reported different parameter sizes for the same generator. 174 | 175 | - WaveGrad paper -> 15 millions 176 | - WaveFit paper -> 13.8 millions 177 | 178 | In order to implement the model as faithfully as possible to the original, I checked the parameter size of the implemented model, but I noticed that there're ambiguities in the architecture even in the WaveGrad paper. 179 | The most unclear part is the Downsample module within the downsampling block (DBlock). 180 | 181 | In the paper, it is written that the DBlock is implemented as a convolutional layer with a stride of a downsampling factor. 182 | Naively thinking, to perform only downsampling, using a `kernel_size=3` setting should suffice, and in this case, the parameter size turned out as follows. 183 | 184 | - 15.12 millions (kernel_size = 3) 185 | 186 | This size is (supposedly) as reported in the WaveGrad paper, and I believe it has been implemented faithfully. 187 | However, personally, I found it unsettling that when the downsampling factor exceeds 3, regions not covered by the convolution arise due to the stride exceeding the kernel size. 188 | Therefore, I recalculated the `kernel_size` as follows, which resulted in a slightly larger parameter size. 189 | 190 | - 15.51 millions (kernel_size = down_factor // 2 * 2 + 1) 191 | 192 | Regarding the parameter size reported in the WaveFit paper, I was unable to reproduce it, so the details of the implementation remain unclear to me. 193 | 194 | ## 2. Step condition (positional embedding) 195 | 196 | Similar to WaveGrad, the generator loop in WaveFit is conditioned on each step number (eq.15). 197 | However, while WaveGrad, as a diffusion model, is conditioned on continuous coefficients obtained from the noise scheduler, WaveFit has a fixed step number even during training, resulting in discrete conditioning. 198 | 199 | Assuming that WaveFit also utilizes Transformer's sinusoidal positional embedding for conditioning, I designed the embedding computation by setting the position coefficients for each step against the total number of steps $T$ as follows: 200 | ```python 201 | T = 3 202 | scale = 5000. 203 | pos = torch.linspace(0., scale, T) 204 | # e.g. WaveFit-5 case 205 | # pos = [0., 1250., 2500., 3750., 5000.] 206 | ``` 207 | The scale value of 5000 was referenced from the WaveGrad paper. 208 | 209 | The positional embeddings calculated from these coefficients are fixed vectors, so they can be precomputed. 210 | 211 | ## 3. Gain adjustment 212 | 213 | In WaveFit iteration, the power of estimated signals are adjusted by the gain adjustment function 214 | with a power spectrogram computed from an input mel-spectrogram (Sec.4.1). 215 | In the paper, it is stated that the gain adjustment is calculated using the power of the estimated spectrogram $P_z$ and the input (pseudo) spectrogram $P_c$ according to the following formula. 216 | ```math 217 | y_t = (P_c / (P_z + 10^{-8}))z_t 218 | ``` 219 | However, from my understanding, since the signal amplitude is proportional to the square root of the power ratio, it should actually be implemented as follows. Without this adjustment, the model will not learn correctly. 220 | ```math 221 | y_t = \sqrt{(P_c / (P_z + 10^{-8}))}z_t 222 | ``` 223 | 224 | ## 4. Optimizer of training 225 | 226 | In the WaveFit paper, it states that the same optimizer settings as WaveGrad were used, but I could not find a description of that in the WaveGrad paper. Therefore, I referred to the following paper to configure the learning rate. 227 | - "Noise Level Limited Sub-Modeling for Diffusion Probabilistic Vocoders", Okamoto et al., 2021 228 | 229 | Additionally, based on insights into stability in GAN training, I have incorporated the following techniques. 230 | - Gradient clipping 231 | - Learning rate warm-up 232 | - Learning rate decaying 233 | 234 | # TODO 235 | 236 | - [x] Useful inference scripts 237 | 238 | # References 239 | 240 | 1. "WaveFit: An Iterative and Non-autoregressive Neural Vocoder based on Fixed-Point Iteration", Y. Koizumi et al., IEEE SLT, 2022 241 | 1. "WaveGrad: Estimating Gradients for Waveform Generation", N. Chen et al., ICLR, 2021 242 | 1. "Miipher-2: A Universal Speech Restoration Model for Million-Hour Scale Data Restoration", S. Karita et al., 2025 -------------------------------------------------------------------------------- /_LICENSES/SpecDiff-GAN.LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SpecDiff-GAN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /_LICENSES/aeiou.LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /_LICENSES/melgan-neurips.LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Descript Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /_LICENSES/stable-audio-tools.LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /_LICENSES/wavegrad.LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 LMNT, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /assets/fig/wavefit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukara-ikemiya/wavefit-pytorch/13d6a0dc87dd6e5bd35fba69f889dc5fc7d53bee/assets/fig/wavefit.png -------------------------------------------------------------------------------- /configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | batchsize: 1 3 | 4 | train: 5 | _target_: data.AudioFilesDataset 6 | dir_list: [] 7 | sample_size: 36000 8 | sample_rate: 24000 9 | out_channels: 'mono' 10 | exts: ['wav'] 11 | augment_shift: True 12 | augment_flip: True 13 | augment_volume: True 14 | max_samples: null 15 | 16 | test: 17 | _target_: data.AudioFilesDataset 18 | dir_list: [] 19 | sample_size: 36000 20 | sample_rate: 24000 21 | out_channels: 'mono' 22 | exts: ['wav'] 23 | augment_shift: False 24 | augment_flip: False 25 | augment_volume: False 26 | max_samples: null -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: wavefit-3 3 | - data: default 4 | - trainer: default 5 | - optimizer: default 6 | - _self_ 7 | 8 | # no hydra logging 9 | hydra: 10 | output_subdir: null 11 | run: 12 | dir: . 13 | # multirun 14 | sweep: 15 | dir: . 16 | 17 | -------------------------------------------------------------------------------- /configs/model/wavefit-3.yaml: -------------------------------------------------------------------------------- 1 | 2 | generator: 3 | 4 | model: 5 | _target_: model.WaveFit 6 | # Number of WaveFit iteration (e.g. WaveFit-5 -> 5) 7 | num_iteration: 3 8 | # Architecture type 9 | memory_efficient_miipher2: false 10 | # Mel-spectrogram setting 11 | args_mel: 12 | sr: 24000 13 | n_fft: 2048 14 | win_size: 1200 15 | hop_size: 300 # You cannot change this parameter as long as you use the default WaveGrad generator 16 | n_mels: 128 17 | fmin: 20.0 18 | fmax: 12000.0 19 | 20 | 21 | discriminator: 22 | 23 | model: 24 | _target_: model.Discriminator 25 | num_D: 3 26 | ndf: 16 27 | n_layers: 4 28 | downsampling_factor: 4 29 | 30 | 31 | loss: 32 | 33 | mrstft: 34 | _target_: loss.MRSTFTLoss 35 | n_ffts: [512, 1024, 2048] 36 | win_sizes: [360, 900, 1800] 37 | hop_sizes: [80, 150, 300] 38 | 39 | melmae: 40 | _target_: loss.MELMAELoss 41 | sr: 24000 42 | n_fft: 1024 43 | win_size: 900 44 | hop_size: 150 45 | n_mels: 128 46 | fmin: 20.0 47 | fmax: 12000. 48 | 49 | # Weights of each loss 50 | # This is the LibriTTS setting in the WaveFit paper (Sec.5.1) 51 | lambdas: 52 | G/disc_gan_loss: 1.0 53 | G/disc_feat_loss: 10.0 54 | G/mrstft_sc_loss: 2.5 55 | G/mrstft_mag_loss: 2.5 56 | G/mel_mae_loss: 0.0 -------------------------------------------------------------------------------- /configs/model/wavefit-3_mem-efficient.yaml: -------------------------------------------------------------------------------- 1 | 2 | generator: 3 | 4 | model: 5 | _target_: model.WaveFit 6 | # Number of WaveFit iteration (e.g. WaveFit-5 -> 5) 7 | num_iteration: 3 8 | # Architecture type 9 | memory_efficient_miipher2: true 10 | # Mel-spectrogram setting 11 | args_mel: 12 | sr: 24000 13 | n_fft: 2048 14 | win_size: 1200 15 | hop_size: 300 # You cannot change this parameter as long as you use the default WaveGrad generator 16 | n_mels: 128 17 | fmin: 20.0 18 | fmax: 12000.0 19 | 20 | 21 | discriminator: 22 | 23 | model: 24 | _target_: model.Discriminator 25 | num_D: 3 26 | ndf: 16 27 | n_layers: 4 28 | downsampling_factor: 4 29 | 30 | 31 | loss: 32 | 33 | mrstft: 34 | _target_: loss.MRSTFTLoss 35 | n_ffts: [512, 1024, 2048] 36 | win_sizes: [360, 900, 1800] 37 | hop_sizes: [80, 150, 300] 38 | 39 | melmae: 40 | _target_: loss.MELMAELoss 41 | sr: 24000 42 | n_fft: 1024 43 | win_size: 900 44 | hop_size: 150 45 | n_mels: 128 46 | fmin: 20.0 47 | fmax: 12000. 48 | 49 | # Weights of each loss 50 | # This is the LibriTTS setting in the WaveFit paper (Sec.5.1) 51 | lambdas: 52 | G/disc_gan_loss: 1.0 53 | G/disc_feat_loss: 10.0 54 | G/mrstft_sc_loss: 2.5 55 | G/mrstft_mag_loss: 2.5 56 | G/mel_mae_loss: 0.0 -------------------------------------------------------------------------------- /configs/optimizer/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Generator 3 | 4 | G: 5 | optimizer: 6 | _partial_: true 7 | _target_: torch.optim.AdamW 8 | betas: [0.8, 0.99] 9 | lr: 0.0001 10 | weight_decay: 0.001 11 | 12 | scheduler: 13 | _partial_: true 14 | _target_: utils.scheduler.InverseLR 15 | inv_gamma: 200000 16 | power: 0.5 17 | warmup: 0.999 18 | 19 | # Discriminator 20 | 21 | D: 22 | optimizer: 23 | _partial_: true 24 | _target_: torch.optim.AdamW 25 | betas: [0.8, 0.99] 26 | lr: 0.0002 27 | weight_decay: 0.001 28 | 29 | scheduler: 30 | _partial_: true 31 | _target_: utils.scheduler.InverseLR 32 | inv_gamma: 200000 33 | power: 0.5 34 | warmup: 0.999 35 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Root directory for outputs 3 | output_dir: ??? 4 | 5 | # Checkpoint directory for resuming training 6 | ckpt_dir: null 7 | 8 | # Batch size 9 | # NOTE: Batch size must be a multiple of GPU number 10 | batch_size: 1 11 | num_workers: 2 12 | 13 | # Seed value used for rng initialization 14 | seed: 0 15 | 16 | # Automatic mixed precision 17 | # Choose from ‘no’,‘fp16’,‘bf16’ or ‘fp8’. 18 | amp: 'no' # 'fp16' 19 | 20 | # Max norm of gradient clipping 21 | max_grad_norm: 1.0 22 | 23 | logger: 24 | project_name: 'project_name' 25 | run_name: 'run_name' 26 | 27 | logging: 28 | # Step interval for logging metrics / saving checkpoints 29 | # / generating samples / test (validation) / printing metrics 30 | n_step_log: 20 31 | n_step_ckpt: 10000 32 | n_step_sample: 2000 33 | n_step_test: 10000 34 | n_step_print: 1000 35 | # Number of generated samples 36 | n_samples: 2 37 | 38 | metrics_for_best_ckpt: ['G/mrstft_sc_loss','G/mrstft_mag_loss'] -------------------------------------------------------------------------------- /container/build_singularity.bash: -------------------------------------------------------------------------------- 1 | # Create `wavefit.sif` file 2 | singularity build --fakeroot wavefit.sif wavefit.def -------------------------------------------------------------------------------- /container/wavefit.def: -------------------------------------------------------------------------------- 1 | # Pytorch 24.01 -> CUDA 12.3.2, PyTorch 2.2.0 2 | Bootstrap: docker 3 | From: nvcr.io/nvidia/pytorch:24.01-py3 4 | 5 | %post 6 | apt-get update && python -m pip install --upgrade pip 7 | 8 | # pip 9 | python -m pip install --no-cache-dir setuptools wheel 10 | 11 | python -m pip install --no-cache-dir torchaudio 12 | 13 | python -m pip install --no-cache-dir numpy librosa pillow 14 | 15 | python -m pip install --no-cache-dir -U accelerate hydra-core wandb 16 | 17 | 18 | %environment 19 | export PYTHONIOENCODING=utf-8 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0 2 | torchaudio 3 | numpy 4 | librosa 5 | pillow 6 | 7 | accelerate 8 | hydra-core 9 | wandb -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import AudioFilesDataset 2 | -------------------------------------------------------------------------------- /src/data/audio_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import math 6 | import json 7 | 8 | from torch.nn import functional as F 9 | import torchaudio 10 | from torchaudio import transforms as T 11 | 12 | 13 | def get_audio_metadata(filepath, cache=True): 14 | try: 15 | with open(filepath + '.json', 'r') as f: 16 | info = json.load(f) 17 | return info 18 | except Exception: 19 | try: 20 | info_ = torchaudio.info(filepath) 21 | sample_rate = info_.sample_rate 22 | num_channels = info_.num_channels 23 | num_frames = info_.num_frames 24 | 25 | info = { 26 | 'sample_rate': sample_rate, 27 | 'num_frames': num_frames, 28 | 'num_channels': num_channels 29 | } 30 | except Exception: 31 | # error : cannot open an audio file 32 | info = {'sample_rate': 0, 'num_frames': 0, 'num_channels': 0} 33 | 34 | if cache: 35 | with open(filepath + '.json', 'w') as f: 36 | json.dump(info, f, indent=2) 37 | 38 | return info 39 | 40 | 41 | def load_audio_with_pad(filepath, info: dict, sr: int, n_samples: int, offset: int): 42 | sr_in, num_frames = info['sample_rate'], info['num_frames'] 43 | n_samples_in = int(math.ceil(n_samples * (sr_in / sr))) 44 | 45 | # load audio 46 | ext = filepath.split(".")[-1] 47 | out_frames = min(n_samples_in, num_frames - offset) 48 | 49 | audio, _ = torchaudio.load( 50 | filepath, frame_offset=offset, num_frames=out_frames, 51 | format=ext, backend='soundfile') 52 | 53 | # resample 54 | if sr_in != sr: 55 | resample_tf = T.Resample(sr_in, sr) 56 | audio = resample_tf(audio)[..., :n_samples] 57 | 58 | # zero pad 59 | L = audio.shape[-1] 60 | if L < n_samples: 61 | audio = F.pad(audio, (0, n_samples - L), value=0.) 62 | 63 | return audio 64 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | import random 7 | import typing as tp 8 | 9 | import torch 10 | import numpy as np 11 | 12 | from utils.torch_common import print_once 13 | from .modification import Stereo, Mono, PhaseFlipper, VolumeChanger 14 | from .audio_io import get_audio_metadata, load_audio_with_pad 15 | 16 | 17 | def fast_scandir(dir: str, ext: tp.List[str]): 18 | """ Very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243 19 | 20 | fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py 21 | 22 | Args: 23 | dir (str): top-level directory at which to begin scanning. 24 | ext (tp.List[str]): list of allowed file extensions. 25 | """ 26 | subfolders, files = [], [] 27 | # add starting period to extensions if needed 28 | ext = ['.' + x if x[0] != '.' else x for x in ext] 29 | 30 | try: # hope to avoid 'permission denied' by this try 31 | for f in os.scandir(dir): 32 | try: # 'hope to avoid too many levels of symbolic links' error 33 | if f.is_dir(): 34 | subfolders.append(f.path) 35 | elif f.is_file(): 36 | is_hidden = os.path.basename(f.path).startswith(".") 37 | has_ext = os.path.splitext(f.name)[1].lower() in ext 38 | 39 | if has_ext and (not is_hidden): 40 | files.append(f.path) 41 | except Exception: 42 | pass 43 | except Exception: 44 | pass 45 | 46 | for dir in list(subfolders): 47 | sf, f = fast_scandir(dir, ext) 48 | subfolders.extend(sf) 49 | files.extend(f) 50 | 51 | return subfolders, files 52 | 53 | 54 | def get_audio_filenames( 55 | paths: tp.List[str], # directories in which to search 56 | exts: tp.List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] 57 | ): 58 | """recursively get a list of audio filenames""" 59 | if isinstance(paths, str): 60 | paths = [paths] 61 | 62 | # get a list of relevant filenames 63 | filenames = [] 64 | nums_file = [] 65 | for p in paths: 66 | _, files = fast_scandir(p, exts) 67 | files.sort() 68 | filenames.extend(files) 69 | nums_file.append(len(files)) 70 | 71 | return filenames, nums_file 72 | 73 | 74 | class AudioFilesDataset(torch.utils.data.Dataset): 75 | def __init__( 76 | self, 77 | dir_list: tp.List[str], 78 | sample_size: int = 36000, 79 | sample_rate: int = 24000, 80 | out_channels="mono", 81 | exts: tp.List[str] = ['wav'], 82 | # augmentation 83 | augment_shift: bool = True, 84 | augment_flip: bool = True, 85 | augment_volume: bool = True, 86 | # Others 87 | max_samples: tp.Optional[int] = None 88 | ): 89 | assert out_channels in ['mono', 'stereo'] 90 | 91 | super().__init__() 92 | self.sample_size = sample_size 93 | self.sr = sample_rate 94 | self.augment_shift = augment_shift 95 | self.out_channels = out_channels 96 | 97 | self.ch_encoding = torch.nn.Sequential( 98 | Stereo() if self.out_channels == "stereo" else torch.nn.Identity(), 99 | Mono() if self.out_channels == "mono" else torch.nn.Identity(), 100 | ) 101 | 102 | self.augs = torch.nn.Sequential( 103 | PhaseFlipper() if augment_flip else torch.nn.Identity(), 104 | VolumeChanger() if augment_volume else torch.nn.Identity() 105 | ) 106 | 107 | # find all audio files 108 | print_once('->-> Searching audio files...') 109 | self.filenames, _ = get_audio_filenames(dir_list, exts=exts) 110 | # sort 111 | self.filenames.sort() 112 | 113 | max_samples = max_samples if max_samples else len(self.filenames) 114 | self.filenames = self.filenames[:max_samples] 115 | print_once(f'->-> Found {len(self.filenames)} files.') 116 | 117 | # This may take long time if file number is large. 118 | print_once('->-> Loading audio metadata... (this may take a long time)') 119 | self.metas = [] 120 | self.durs = [] 121 | for idx, filepath in enumerate(self.filenames): 122 | info = get_audio_metadata(filepath, cache=True) 123 | self.metas.append(info) 124 | self.durs.append(info['num_frames']) 125 | 126 | self.cs_durs = np.cumsum(self.durs) 127 | 128 | def get_index_offset(self, item): 129 | """ 130 | Return a track index and frame offset 131 | """ 132 | # For a given dataset item and shift, return song index and offset within song 133 | half_size = self.sample_size // 2 134 | shift = np.random.randint(-half_size, half_size) if self.augment_shift else 0 135 | offset = item * self.sample_size + shift # Note we centred shifts, so adding now 136 | midpoint = offset + half_size 137 | 138 | index = np.searchsorted(self.cs_durs, midpoint) 139 | start, end = self.cs_durs[index - 1] if index > 0 else 0, self.cs_durs[index] # start and end of current song 140 | assert start <= midpoint <= end 141 | 142 | if offset > end - self.sample_size: # Going over song 143 | offset = max(start, offset - half_size) 144 | elif offset < start: # Going under song 145 | offset = start 146 | 147 | offset -= start 148 | return index, offset 149 | 150 | def __len__(self): 151 | return int(np.floor(self.cs_durs[-1] / self.sample_size)) 152 | 153 | def __getitem__(self, idx): 154 | idx_file, offset = self.get_index_offset(idx) 155 | filename = self.filenames[idx_file] 156 | info = self.metas[idx_file] 157 | 158 | try: 159 | audio = load_audio_with_pad(filename, info, self.sr, self.sample_size, offset) 160 | 161 | # Fix channel number 162 | audio = self.ch_encoding(audio) 163 | 164 | # Audio augmentations 165 | audio = self.augs(audio) 166 | audio = audio.clamp(-1, 1) 167 | 168 | return (audio, info) 169 | 170 | except Exception as e: 171 | print(f'Couldn\'t load file {filename} (INFO: {info}, offset: {offset}): {e}') 172 | return self[random.randrange(len(self))] 173 | -------------------------------------------------------------------------------- /src/data/modification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import random 6 | 7 | import torch 8 | from torch import nn 9 | import numpy as np 10 | 11 | 12 | # Channels 13 | 14 | class Mono(nn.Module): 15 | def __call__(self, x: torch.Tensor): 16 | assert len(x.shape) <= 2 17 | return torch.mean(x, dim=0, keepdims=True) if len(x.shape) > 1 else x 18 | 19 | 20 | class Stereo(nn.Module): 21 | def __call__(self, x: torch.Tensor): 22 | x_shape = x.shape 23 | assert len(x_shape) <= 2 24 | # Check if it's mono 25 | if len(x_shape) == 1: # s -> 2, s 26 | x = x.unsqueeze(0).repeat(2, 1) 27 | elif len(x_shape) == 2: 28 | if x_shape[0] == 1: # 1, s -> 2, s 29 | x = x.repeat(2, 1) 30 | elif x_shape[0] > 2: # ?, s -> 2,s 31 | x = x[:2, :] 32 | 33 | return x 34 | 35 | 36 | # Augmentation 37 | 38 | class PhaseFlipper(nn.Module): 39 | """Randomly invert the phase of a signal""" 40 | 41 | def __init__(self, p=0.5): 42 | super().__init__() 43 | self.p = p 44 | 45 | def __call__(self, x: torch.Tensor): 46 | assert len(x.shape) <= 2 47 | return -x if (random.random() < self.p) else x 48 | 49 | 50 | class VolumeChanger(nn.Module): 51 | """Randomly change volume (amplitude) of a signal""" 52 | 53 | def __init__(self, min_db: float = -3., max_db: float = 6.): 54 | super().__init__() 55 | self.min_db = min_db 56 | self.max_db = max_db 57 | self.rng = np.random.default_rng() 58 | 59 | def __call__(self, x: torch.Tensor): 60 | assert len(x.shape) <= 2 61 | 62 | max_db = min(self.max_db, 20 * np.log10(1. / (np.abs(x).max() + 1.0e-8))) # amp <= 1.0 63 | db = self.rng.uniform(self.min_db, max_db) 64 | gain = 10 ** (db / 20.) 65 | return x * gain 66 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.dont_write_bytecode = True 8 | import argparse 9 | import math 10 | 11 | import hydra 12 | import torch 13 | import torchaudio 14 | from accelerate import Accelerator 15 | from omegaconf import DictConfig, OmegaConf 16 | 17 | from utils.torch_common import get_rank, get_world_size, print_once 18 | from data.dataset import get_audio_filenames 19 | 20 | 21 | def make_audio_batch(audio, sample_size: int, overlap: int): 22 | """ 23 | audio : (ch, L) 24 | """ 25 | assert 0 <= overlap < sample_size 26 | L = audio.shape[-1] 27 | shift = sample_size - overlap 28 | 29 | n_split = math.ceil(max(L - sample_size, 0) / shift) + 1 30 | # to mono 31 | audio = audio.mean(0) # (L) 32 | batch = [] 33 | for n in range(n_split): 34 | b = audio[n * shift: n * shift + sample_size] 35 | if n == n_split - 1: 36 | b = torch.nn.functional.pad(b, (0, sample_size - len(b))) 37 | batch.append(b) 38 | 39 | batch = torch.stack(batch, dim=0).unsqueeze(1) # (n_split, 1, sample_size) 40 | return batch, L 41 | 42 | 43 | def cross_fade(preds, overlap: int, L: int): 44 | """ 45 | preds: (bs, 1, sample_size) 46 | """ 47 | bs, _, sample_size = preds.shape 48 | shift = sample_size - overlap 49 | full_L = sample_size + (bs - 1) * shift 50 | win = torch.bartlett_window(overlap * 2, device=preds.device) 51 | 52 | buf = torch.zeros(1, full_L, device=preds.device) 53 | for idx in range(bs): 54 | pred = preds[idx] # (1, sample_size) 55 | ofs = idx * shift 56 | if idx != 0: 57 | pred[:, :overlap] *= win[None, :overlap] 58 | if idx != bs - 1: 59 | pred[:, -overlap:] *= win[None, overlap:] 60 | 61 | buf[:, ofs:ofs + sample_size] += pred 62 | 63 | buf = buf[..., :L] 64 | return buf 65 | 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--ckpt-dir', type=str, help="Checkpoint directory.") 70 | parser.add_argument('--input-audio-dir', type=str, help="Root directory which contains input audio files.") 71 | parser.add_argument('--output-dir', type=str, help="Output directory.") 72 | parser.add_argument('--sample-size', type=int, default=36000, help="Input sample size.") 73 | parser.add_argument('--max-batch-size', type=int, default=10, help="Max batch size for inference.") 74 | parser.add_argument('--overlap-rate', type=float, default=0.02, help="Overlap rate for inference.") 75 | parser.add_argument('--use-original-name', default=True, type=bool, help="Whether to use an original file name as an output name.") 76 | args = parser.parse_args() 77 | 78 | ckpt_dir = args.ckpt_dir 79 | input_audio_dir = args.input_audio_dir 80 | output_dir = args.output_dir 81 | sample_size = args.sample_size 82 | max_batch_size = args.max_batch_size 83 | overlap_rate = args.overlap_rate 84 | use_original_name = args.use_original_name 85 | 86 | # Distributed inference 87 | accel = Accelerator() 88 | device = accel.device 89 | rank = get_rank() 90 | world_size = get_world_size() 91 | 92 | print_once(f"Checkpoint dir : {ckpt_dir}") 93 | print_once(f"Input audio dir : {input_audio_dir}") 94 | print_once(f"Output dir : {output_dir}") 95 | 96 | # Load WaveFit model 97 | cfg_ckpt = OmegaConf.load(f'{ckpt_dir}/config.yaml') 98 | wavefit = hydra.utils.instantiate(cfg_ckpt.model.generator.model) 99 | wavefit.load_state_dict(torch.load(f"{ckpt_dir}/generator.pth", weights_only=False)) 100 | wavefit.to(device) 101 | wavefit.eval() 102 | print_once("->-> Successfully loaded WaveFit model from checkpoint.") 103 | 104 | hop_size = wavefit.args_mel['hop_size'] 105 | overlap = math.ceil(sample_size * overlap_rate / hop_size) * hop_size 106 | 107 | # Get audio files 108 | files, _ = get_audio_filenames(input_audio_dir) 109 | print_once(f"->-> Found {len(files)} audio files.") 110 | # Split files 111 | files = files[rank::world_size] 112 | 113 | for idx, f_path in enumerate(files): 114 | # load and split audio 115 | audio, sr = torchaudio.load(f_path) 116 | audio_batch, L = make_audio_batch(audio, sample_size, overlap) 117 | n_iter = math.ceil(audio_batch.shape[0] / max_batch_size) 118 | 119 | audio_batch = audio_batch.to(device) 120 | 121 | # execute 122 | preds = [] 123 | for n in range(n_iter): 124 | batch_ = audio_batch[n * max_batch_size:(n + 1) * max_batch_size] 125 | with torch.no_grad(): 126 | mel_spec, r_noise, pstft_spec = wavefit.mel.get_shaped_noise(batch_) 127 | pred = wavefit(r_noise, torch.log(mel_spec.clamp(min=1e-8)), pstft_spec, return_only_last=True)[-1] 128 | 129 | preds.append(pred) 130 | 131 | preds = torch.cat(preds, dim=0) 132 | 133 | # cross-fade 134 | pred_audio = cross_fade(preds, overlap, L).cpu() 135 | 136 | # save audio 137 | out_name = os.path.splitext(os.path.basename(f_path))[0] if use_original_name else f"sample_{idx}" 138 | out_path = f"{output_dir}/{out_name}.wav" 139 | torchaudio.save(out_path, pred_audio, sample_rate=sr, encoding="PCM_F") 140 | 141 | print(f"--- Rank-{rank} : Finished. ---") 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .mrstft import MRSTFTLoss, MELMAELoss 2 | -------------------------------------------------------------------------------- /src/loss/mrstft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import typing as tp 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from utils.audio_utils import get_amplitude_spec, MelSpectrogram 12 | 13 | 14 | class MRSTFTLoss(nn.Module): 15 | """ 16 | Multi-resolution STFT loss corresponding to the eq.(9) 17 | """ 18 | 19 | def __init__( 20 | self, 21 | n_ffts: tp.List[int] = [512, 1024, 2048], 22 | win_sizes: tp.List[int] = [360, 900, 1800], 23 | hop_sizes: tp.List[int] = [80, 150, 300], 24 | EPS: float = 1e-5 25 | ): 26 | super().__init__() 27 | assert len(n_ffts) == len(win_sizes) == len(hop_sizes) 28 | self.n_ffts = n_ffts 29 | self.win_sizes = win_sizes 30 | self.hop_sizes = hop_sizes 31 | 32 | # NOTE: Since spectral convergence is quite sensitive to small values in the spectrum, 33 | # I believe setting a higher lower bound will result in more stable training. 34 | self.EPS = EPS 35 | 36 | def forward( 37 | self, 38 | pred: torch.Tensor, 39 | target: torch.Tensor 40 | ): 41 | losses = { 42 | 'G/mrstft_sc_loss': 0., 43 | 'G/mrstft_mag_loss': 0. 44 | } 45 | 46 | for n_fft, win_size, hop_size in zip(self.n_ffts, self.win_sizes, self.hop_sizes): 47 | window = torch.hann_window(win_size, device=pred.device) 48 | spec_t = get_amplitude_spec(target.squeeze(1), n_fft, win_size, hop_size, window) 49 | spec_p = get_amplitude_spec(pred.squeeze(1), n_fft, win_size, hop_size, window) 50 | 51 | # spectral convergence 52 | sc_loss = (spec_t - spec_p).norm(p=2) / (spec_t.norm(p=2) + self.EPS) 53 | 54 | # magnitude loss 55 | mag_loss = F.l1_loss(torch.log(spec_t.clamp(min=self.EPS)), torch.log(spec_p.clamp(min=self.EPS))) 56 | 57 | losses['G/mrstft_sc_loss'] += sc_loss 58 | losses['G/mrstft_mag_loss'] += mag_loss 59 | 60 | losses['G/mrstft_sc_loss'] /= len(self.n_ffts) 61 | losses['G/mrstft_mag_loss'] /= len(self.n_ffts) 62 | 63 | return losses 64 | 65 | 66 | class MELMAELoss(nn.Module): 67 | """ 68 | MAE(L1) loss of Mel spectrogram corresponding to the second term of the eq.(19) 69 | """ 70 | 71 | def __init__( 72 | self, 73 | sr: int = 24000, 74 | n_fft: int = 1024, 75 | win_size: int = 900, 76 | hop_size: int = 150, 77 | n_mels: int = 128, 78 | fmin: float = 20., 79 | fmax: float = 12000. 80 | ): 81 | super().__init__() 82 | 83 | self.mel = MelSpectrogram(sr, n_fft, win_size, hop_size, n_mels, fmin, fmax) 84 | 85 | def forward( 86 | self, 87 | pred: torch.Tensor, 88 | target: torch.Tensor 89 | ): 90 | losses = {'G/mel_mae_loss': 0.} 91 | 92 | # Mel MAE (L1) loss 93 | mel_p = self.mel.compute_mel(pred.squeeze(1)) 94 | mel_t = self.mel.compute_mel(target.squeeze(1)) 95 | 96 | losses['G/mel_mae_loss'] = F.l1_loss(mel_t, mel_p) 97 | 98 | return losses 99 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import Generator 2 | from .discriminator import Discriminator 3 | from .wavefit import WaveFit 4 | -------------------------------------------------------------------------------- /src/model/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | 4 | Adapted from the following repo's code under MIT License. 5 | https://github.com/descriptinc/melgan-neurips/ 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.utils.parametrizations import weight_norm 12 | 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find("Conv") != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | elif classname.find("BatchNorm2d") != -1: 19 | m.weight.data.normal_(1.0, 0.02) 20 | m.bias.data.fill_(0) 21 | 22 | 23 | def WNConv1d(*args, **kwargs): 24 | return weight_norm(nn.Conv1d(*args, **kwargs)) 25 | 26 | 27 | class NLayerDiscriminator(nn.Module): 28 | def __init__(self, ndf, n_layers, downsampling_factor): 29 | super().__init__() 30 | model = nn.ModuleDict() 31 | 32 | model["layer_0"] = nn.Sequential( 33 | nn.ReflectionPad1d(7), 34 | WNConv1d(1, ndf, kernel_size=15), 35 | nn.LeakyReLU(0.2, True), 36 | ) 37 | 38 | nf = ndf 39 | stride = downsampling_factor 40 | for n in range(1, n_layers + 1): 41 | nf_prev = nf 42 | nf = min(nf * stride, 1024) 43 | 44 | model["layer_%d" % n] = nn.Sequential( 45 | WNConv1d( 46 | nf_prev, 47 | nf, 48 | kernel_size=stride * 10 + 1, 49 | stride=stride, 50 | padding=stride * 5, 51 | groups=nf_prev // 4, 52 | ), 53 | nn.LeakyReLU(0.2, True), 54 | ) 55 | 56 | nf = min(nf * 2, 1024) 57 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 58 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 59 | nn.LeakyReLU(0.2, True), 60 | ) 61 | 62 | model["layer_%d" % (n_layers + 2)] = WNConv1d( 63 | nf, 1, kernel_size=3, stride=1, padding=1 64 | ) 65 | 66 | self.model = model 67 | 68 | def forward(self, x: torch.Tensor, return_feature: bool = True): 69 | """ 70 | Args: 71 | x: input audio, (bs, 1, L) 72 | """ 73 | n_layer = len(self.model) 74 | results = [] 75 | for idx, (key, layer) in enumerate(self.model.items()): 76 | x = layer(x) 77 | if return_feature or (idx == n_layer - 1): 78 | results.append(x) 79 | 80 | return results 81 | 82 | 83 | class Discriminator(nn.Module): 84 | def __init__( 85 | self, 86 | num_D: int = 3, 87 | ndf: int = 16, 88 | n_layers: int = 4, 89 | downsampling_factor: int = 4 90 | ): 91 | super().__init__() 92 | self.model = nn.ModuleDict() 93 | for i in range(num_D): 94 | self.model[f"disc_{i}"] = NLayerDiscriminator( 95 | ndf, n_layers, downsampling_factor 96 | ) 97 | 98 | self.downsample = nn.AvgPool1d(4, stride=2, padding=1, count_include_pad=False) 99 | self.apply(weights_init) 100 | 101 | def forward(self, x: torch.Tensor, return_feature: bool = True): 102 | """ 103 | Args: 104 | x: input audio, (bs, 1, L) 105 | """ 106 | results = [] 107 | for key, disc in self.model.items(): 108 | results.append(disc(x, return_feature)) 109 | x = self.downsample(x) 110 | 111 | return results 112 | 113 | def compute_G_loss(self, x_fake, x_real): 114 | """ 115 | The eq.(18) loss 116 | """ 117 | assert x_fake.shape == x_real.shape 118 | 119 | out_f = self(x_fake, return_feature=True) 120 | with torch.no_grad(): 121 | out_r = self(x_real, return_feature=True) 122 | 123 | num_D = len(self.model) 124 | losses = { 125 | 'G/disc_gan_loss': 0., 126 | 'G/disc_feat_loss': 0. 127 | } 128 | 129 | for i_d in range(num_D): 130 | n_layer = len(out_f[i_d]) 131 | 132 | # GAN loss 133 | losses['G/disc_gan_loss'] += (1 - out_f[i_d][-1]).relu().mean() 134 | 135 | # Feature-matching loss 136 | # eq.(8) 137 | feat_loss = 0. 138 | for i_l in range(n_layer - 1): 139 | feat_loss += F.l1_loss(out_f[i_d][i_l], out_r[i_d][i_l]) 140 | 141 | losses['G/disc_feat_loss'] += feat_loss / (n_layer - 1) 142 | 143 | losses['G/disc_gan_loss'] /= num_D 144 | losses['G/disc_feat_loss'] /= num_D 145 | 146 | return losses 147 | 148 | def compute_D_loss(self, x, mode: str): 149 | """ 150 | The eq.(7) loss 151 | """ 152 | assert mode in ['fake', 'real'] 153 | sign = 1 if mode == 'fake' else -1 154 | 155 | out = self(x, return_feature=False) 156 | 157 | num_D = len(self.model) 158 | losses = {'D/loss': 0.} 159 | 160 | for i_d in range(num_D): 161 | # Hinge loss 162 | losses['D/loss'] += (1 + sign * out[i_d][-1]).relu().mean() 163 | 164 | losses['D/loss'] /= num_D 165 | 166 | return losses 167 | -------------------------------------------------------------------------------- /src/model/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | 4 | Adapted from the following repo's code under Apache License 2.0. 5 | https://github.com/lmnt-com/wavegrad/ 6 | """ 7 | 8 | import typing as tp 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | class Conv1d(nn.Conv1d): 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.reset_parameters() 20 | 21 | def reset_parameters(self): 22 | nn.init.orthogonal_(self.weight) 23 | nn.init.zeros_(self.bias) 24 | 25 | 26 | class SinusoidalPositionalEncoding(nn.Module): 27 | def __init__(self, dim: int, max_iter: int, use_conv: bool): 28 | super().__init__() 29 | self.dim = dim 30 | self.max_iter = max_iter 31 | self.use_conv = use_conv 32 | assert dim % 2 == 0 33 | 34 | if use_conv: 35 | # 1x1 conv 36 | self.conv = nn.Conv1d(dim, dim, 1) 37 | nn.init.xavier_uniform_(self.conv.weight) 38 | nn.init.zeros_(self.conv.bias) 39 | 40 | # pre-compute positional embedding 41 | pos_embs = self.prepare_embedding() # (max_iter, dim) 42 | self.register_buffer('pos_embs', pos_embs) 43 | 44 | def forward(self, x, t: int): 45 | """ 46 | Args: 47 | x: (bs, dim, T) 48 | t: Step index 49 | 50 | Returns: 51 | x_with_pos: (bs, dim, T) 52 | """ 53 | assert 0 <= t < self.max_iter, f"Invalid step index {t}. It must be 0 <= t < {self.max_iter} = max_iter." 54 | pos_emb = self.pos_embs[t][None, :, None] 55 | if self.use_conv: 56 | pos_emb = self.conv(pos_emb) 57 | 58 | return x + pos_emb 59 | 60 | def prepare_embedding(self, scale: float = 5000.): 61 | dim_h = self.dim // 2 62 | pos = torch.linspace(0., scale, self.max_iter) 63 | div_term = torch.exp(- math.log(10000.0) * torch.arange(dim_h) / dim_h) 64 | pos = pos[:, None] @ div_term[None, :] # (max_iter, dim_h) 65 | pos_embs = torch.cat([torch.sin(pos), torch.cos(pos)], dim=-1) # (max_iter, dim) 66 | return pos_embs 67 | 68 | 69 | class FiLM(nn.Module): 70 | def __init__(self, input_size: int, output_size: int, max_iter: int, memory_efficient: bool): 71 | super().__init__() 72 | self.step_condition = SinusoidalPositionalEncoding(input_size, max_iter, use_conv=memory_efficient) 73 | self.memory_efficient = memory_efficient 74 | self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) 75 | self.output_conv_1 = nn.Conv1d(input_size, output_size, 3, padding=1) 76 | if not memory_efficient: 77 | self.output_conv_2 = nn.Conv1d(input_size, output_size, 3, padding=1) 78 | self.reset_parameters() 79 | 80 | def reset_parameters(self): 81 | nn.init.xavier_uniform_(self.input_conv.weight) 82 | nn.init.zeros_(self.input_conv.bias) 83 | nn.init.xavier_uniform_(self.output_conv_1.weight) 84 | nn.init.zeros_(self.output_conv_1.bias) 85 | if not self.memory_efficient: 86 | nn.init.xavier_uniform_(self.output_conv_2.weight) 87 | nn.init.zeros_(self.output_conv_2.bias) 88 | 89 | def forward(self, x, t: int): 90 | x = self.input_conv(x) 91 | x = F.leaky_relu(x, 0.2) 92 | x = self.step_condition(x, t) 93 | shift = self.output_conv_1(x) 94 | scale = self.output_conv_2(x) if not self.memory_efficient else None 95 | 96 | return shift, scale 97 | 98 | 99 | class EmptyFiLM(nn.Module): 100 | def __init__(self): 101 | super().__init__() 102 | 103 | def forward(self, x, t): 104 | return 0, 1 105 | 106 | 107 | class UBlock(nn.Module): 108 | def __init__(self, input_size, hidden_size, factor, dilation): 109 | super().__init__() 110 | assert isinstance(dilation, (list, tuple)) 111 | assert len(dilation) == 4 112 | 113 | self.factor = factor 114 | self.block1 = Conv1d(input_size, hidden_size, 1) 115 | self.block2 = nn.ModuleList([ 116 | Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), 117 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]) 118 | ]) 119 | self.block3 = nn.ModuleList([ 120 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), 121 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]) 122 | ]) 123 | 124 | def forward(self, x, film_shift, film_scale: tp.Optional[torch.Tensor]): 125 | if film_scale is None: 126 | film_scale = 1.0 127 | 128 | block1 = F.interpolate(x, size=x.shape[-1] * self.factor) 129 | block1 = self.block1(block1) 130 | 131 | block2 = F.leaky_relu(x, 0.2) 132 | block2 = F.interpolate(block2, size=x.shape[-1] * self.factor) 133 | block2 = self.block2[0](block2) 134 | block2 = film_shift + film_scale * block2 135 | block2 = F.leaky_relu(block2, 0.2) 136 | block2 = self.block2[1](block2) 137 | 138 | x = block1 + block2 139 | 140 | block3 = film_shift + film_scale * x 141 | block3 = F.leaky_relu(block3, 0.2) 142 | block3 = self.block3[0](block3) 143 | block3 = film_shift + film_scale * block3 144 | block3 = F.leaky_relu(block3, 0.2) 145 | block3 = self.block3[1](block3) 146 | 147 | x = x + block3 148 | 149 | return x 150 | 151 | 152 | class DBlock(nn.Module): 153 | def __init__(self, input_size, hidden_size, factor): 154 | super().__init__() 155 | self.factor = factor 156 | 157 | # self.residual_dense = Conv1d(input_size, hidden_size, 1) 158 | # self.conv = nn.ModuleList([ 159 | # Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), 160 | # Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), 161 | # Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), 162 | # ]) 163 | 164 | # NOTE : This might be the correct architecture rather than the above one 165 | # since parameter size is quite closer to the reported size in the WaveGrad paper (15M). 166 | self.residual_dense = Conv1d(input_size, input_size, 1) 167 | self.conv = nn.ModuleList([ 168 | Conv1d(input_size, input_size, 3, dilation=1, padding=1), 169 | Conv1d(input_size, input_size, 3, dilation=2, padding=2), 170 | Conv1d(input_size, hidden_size, 3, dilation=4, padding=4), 171 | ]) 172 | 173 | # downsampling module using Conv1d 174 | # NOTE: When using kernel_size=3 for all downsampling factors, 175 | # the parameter size of generator is 15.12 millions. 176 | kernel_size = factor // 2 * 2 + 1 177 | padding = kernel_size // 2 178 | self.down1 = Conv1d(input_size, hidden_size, kernel_size, padding=padding, stride=factor) 179 | self.down2 = Conv1d(input_size, input_size, kernel_size, padding=padding, stride=factor) 180 | 181 | def forward(self, x): 182 | residual = self.residual_dense(x) 183 | residual = self.down1(residual) 184 | 185 | x = self.down2(x) 186 | for layer in self.conv: 187 | x = F.leaky_relu(x, 0.2) 188 | x = layer(x) 189 | 190 | return x + residual 191 | 192 | 193 | class Generator(nn.Module): 194 | """ 195 | This generator upsamples mel spectrogram with scale factor of 300. 196 | Specifically, input and output tensors must have the following size. 197 | 198 | Inputs: 199 | mel_spec: (bs, 128, num_frame) 200 | y_t: (bs, 1, num_frame x 300) 201 | Outputs: 202 | n_hat: (bs, 1, num_frame x 300) 203 | """ 204 | 205 | def __init__( 206 | self, 207 | num_iteration: int, 208 | memory_efficient: bool 209 | ): 210 | super().__init__() 211 | self.memory_efficient = memory_efficient 212 | self.downsample = nn.ModuleList([ 213 | Conv1d(1, 32, 5, padding=2), 214 | DBlock(32, 128, 2), 215 | DBlock(128, 128, 2), 216 | DBlock(128, 256, 3), 217 | DBlock(256, 512, 5), 218 | ]) 219 | self.film = nn.ModuleList([ 220 | FiLM(32, 128, num_iteration, memory_efficient=memory_efficient) if not memory_efficient else EmptyFiLM(), 221 | FiLM(128, 128, num_iteration, memory_efficient=memory_efficient), 222 | FiLM(128, 256, num_iteration, memory_efficient=memory_efficient), 223 | FiLM(256, 512, num_iteration, memory_efficient=memory_efficient), 224 | FiLM(512, 512, num_iteration, memory_efficient=memory_efficient), 225 | ]) 226 | self.upsample = nn.ModuleList([ 227 | UBlock(768, 512, 5, [1, 2, 1, 2]), 228 | UBlock(512, 512, 5, [1, 2, 1, 2]), 229 | UBlock(512, 256, 3, [1, 2, 4, 8]), 230 | UBlock(256, 128, 2, [1, 2, 4, 8]), 231 | UBlock(128, 128, 2, [1, 2, 4, 8]), 232 | ]) 233 | self.first_conv = Conv1d(128, 768, 3, padding=1) 234 | self.last_conv = Conv1d(128, 1, 3, padding=1) 235 | 236 | def forward(self, y_t, log_mel_spec, t: int): 237 | """ 238 | Args: 239 | y_t: Noisy input, (bs, 1, num_frame x 300) 240 | log_mel_spec: Log mel spectrogram, (bs, 128, num_frame) 241 | t: Step index 242 | Returns: 243 | n_hat: Estimated noise, (bs, 1, num_frame x 300) 244 | """ 245 | x = y_t 246 | 247 | downsampled = [] 248 | for film, layer in zip(self.film, self.downsample): 249 | x = layer(x) 250 | downsampled.append(film(x, t)) 251 | 252 | x = self.first_conv(log_mel_spec) 253 | for layer, (film_shift, film_scale) in zip(self.upsample, reversed(downsampled)): 254 | x = layer(x, film_shift, film_scale) 255 | x = self.last_conv(x) 256 | 257 | return x 258 | -------------------------------------------------------------------------------- /src/model/wavefit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from utils.audio_utils import MelSpectrogram, get_amplitude_spec 9 | from model import Generator 10 | 11 | 12 | class WaveFit(nn.Module): 13 | def __init__( 14 | self, 15 | num_iteration: int, 16 | memory_efficient_miipher2: bool = False, 17 | args_mel: dict = { 18 | 'sr': 24000, 19 | 'n_fft': 2048, 20 | 'win_size': 1200, 21 | 'hop_size': 300, 22 | 'n_mels': 128, 23 | 'fmin': 20., 24 | 'fmax': 12000. 25 | } 26 | ): 27 | super().__init__() 28 | 29 | self.T = num_iteration 30 | self.memory_efficient_miipher2 = memory_efficient_miipher2 31 | self.args_mel = args_mel 32 | self.mel = MelSpectrogram(**args_mel) 33 | self.generator = Generator(num_iteration, memory_efficient=memory_efficient_miipher2) 34 | self.EPS = 1e-8 35 | 36 | def forward( 37 | self, 38 | initial_noise: torch.Tensor, 39 | log_mel_spec: torch.Tensor, 40 | pstft_spec: torch.Tensor, 41 | # You can use this option at inference time 42 | return_only_last: bool = False 43 | ): 44 | """ 45 | Args: 46 | initial_noise: Initial noise, (bs, 1, L). 47 | log_mel_spec: Log Mel spectrogram, (bs, n_mels, L//hop_size). 48 | pstft_spec: Pseudo spectrogram used for gain adjustment. 49 | return_only_last: If true, only the last output (y_0) is returned. 50 | Returns: 51 | preds: List of predictions (y_t) 52 | """ 53 | assert initial_noise.dim() == log_mel_spec.dim() == 3 54 | assert initial_noise.shape[-1] == log_mel_spec.shape[-1] * self.args_mel['hop_size'] 55 | 56 | preds = [] 57 | y_t = initial_noise 58 | for t in range(self.T): 59 | # estimate noise 60 | est = self.generator(y_t, log_mel_spec, t) 61 | y_t = y_t - est 62 | 63 | # adjust gain 64 | y_t = self.adjust_gain(y_t, pstft_spec) 65 | 66 | if (not return_only_last) or (t == self.T - 1): 67 | preds.append(y_t) 68 | 69 | # To avoid gradient loop 70 | y_t = y_t.detach() 71 | 72 | return preds 73 | 74 | def adjust_gain(self, z_t, pstft_spec): 75 | num_frame = pstft_spec.shape[-1] 76 | power_spec_z = get_amplitude_spec( 77 | z_t.squeeze(1), self.args_mel['n_fft'], self.args_mel['win_size'], 78 | self.args_mel['hop_size'], self.mel.fft_win, return_power=True 79 | )[..., :num_frame] 80 | 81 | assert power_spec_z.shape == pstft_spec.shape 82 | pow_z = power_spec_z.mean(dim=[1, 2]) 83 | pow_c = pstft_spec.pow(2).mean(dim=[1, 2]) 84 | 85 | return z_t * torch.sqrt(pow_c[:, None, None] / (pow_z[:, None, None] + self.EPS)) 86 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | import sys 5 | sys.dont_write_bytecode = True 6 | 7 | import torch 8 | import torchaudio 9 | import hydra 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | from utils.audio_utils import MelSpectrogram 13 | from model import Generator 14 | from utils.torch_common import count_parameters 15 | 16 | 17 | @hydra.main(version_base=None, config_path='../configs/', config_name="default.yaml") 18 | def main(cfg: DictConfig): 19 | 20 | sr = 24000 21 | n_fft = 2048 22 | win_size = 1200 23 | hop_size = 300 24 | n_mels = 128 25 | 26 | train_dataset = hydra.utils.instantiate(cfg.data.train) 27 | test_dataset = hydra.utils.instantiate(cfg.data.train) 28 | 29 | print(len(train_dataset), len(test_dataset)) 30 | 31 | mel_module = MelSpectrogram( 32 | sr=sr, n_fft=n_fft, win_size=win_size, hop_size=hop_size, 33 | n_mels=n_mels, fmin=20., fmax=12000.) 34 | 35 | data, _ = test_dataset[1] 36 | print(data.shape) 37 | data = data.mean(0).unsqueeze(0) # (1, n_samples) 38 | print(data.shape) 39 | 40 | mel_spec = mel_module.compute_mel(data) 41 | print(mel_spec.shape) 42 | 43 | spec_env = mel_module.get_spec_env_from_mel(mel_spec, cep_order=24, return_min_phase=True) 44 | print(f"Spec_env: {spec_env.shape}") 45 | 46 | # noise reshape test 47 | noise = torch.randn(*data.shape, device=data.device) 48 | 49 | print(f"Input audio std: {data.std().item()}") 50 | print(f"noise(before) std: {noise.std().item()}") 51 | print(f"noise(before): max: {noise.max().item()}, min: {noise.min().item()}") 52 | 53 | num_frame = mel_spec.shape[-1] 54 | fft_win = mel_module.fft_win 55 | noise_spec = torch.stft( 56 | noise, n_fft, hop_length=hop_size, win_length=win_size, window=fft_win, 57 | center=True, normalized=False, onesided=True, return_complex=True) 58 | 59 | noise_spec = noise_spec[..., :num_frame] 60 | print(noise_spec.shape, spec_env.shape) 61 | assert noise_spec.shape == spec_env.shape 62 | 63 | noise_spec *= spec_env 64 | 65 | r_noise = torch.istft( 66 | noise_spec, n_fft, hop_length=hop_size, win_length=win_size, window=fft_win, 67 | center=True, normalized=False, length=data.shape[-1]) 68 | 69 | print(r_noise.shape) 70 | print(f"noise(after) std: {r_noise.std().item()}") 71 | print(f"noise(after): max: {r_noise.max().item()}, min: {r_noise.min().item()}") 72 | 73 | torchaudio.save("src.wav", data, sample_rate=sr, encoding="PCM_F") 74 | torchaudio.save("r_noise.wav", r_noise, sample_rate=sr, encoding="PCM_F") 75 | 76 | wavefit = Generator() 77 | wavefit.train() 78 | num_params_d = count_parameters(wavefit.downsample) 79 | num_params_u = count_parameters(wavefit.upsample) 80 | num_params_f = count_parameters(wavefit.film) 81 | print(f"Num params : {num_params_d + num_params_u + num_params_f}") 82 | print(f"Num params (down) : {num_params_d}") 83 | print(f"Num params (up) : {num_params_u}") 84 | print(f"Num params (film) : {num_params_f}") 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import sys 6 | sys.dont_write_bytecode = True 7 | 8 | # DDP 9 | from accelerate import Accelerator, DistributedDataParallelKwargs, DataLoaderConfiguration 10 | from accelerate.utils import ProjectConfiguration 11 | 12 | import torch 13 | import hydra 14 | from hydra.core.hydra_config import HydraConfig 15 | from omegaconf import DictConfig, OmegaConf 16 | 17 | from utils.torch_common import get_world_size, count_parameters, set_seed 18 | from trainer import Trainer 19 | 20 | 21 | @hydra.main(version_base=None, config_path='../configs/', config_name="default.yaml") 22 | def main(cfg: DictConfig): 23 | 24 | # Update config if ckpt_dir is specified (training resumption) 25 | if cfg.trainer.ckpt_dir is not None: 26 | overrides = HydraConfig.get().overrides.task 27 | overrides = [e for e in overrides if isinstance(e, str)] 28 | override_conf = OmegaConf.from_dotlist(overrides) 29 | cfg = OmegaConf.merge(cfg, override_conf) 30 | 31 | # Load checkpoint configuration 32 | cfg_ckpt = OmegaConf.load(f'{cfg.trainer.ckpt_dir}/config.yaml') 33 | cfg = OmegaConf.merge(cfg_ckpt, override_conf) 34 | 35 | # HuggingFace Accelerate for distributed training 36 | 37 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 38 | dl_config = DataLoaderConfiguration(split_batches=True) 39 | p_config = ProjectConfiguration(project_dir=cfg.trainer.output_dir) 40 | accel = Accelerator( 41 | mixed_precision=cfg.trainer.amp, 42 | dataloader_config=dl_config, 43 | project_config=p_config, 44 | kwargs_handlers=[ddp_kwargs], 45 | log_with='wandb' 46 | ) 47 | 48 | accel.init_trackers(cfg.trainer.logger.project_name, config=OmegaConf.to_container(cfg), 49 | init_kwargs={"wandb": {"name": cfg.trainer.logger.run_name, "dir": cfg.trainer.output_dir}}) 50 | 51 | if accel.is_main_process: 52 | print("->->-> DDP Initialized.") 53 | print(f"->->-> World size (Number of GPUs): {get_world_size()}") 54 | 55 | set_seed(cfg.trainer.seed) 56 | 57 | # Dataset 58 | 59 | batch_size = cfg.trainer.batch_size 60 | num_workers = cfg.trainer.num_workers 61 | train_dataset = hydra.utils.instantiate(cfg.data.train) 62 | test_dataset = hydra.utils.instantiate(cfg.data.test) 63 | train_dataloader = torch.utils.data.DataLoader( 64 | train_dataset, batch_size=batch_size, shuffle=True, 65 | num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers > 0)) 66 | test_dataloader = torch.utils.data.DataLoader( 67 | test_dataset, batch_size=batch_size, shuffle=False, 68 | num_workers=num_workers, pin_memory=True, persistent_workers=(num_workers > 0)) 69 | 70 | # Model 71 | 72 | generator = hydra.utils.instantiate(cfg.model.generator.model) 73 | discriminator = hydra.utils.instantiate(cfg.model.discriminator.model) 74 | 75 | # Loss modules 76 | 77 | mrstft_loss = hydra.utils.instantiate(cfg.model.loss.mrstft) 78 | melmae_loss = hydra.utils.instantiate(cfg.model.loss.melmae) 79 | 80 | # Optimizer 81 | 82 | opt_G = hydra.utils.instantiate(cfg.optimizer.G.optimizer)(params=generator.parameters()) 83 | sche_G = hydra.utils.instantiate(cfg.optimizer.G.scheduler)(optimizer=opt_G) 84 | opt_D = hydra.utils.instantiate(cfg.optimizer.D.optimizer)(params=discriminator.parameters()) 85 | sche_D = hydra.utils.instantiate(cfg.optimizer.D.scheduler)(optimizer=opt_D) 86 | 87 | # Log 88 | 89 | generator.train() 90 | discriminator.train() 91 | num_params_g = count_parameters(generator.generator) / 1e6 92 | num_params_d = count_parameters(discriminator) / 1e6 93 | if accel.is_main_process: 94 | print("=== Parameters ===") 95 | print(f"\tGenerator:\t{num_params_g:.2f} [million]") 96 | print(f"\tDiscriminator:\t{num_params_d:.2f} [million]") 97 | print("=== Dataset ===") 98 | print(f"\tBatch size: {cfg.trainer.batch_size}") 99 | print("\tTrain data:") 100 | print(f"\t\tFiles:\t{len(train_dataset.filenames)}") 101 | print(f"\t\tChunks:\t{len(train_dataset)}") 102 | print(f"\t\tBatches:\t{len(train_dataset)//cfg.trainer.batch_size}") 103 | print("\tTest data:") 104 | print(f"\t\tFiles:\t{len(test_dataset.filenames)}") 105 | print(f"\t\tChunks:\t{len(test_dataset)}") 106 | print(f"\t\tBatches:\t{len(test_dataset)//cfg.trainer.batch_size}") 107 | 108 | # Prepare for DDP 109 | 110 | (train_dataloader, test_dataloader, generator, discriminator, 111 | mrstft_loss, melmae_loss, opt_G, sche_G, opt_D, sche_D) = accel.prepare( 112 | train_dataloader, test_dataloader, generator, discriminator, 113 | mrstft_loss, melmae_loss, opt_G, sche_G, opt_D, sche_D) 114 | 115 | # Start training 116 | 117 | trainer = Trainer( 118 | generator, discriminator, {'mrstft': mrstft_loss, 'melmae': melmae_loss}, 119 | opt_G, sche_G, opt_D, sche_D, train_dataloader, test_dataloader, 120 | accel, cfg, ckpt_dir=cfg.trainer.ckpt_dir 121 | ) 122 | 123 | trainer.start_training() 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | 7 | import torch 8 | import torchaudio 9 | import wandb 10 | 11 | from model import WaveFit, Discriminator 12 | from utils.logging import MetricsLogger 13 | from utils.torch_common import sort_dict, print_once 14 | from utils.viz import spectrogram_image 15 | 16 | 17 | class Trainer: 18 | def __init__( 19 | self, 20 | G, # Generator (WaveFit) 21 | D, # Discriminator 22 | loss_modules, # Loss modules 23 | opt_G, sche_G, # Optimizer/scheduler of G 24 | opt_D, sche_D, # Optimizer/scheduler of D 25 | train_dataloader, 26 | test_dataloader, 27 | accel, # Accelerator object 28 | cfg, # Configurations 29 | ckpt_dir=None 30 | ): 31 | self.G: WaveFit = accel.unwrap_model(G) 32 | self.D: Discriminator = accel.unwrap_model(D) 33 | self.loss_modules = loss_modules 34 | self.opt_G = opt_G 35 | self.sche_G = sche_G 36 | self.opt_D = opt_D 37 | self.sche_D = sche_D 38 | self.train_dataloader = train_dataloader 39 | self.test_dataloader = test_dataloader 40 | self.accel = accel 41 | self.cfg = cfg 42 | self.cfg_t = cfg.trainer 43 | self.loss_lambdas = cfg.model.loss.lambdas 44 | self.EPS = 1e-8 45 | 46 | self.logger = MetricsLogger() # Logger for WandB 47 | self.logger_print = MetricsLogger() # Logger for printing 48 | self.logger_test = MetricsLogger() # Logger for test 49 | 50 | self.states = {'global_step': 0, 'best_metrics': float('inf'), 'latest_metrics': float('inf')} 51 | 52 | # time measurement 53 | self.s_event = torch.cuda.Event(enable_timing=True) 54 | self.e_event = torch.cuda.Event(enable_timing=True) 55 | 56 | # resume training 57 | if ckpt_dir is not None: 58 | self.__load_ckpt(ckpt_dir) 59 | 60 | def start_training(self): 61 | """ 62 | Start training with infinite loops 63 | """ 64 | self.G.train() 65 | self.D.train() 66 | self.s_event.record() 67 | 68 | print_once("\n[ Started training ]\n") 69 | 70 | while True: 71 | for batch in self.train_dataloader: 72 | # Update 73 | metrics = self.run_step(batch) 74 | 75 | # Test (validation) 76 | if self.__its_time(self.cfg_t.logging.n_step_test): 77 | self.__test() 78 | 79 | if self.accel.is_main_process: 80 | self.logger.add(metrics) 81 | self.logger_print.add(metrics) 82 | 83 | # Log 84 | if self.__its_time(self.cfg_t.logging.n_step_log): 85 | self.__log_metrics() 86 | 87 | # Print 88 | if self.__its_time(self.cfg_t.logging.n_step_print): 89 | self.__print_metrics() 90 | 91 | # Save checkpoint 92 | if self.__its_time(self.cfg_t.logging.n_step_ckpt): 93 | self.__save_ckpt() 94 | 95 | # Sample 96 | if self.__its_time(self.cfg_t.logging.n_step_sample): 97 | self.__sampling() 98 | 99 | self.states['global_step'] += 1 100 | 101 | def run_step(self, batch, train: bool = True): 102 | """ One training step """ 103 | 104 | audios, _ = batch 105 | 106 | # Prepare inputs 107 | 108 | mel_spec, r_noise, pstft_spec = self.G.mel.get_shaped_noise(audios) 109 | 110 | # Generator update 111 | 112 | if train: 113 | self.opt_G.zero_grad() 114 | 115 | preds = self.G(r_noise, torch.log(mel_spec.clamp(min=self.EPS)), pstft_spec) 116 | 117 | losses = {} 118 | for idx, pred in enumerate(preds): 119 | losses_i = {} 120 | losses_i.update(self.loss_modules['mrstft'](pred, audios)) 121 | losses_i.update(self.loss_modules['melmae'](pred, audios)) 122 | losses_i.update(self.D.compute_G_loss(pred, audios)) 123 | for k, v in losses_i.items(): 124 | losses[k] = losses.get(k, 0.) + v / len(preds) 125 | losses[f'{k}/iter-{idx+1}'] = v.detach() # for logging 126 | 127 | loss_g = 0. 128 | for k in self.loss_lambdas.keys(): 129 | loss_g += losses[k] * self.loss_lambdas[k] 130 | 131 | if train: 132 | self.accel.backward(loss_g) 133 | if self.accel.sync_gradients: 134 | self.accel.clip_grad_norm_(self.G.parameters(), self.cfg_t.max_grad_norm) 135 | self.opt_G.step() 136 | self.sche_G.step() 137 | 138 | losses['G/loss'] = loss_g.detach() 139 | 140 | # Discriminator update 141 | 142 | if train: 143 | self.opt_D.zero_grad() 144 | 145 | loss_d_real = self.D.compute_D_loss(audios, mode='real')['D/loss'] 146 | 147 | # NOTE: Discriminator loss is also computed for all intermediate predictions (Sec.4.2) 148 | loss_d_fake = 0. 149 | for idx, pred in enumerate(preds): 150 | loss_d_fake_ = self.D.compute_D_loss(pred.detach(), mode='fake')['D/loss'] 151 | loss_d_fake += loss_d_fake_ / len(preds) 152 | losses[f'D/loss/iter-{idx+1}'] = loss_d_fake_.detach() # for logging 153 | 154 | losses['D/loss/real'] = loss_d_real.detach() 155 | losses['D/loss/fake'] = loss_d_fake.detach() 156 | losses['D/loss'] = loss_d_real + loss_d_fake 157 | 158 | if train: 159 | self.accel.backward(losses['D/loss']) 160 | if self.accel.sync_gradients: 161 | self.accel.clip_grad_norm_(self.D.parameters(), self.cfg_t.max_grad_norm) 162 | self.opt_D.step() 163 | self.sche_D.step() 164 | 165 | return {k: v.detach() for k, v in losses.items()} 166 | 167 | @torch.no_grad() 168 | def __test(self): 169 | self.G.eval() 170 | self.D.eval() 171 | 172 | start = torch.cuda.Event(enable_timing=True) 173 | end = torch.cuda.Event(enable_timing=True) 174 | start.record() 175 | 176 | for batch in self.test_dataloader: 177 | metrics = self.run_step(batch, train=False) 178 | self.logger_test.add(metrics) 179 | 180 | end.record() 181 | torch.cuda.synchronize() 182 | p_time = start.elapsed_time(end) / 1000. # [sec] 183 | 184 | metrics = self.logger_test.pop() 185 | # gather from all processes 186 | metrics_g = {} 187 | for k, v in metrics.items(): 188 | metrics_g[k] = self.accel.gather(v).mean() 189 | 190 | if self.accel.is_main_process: 191 | # log and print 192 | step = self.states['global_step'] 193 | metrics_g = sort_dict(metrics_g) 194 | self.accel.log({f"test/{k}": v for k, v in metrics_g.items()}, step=step) 195 | metrics_g = {k: v.item() for k, v in metrics_g.items() if 'iter' not in k} 196 | s = f"[Test] ({p_time:.1e} [sec]): " + ' / '.join([f"[{k}] - {v:.3e}" for k, v in metrics_g.items()]) 197 | print(s) 198 | 199 | # update states 200 | m_for_ckpt = self.cfg_t.logging.metrics_for_best_ckpt 201 | m_latest = sum([metrics_g[k] * self.loss_lambdas[k] for k in m_for_ckpt]) 202 | self.states['latest_metrics'] = m_latest 203 | if m_latest < self.states['best_metrics']: 204 | self.states['best_metrics'] = m_latest 205 | 206 | self.G.train() 207 | self.D.train() 208 | 209 | @torch.no_grad() 210 | def __sampling(self): 211 | self.G.eval() 212 | 213 | # randomly select samples 214 | dataset = self.test_dataloader.dataset 215 | n_sample = self.cfg_t.logging.n_samples 216 | idxs = torch.randint(len(dataset), size=(n_sample,)) 217 | audios = torch.stack([dataset[idx][0] for idx in idxs], dim=0).to(self.accel.device) 218 | 219 | # sample 220 | mel_spec, r_noise, pstft_spec = self.G.mel.get_shaped_noise(audios) 221 | preds = self.G(r_noise, torch.log(mel_spec.clamp(min=self.EPS)), pstft_spec, return_only_last=True)[-1] 222 | mel_spec_pred = self.G.mel.compute_mel(preds.squeeze(1)) 223 | mel_spec_inoise = self.G.mel.compute_mel(r_noise.squeeze(1)) 224 | 225 | # save/log audios 226 | out_dir = self.cfg_t.output_dir + '/sample' 227 | os.makedirs(out_dir, exist_ok=True) 228 | columns = ['gt (audio)', 'gt (spec)', 'init_noise (audio)', 'init_noise (spec)', 'pred (audio)', 'pred (spec)'] 229 | table_audio = wandb.Table(columns=columns) 230 | for idx in range(audios.shape[0]): 231 | torchaudio.save(f"{out_dir}/item-{idx}_gt.wav", audios[idx].cpu(), sample_rate=dataset.sr, encoding='PCM_F') 232 | torchaudio.save(f"{out_dir}/item-{idx}_pred.wav", preds[idx].cpu(), sample_rate=dataset.sr, encoding='PCM_F') 233 | 234 | data = [ 235 | wandb.Audio(audios[idx].cpu().numpy().T, sample_rate=dataset.sr), 236 | wandb.Image(spectrogram_image(mel_spec[idx].cpu().numpy())), 237 | wandb.Audio(r_noise[idx].cpu().numpy().T, sample_rate=dataset.sr), 238 | wandb.Image(spectrogram_image(mel_spec_inoise[idx].cpu().numpy())), 239 | wandb.Audio(preds[idx].cpu().numpy().T, sample_rate=dataset.sr), 240 | wandb.Image(spectrogram_image(mel_spec_pred[idx].cpu().numpy())) 241 | ] 242 | 243 | table_audio.add_data(*data) 244 | 245 | self.accel.log({'Samples': table_audio}, step=self.states['global_step']) 246 | 247 | self.G.train() 248 | 249 | print("\t->->-> Sampled.") 250 | 251 | def __save_ckpt(self): 252 | import shutil 253 | import json 254 | from omegaconf import OmegaConf 255 | 256 | out_dir = self.cfg_t.output_dir + '/ckpt' 257 | 258 | # save latest ckpt 259 | latest_dir = out_dir + '/latest' 260 | os.makedirs(latest_dir, exist_ok=True) 261 | ckpts = {'generator': self.G, 262 | 'discriminator': self.D, 263 | 'opt_g': self.opt_G, 264 | 'opt_d': self.opt_D, 265 | 'sche_g': self.sche_G, 266 | 'sche_d': self.sche_D} 267 | for name, m in ckpts.items(): 268 | torch.save(m.state_dict(), f"{latest_dir}/{name}.pth") 269 | 270 | # save states and configuration 271 | OmegaConf.save(self.cfg, f"{latest_dir}/config.yaml") 272 | with open(f"{latest_dir}/states.json", mode="wt", encoding="utf-8") as f: 273 | json.dump(self.states, f, indent=2) 274 | 275 | # save best ckpt 276 | if self.states['latest_metrics'] == self.states['best_metrics']: 277 | shutil.copytree(latest_dir, out_dir + '/best', dirs_exist_ok=True) 278 | 279 | print("\t->->-> Saved checkpoints.") 280 | 281 | def __load_ckpt(self, dir: str): 282 | import json 283 | 284 | print_once(f"\n[Resuming training from the checkpoint directory] -> {dir}") 285 | ckpts = {'generator': self.G, 286 | 'discriminator': self.D, 287 | 'opt_g': self.opt_G, 288 | 'opt_d': self.opt_D, 289 | 'sche_g': self.sche_G, 290 | 'sche_d': self.sche_D} 291 | for k, v in ckpts.items(): 292 | v.load_state_dict(torch.load(f"{dir}/{k}.pth", weights_only=False)) 293 | 294 | with open(f"{dir}/states.json", mode="rt", encoding="utf-8") as f: 295 | self.states.update(json.load(f)) 296 | 297 | def __log_metrics(self, sort_by_key: bool = True): 298 | metrics = self.logger.pop() 299 | # learning rate 300 | metrics['G/lr'] = self.sche_G.get_last_lr()[0] 301 | metrics['D/lr'] = self.sche_D.get_last_lr()[0] 302 | if sort_by_key: 303 | metrics = sort_dict(metrics) 304 | 305 | self.accel.log(metrics, step=self.states['global_step']) 306 | 307 | def __print_metrics(self, sort_by_key: bool = True): 308 | self.e_event.record() 309 | torch.cuda.synchronize() 310 | p_time = self.s_event.elapsed_time(self.e_event) / 1000. # [sec] 311 | 312 | metrics = self.logger_print.pop() 313 | # tensor to scalar 314 | metrics = {k: v.item() for k, v in metrics.items() if 'iter' not in k} 315 | # learning rate 316 | metrics['G/lr'] = self.sche_G.get_last_lr()[0] 317 | metrics['D/lr'] = self.sche_D.get_last_lr()[0] 318 | if sort_by_key: 319 | metrics = sort_dict(metrics) 320 | 321 | step = self.states['global_step'] 322 | s = f"Step {step} ({p_time:.1e} [sec]): " + ' / '.join([f"[{k}] - {v:.3e}" for k, v in metrics.items()]) 323 | print(s) 324 | 325 | self.s_event.record() 326 | 327 | def __its_time(self, itv: int): 328 | return (self.states['global_step'] - 1) % itv == 0 329 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukara-ikemiya/wavefit-pytorch/13d6a0dc87dd6e5bd35fba69f889dc5fc7d53bee/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import math 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from librosa.filters import mel as librosa_mel 11 | 12 | EPS = 1e-10 13 | 14 | 15 | def get_min_phase_filter(amplitude): 16 | """ 17 | Adapted from the following repo's code. 18 | https://github.com/SpecDiff-GAN/ 19 | """ 20 | def concat_negative_freq(tensor): 21 | return torch.concat((tensor[..., :-1], tensor[..., 1:].flip(dims=(-1,))), -1) 22 | 23 | device = amplitude.device 24 | rank = amplitude.ndim 25 | num_bins = amplitude.shape[-1] 26 | amplitude = concat_negative_freq(amplitude) 27 | 28 | fftsize = (num_bins - 1) * 2 29 | m0 = torch.zeros((fftsize // 2 - 1,), dtype=torch.complex64, device=device) 30 | m1 = torch.ones((1,), dtype=torch.complex64, device=device) 31 | m2 = torch.ones((fftsize // 2 - 1,), dtype=torch.complex64, device=device) * 2.0 32 | minimum_phase_window = torch.concat([m1, m2, m1, m0], axis=0) 33 | 34 | if rank > 1: 35 | new_shape = [1] * (rank - 1) + [fftsize] 36 | minimum_phase_window = torch.reshape(minimum_phase_window, new_shape) 37 | 38 | cepstrum = torch.fft.ifft(torch.log(amplitude).to(torch.complex64)) 39 | windowed_cepstrum = cepstrum * minimum_phase_window 40 | imag_phase = torch.imag(torch.fft.fft(windowed_cepstrum)) 41 | phase = torch.exp(torch.complex(imag_phase * 0.0, imag_phase)) 42 | minimum_phase = amplitude.to(torch.complex64) * phase 43 | return minimum_phase[..., :num_bins] 44 | 45 | 46 | def get_amplitude_spec(x, n_fft, win_size, hop_size, window, return_power: bool = False): 47 | stft_spec = torch.stft( 48 | x, n_fft, hop_length=hop_size, win_length=win_size, window=window, 49 | center=True, normalized=False, onesided=True, return_complex=True) 50 | 51 | power_spec = torch.view_as_real(stft_spec).pow(2).sum(-1) 52 | 53 | return power_spec if return_power else torch.sqrt(power_spec + EPS) 54 | 55 | 56 | class MelSpectrogram(nn.Module): 57 | def __init__( 58 | self, 59 | sr: int, 60 | # STFT setting 61 | n_fft: int, win_size: int, hop_size: int, 62 | # MelSpec setting 63 | n_mels: int, fmin: float, fmax: float, 64 | ): 65 | super().__init__() 66 | 67 | self.sr = sr 68 | self.n_fft = n_fft 69 | self.win_size = win_size 70 | self.hop_size = hop_size 71 | self.n_mels = n_mels 72 | self.fmin = fmin 73 | self.fmax = fmax 74 | 75 | mel_basis = librosa_mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 76 | mel_basis = torch.from_numpy(mel_basis).float() 77 | mel_inv_basis = torch.linalg.pinv(mel_basis) 78 | 79 | self.register_buffer('fft_win', torch.hann_window(win_size)) 80 | self.register_buffer('mel_basis', mel_basis) 81 | self.register_buffer('mel_inv_basis', mel_inv_basis) 82 | 83 | def compute_mel(self, x: torch.Tensor): 84 | """ 85 | Compute Mel-spectrogram. 86 | 87 | Args: 88 | x: time_signal, (bs, length) 89 | Returns: 90 | mel_spec: Mel spectrogram, (bs, n_mels, num_frame) 91 | """ 92 | assert x.dim() == 2 93 | L = x.shape[-1] 94 | # NOTE : To prevent different signal length in the final frame of the STFT between training and inference time, 95 | # input signal length must be a multiple of hop_size. 96 | assert L % self.hop_size == 0, f"Input signal length must be a multiple of hop_size {self.hop_size}." + \ 97 | f"Input shape -> {x.shape}" 98 | 99 | num_frame = L // self.hop_size 100 | 101 | # STFT 102 | stft_spec = get_amplitude_spec(x, self.n_fft, self.win_size, self.hop_size, self.fft_win) 103 | 104 | # Mel Spec 105 | mel_spec = torch.matmul(self.mel_basis, stft_spec) 106 | 107 | # NOTE : The last frame is removed here. 108 | # When using center=True setting, output from torch.stft has frame length of (L//hopsize+1). 109 | # For training WaveGrad-based architecture, the frame length must be (L//hopsize). 110 | # There might be a better way, but I believe this has little to no impact on training 111 | # since the whole signal information is contained in the previous frames even when removing the last one. 112 | mel_spec = mel_spec[..., :num_frame] 113 | 114 | return mel_spec 115 | 116 | def get_spec_env_from_mel( 117 | self, 118 | mel_spec: torch.Tensor, 119 | cep_order: int = 24, 120 | min_clamp: float = 1e-5, 121 | return_min_phase: bool = True, 122 | return_pseudo_stft: bool = False 123 | ): 124 | """ 125 | Get spectral envelope from Mel spectrogram 126 | 127 | Args: 128 | mel_spec: Mel spectrogram, (bs, n_mels, num_frame) 129 | cep_order: Order of cepstrum lifter 130 | return_min_phase: If true, minimum phase filter (complex) is returned. 131 | If false, amplitude spectrum envelope (float) is returned. 132 | 133 | Returns: 134 | spec_env: Spectral envelope, (bs, binsize, num_frame) 135 | """ 136 | # pseudo-inverse 137 | pstft_spec = torch.matmul(self.mel_inv_basis, mel_spec) # (bs, n_fft/2 + 1, num_frame) 138 | 139 | # n_fft should be power of 2 140 | binsize = pstft_spec.shape[-2] 141 | n_fft = int(2 ** math.floor(math.log2(binsize))) 142 | 143 | # cepstrum 144 | cepstrum = torch.fft.ifft(torch.log(torch.clamp(pstft_spec[..., :n_fft, :], min=min_clamp)).to(torch.complex64), dim=-2) 145 | cepstrum[..., cep_order:, :] = 0 146 | 147 | # spectral envelope 148 | spec_env = torch.exp(torch.real(torch.fft.fft(cepstrum, dim=-2))) 149 | spec_env = F.pad(spec_env, (0, 0, 0, binsize - n_fft)) # zero-pad 150 | 151 | if return_min_phase: 152 | spec_env = get_min_phase_filter(torch.clamp(spec_env.transpose(-2, -1), min=min_clamp)).transpose(-2, -1) 153 | 154 | return (spec_env, pstft_spec) if return_pseudo_stft else spec_env 155 | 156 | def get_shaped_noise(self, audios: torch.Tensor): 157 | """ 158 | Get inputs for WaveFit training including spectral-envelope shaped noise 159 | 160 | Args: 161 | audios: Target audio signal, (bs, 1, L) 162 | """ 163 | # mel-spectrogram 164 | mel_spec = self.compute_mel(audios.squeeze(1)) 165 | # spectral envelope 166 | spec_env, pstft_spec = self.get_spec_env_from_mel( 167 | mel_spec, cep_order=24, return_min_phase=True, return_pseudo_stft=True) 168 | # prepare noise 169 | num_frame = mel_spec.shape[-1] 170 | noise = torch.randn(*audios.shape, device=audios.device) 171 | noise_spec = torch.stft( 172 | noise.squeeze(1), self.n_fft, hop_length=self.hop_size, 173 | win_length=self.win_size, window=self.fft_win, 174 | center=True, normalized=False, onesided=True, return_complex=True)[..., :num_frame] 175 | # shaping 176 | noise_spec *= spec_env 177 | r_noise = torch.istft( 178 | noise_spec, self.n_fft, hop_length=self.hop_size, 179 | win_length=self.win_size, window=self.fft_win, 180 | center=True, normalized=False, length=audios.shape[-1]) 181 | r_noise = r_noise.unsqueeze(1) 182 | 183 | return mel_spec, r_noise, pstft_spec 184 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | 4 | Convenient modules for logging metrics. 5 | """ 6 | 7 | import typing as tp 8 | 9 | import torch 10 | 11 | 12 | class MetricsLogger: 13 | def __init__(self): 14 | self.counts = {} 15 | self.metrics = {} 16 | 17 | def add(self, metrics: tp.Dict[str, torch.Tensor]) -> None: 18 | for k, v in metrics.items(): 19 | if k in self.counts.keys(): 20 | self.counts[k] += 1 21 | self.metrics[k] += v.detach().clone() 22 | else: 23 | self.counts[k] = 1 24 | self.metrics[k] = v.detach().clone() 25 | 26 | def pop(self, mean: bool = True) -> tp.Dict[str, torch.Tensor]: 27 | metrics = {} 28 | for k, v in self.metrics.items(): 29 | metrics[k] = v / self.counts[k] if mean else v 30 | 31 | # reset 32 | self.counts = {} 33 | self.metrics = {} 34 | 35 | return metrics 36 | -------------------------------------------------------------------------------- /src/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used from the following repo under MIT license. 3 | https://github.com/Stability-AI/stable-audio-tools/ 4 | """ 5 | 6 | import torch 7 | 8 | 9 | class InverseLR(torch.optim.lr_scheduler._LRScheduler): 10 | """Implements an inverse decay learning rate schedule with an optional exponential 11 | warmup. When last_epoch=-1, sets initial lr as lr. 12 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 13 | (1 / 2)**power of its original value. 14 | Args: 15 | optimizer (Optimizer): Wrapped optimizer. 16 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 17 | power (float): Exponential factor of learning rate decay. Default: 1. 18 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 19 | Default: 0. 20 | final_lr (float): The final learning rate. Default: 0. 21 | last_epoch (int): The index of last epoch. Default: -1. 22 | verbose (bool): If ``True``, prints a message to stdout for 23 | each update. Default: ``False``. 24 | """ 25 | 26 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 27 | last_epoch=-1, verbose=False): 28 | self.inv_gamma = inv_gamma 29 | self.power = power 30 | if not 0. <= warmup < 1: 31 | raise ValueError('Invalid value for warmup') 32 | self.warmup = warmup 33 | self.final_lr = final_lr 34 | super().__init__(optimizer, last_epoch, verbose) 35 | 36 | def get_lr(self): 37 | if not self._get_lr_called_within_step: 38 | import warnings 39 | warnings.warn("To get the last learning rate computed by the scheduler, " 40 | "please use `get_last_lr()`.") 41 | 42 | return self._get_closed_form_lr() 43 | 44 | def _get_closed_form_lr(self): 45 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 46 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 47 | return [warmup * max(self.final_lr, base_lr * lr_mult) 48 | for base_lr in self.base_lrs] 49 | -------------------------------------------------------------------------------- /src/utils/torch_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | """ 4 | 5 | import os 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def exists(x: torch.Tensor): 13 | return x is not None 14 | 15 | 16 | def get_world_size(): 17 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 18 | return 1 19 | else: 20 | return torch.distributed.get_world_size() 21 | 22 | 23 | def get_rank(): 24 | """Get rank of current process.""" 25 | 26 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 27 | return 0 28 | else: 29 | return torch.distributed.get_rank() 30 | 31 | 32 | def print_once(*args): 33 | if get_rank() == 0: 34 | print(*args) 35 | 36 | 37 | def set_seed(seed: int = 0): 38 | torch.manual_seed(seed) 39 | if torch.cuda.is_available(): 40 | torch.cuda.manual_seed_all(seed) 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | os.environ["PYTHONHASHSEED"] = str(seed) 44 | 45 | 46 | def count_parameters(model: torch.nn.Module, include_buffers: bool = False): 47 | n_trainable_params = sum(p.numel() for p in model.parameters()) 48 | n_buffers = sum(p.numel() for p in model.buffers()) if include_buffers else 0 49 | return n_trainable_params + n_buffers 50 | 51 | 52 | def sort_dict(D: dict): 53 | s_keys = sorted(D.keys()) 54 | return {k: D[k] for k in s_keys} 55 | -------------------------------------------------------------------------------- /src/utils/viz.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2024 Yukara Ikemiya 3 | 4 | Adapted from the following repo under Apache License 2.0. 5 | https://github.com/drscotthawley/aeiou/ 6 | """ 7 | 8 | from PIL import Image 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | from matplotlib.backends.backend_agg import FigureCanvasAgg 12 | 13 | 14 | def spectrogram_image( 15 | spec, 16 | db_range=[-45, 8], 17 | figsize=(6, 3), # size of plot (if justimage==False) 18 | ): 19 | from librosa import power_to_db 20 | 21 | fig = plt.figure(figsize=figsize, dpi=100) 22 | canvas = FigureCanvasAgg(fig) 23 | axs = fig.add_subplot() 24 | spec = spec.squeeze() 25 | im = axs.imshow(power_to_db(spec), origin='lower', aspect='auto', vmin=db_range[0], vmax=db_range[1]) 26 | 27 | axs.axis('off') 28 | plt.tight_layout() 29 | 30 | canvas.draw() 31 | rgba = np.asarray(canvas.buffer_rgba()) 32 | im = Image.fromarray(rgba) 33 | 34 | b = 15 # border size 35 | im = im.crop((b, b, im.size[0] - b, im.size[1] - b)) 36 | 37 | plt.clf() 38 | plt.close() 39 | 40 | return im 41 | --------------------------------------------------------------------------------