├── .github └── workflows │ └── pypi-release.yml ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── vocos-encodec.yaml ├── vocos-imdct.yaml ├── vocos-resnet.yaml └── vocos.yaml ├── metrics ├── UTMOS.py └── periodicity.py ├── notebooks └── Bark+Vocos.ipynb ├── requirements-train.txt ├── requirements.txt ├── setup.py ├── train.py └── vocos ├── __init__.py ├── dataset.py ├── discriminators.py ├── experiment.py ├── feature_extractors.py ├── heads.py ├── helpers.py ├── loss.py ├── models.py ├── modules.py ├── pretrained.py └── spectral_ops.py /.github/workflows/pypi-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.x" 16 | - name: Install pypa/setuptools 17 | run: >- 18 | python -m 19 | pip install wheel 20 | - name: Build a binary wheel 21 | run: >- 22 | python setup.py sdist bdist_wheel 23 | - name: Publish to PyPI 24 | uses: pypa/gh-action-pypi-publish@release/v1 25 | with: 26 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | logs/ 163 | *.pt 164 | *.ckpt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Charactr Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis 2 | 3 | [Audio samples](https://gemelo-ai.github.io/vocos/) | 4 | Paper [[abs]](https://arxiv.org/abs/2306.00814) [[pdf]](https://arxiv.org/pdf/2306.00814.pdf) 5 | 6 | Vocos is a fast neural vocoder designed to synthesize audio waveforms from acoustic features. Trained using a Generative 7 | Adversarial Network (GAN) objective, Vocos can generate waveforms in a single forward pass. Unlike other typical 8 | GAN-based vocoders, Vocos does not model audio samples in the time domain. Instead, it generates spectral 9 | coefficients, facilitating rapid audio reconstruction through inverse Fourier transform. 10 | 11 | ## Installation 12 | 13 | To use Vocos only in inference mode, install it using: 14 | 15 | ```bash 16 | pip install vocos 17 | ``` 18 | 19 | If you wish to train the model, install it with additional dependencies: 20 | 21 | ```bash 22 | pip install vocos[train] 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Reconstruct audio from mel-spectrogram 28 | 29 | ```python 30 | import torch 31 | 32 | from vocos import Vocos 33 | 34 | vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 35 | 36 | mel = torch.randn(1, 100, 256) # B, C, T 37 | audio = vocos.decode(mel) 38 | ``` 39 | 40 | Copy-synthesis from a file: 41 | 42 | ```python 43 | import torchaudio 44 | 45 | y, sr = torchaudio.load(YOUR_AUDIO_FILE) 46 | if y.size(0) > 1: # mix to mono 47 | y = y.mean(dim=0, keepdim=True) 48 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) 49 | y_hat = vocos(y) 50 | ``` 51 | 52 | ### Reconstruct audio from EnCodec tokens 53 | 54 | Additionally, you need to provide a `bandwidth_id` which corresponds to the embedding for bandwidth from the 55 | list: `[1.5, 3.0, 6.0, 12.0]`. 56 | 57 | ```python 58 | vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz") 59 | 60 | audio_tokens = torch.randint(low=0, high=1024, size=(8, 200)) # 8 codeboooks, 200 frames 61 | features = vocos.codes_to_features(audio_tokens) 62 | bandwidth_id = torch.tensor([2]) # 6 kbps 63 | 64 | audio = vocos.decode(features, bandwidth_id=bandwidth_id) 65 | ``` 66 | 67 | Copy-synthesis from a file: It extracts and quantizes features with EnCodec, then reconstructs them with Vocos in a 68 | single forward pass. 69 | 70 | ```python 71 | y, sr = torchaudio.load(YOUR_AUDIO_FILE) 72 | if y.size(0) > 1: # mix to mono 73 | y = y.mean(dim=0, keepdim=True) 74 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=24000) 75 | 76 | y_hat = vocos(y, bandwidth_id=bandwidth_id) 77 | ``` 78 | 79 | ### Integrate with 🐶 [Bark](https://github.com/suno-ai/bark) text-to-audio model 80 | 81 | See [example notebook](notebooks%2FBark%2BVocos.ipynb). 82 | 83 | ## Pre-trained models 84 | 85 | | Model Name | Dataset | Training Iterations | Parameters 86 | |-------------------------------------------------------------------------------------|---------------|-------------------|------------| 87 | | [charactr/vocos-mel-24khz](https://huggingface.co/charactr/vocos-mel-24khz) | LibriTTS | 1M | 13.5M 88 | | [charactr/vocos-encodec-24khz](https://huggingface.co/charactr/vocos-encodec-24khz) | DNS Challenge | 2M | 7.9M 89 | 90 | ## Training 91 | 92 | Prepare a filelist of audio files for the training and validation set: 93 | 94 | ```bash 95 | find $TRAIN_DATASET_DIR -name *.wav > filelist.train 96 | find $VAL_DATASET_DIR -name *.wav > filelist.val 97 | ``` 98 | 99 | Fill a config file, e.g. [vocos.yaml](configs%2Fvocos.yaml), with your filelist paths and start training with: 100 | 101 | ```bash 102 | python train.py -c configs/vocos.yaml 103 | ``` 104 | 105 | Refer to [Pytorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for details about customizing the 106 | training pipeline. 107 | 108 | ## Citation 109 | 110 | If this code contributes to your research, please cite our work: 111 | 112 | ``` 113 | @article{siuzdak2023vocos, 114 | title={Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis}, 115 | author={Siuzdak, Hubert}, 116 | journal={arXiv preprint arXiv:2306.00814}, 117 | year={2023} 118 | } 119 | ``` 120 | 121 | ## License 122 | 123 | The code in this repository is released under the MIT license as found in the 124 | [LICENSE](LICENSE) file. 125 | -------------------------------------------------------------------------------- /configs/vocos-encodec.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 24000 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 24000 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosEncodecExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 1.0 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.EncodecFeatures 38 | init_args: 39 | encodec_model: encodec_24khz 40 | bandwidths: [1.5, 3.0, 6.0, 12.0] 41 | train_codebooks: false 42 | 43 | backbone: 44 | class_path: vocos.models.VocosBackbone 45 | init_args: 46 | input_channels: 128 47 | dim: 384 48 | intermediate_dim: 1152 49 | num_layers: 8 50 | adanorm_num_embeddings: 4 # len(bandwidths) 51 | 52 | head: 53 | class_path: vocos.heads.ISTFTHead 54 | init_args: 55 | dim: 384 56 | n_fft: 1280 57 | hop_length: 320 58 | padding: same 59 | 60 | trainer: 61 | logger: 62 | class_path: pytorch_lightning.loggers.TensorBoardLogger 63 | init_args: 64 | save_dir: logs/ 65 | callbacks: 66 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 67 | - class_path: pytorch_lightning.callbacks.ModelSummary 68 | init_args: 69 | max_depth: 2 70 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 71 | init_args: 72 | monitor: val_loss 73 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 74 | save_top_k: 3 75 | save_last: true 76 | - class_path: vocos.helpers.GradNormCallback 77 | 78 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 79 | # This equals to 1M steps per generator and 1M per discriminator 80 | max_steps: 2000000 81 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 82 | limit_val_batches: 100 83 | accelerator: gpu 84 | strategy: ddp 85 | devices: [0] 86 | log_every_n_steps: 100 87 | -------------------------------------------------------------------------------- /configs/vocos-imdct.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 24000 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 100 43 | padding: center 44 | 45 | backbone: 46 | class_path: vocos.models.VocosBackbone 47 | init_args: 48 | input_channels: 100 49 | dim: 512 50 | intermediate_dim: 1536 51 | num_layers: 8 52 | 53 | head: 54 | class_path: vocos.heads.IMDCTCosHead 55 | init_args: 56 | dim: 512 57 | mdct_frame_len: 512 # mel-spec hop_length * 2 58 | padding: center 59 | 60 | trainer: 61 | logger: 62 | class_path: pytorch_lightning.loggers.TensorBoardLogger 63 | init_args: 64 | save_dir: logs/ 65 | callbacks: 66 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 67 | - class_path: pytorch_lightning.callbacks.ModelSummary 68 | init_args: 69 | max_depth: 2 70 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 71 | init_args: 72 | monitor: val_loss 73 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 74 | save_top_k: 3 75 | save_last: true 76 | - class_path: vocos.helpers.GradNormCallback 77 | 78 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 79 | # This equals to 1M steps per generator and 1M per discriminator 80 | max_steps: 2000000 81 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 82 | limit_val_batches: 100 83 | accelerator: gpu 84 | strategy: ddp 85 | devices: [0] 86 | log_every_n_steps: 100 87 | -------------------------------------------------------------------------------- /configs/vocos-resnet.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 24000 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 100 43 | padding: center 44 | 45 | backbone: 46 | class_path: vocos.models.VocosResNetBackbone 47 | init_args: 48 | input_channels: 100 49 | dim: 512 50 | num_blocks: 3 51 | 52 | head: 53 | class_path: vocos.heads.ISTFTHead 54 | init_args: 55 | dim: 512 56 | n_fft: 1024 57 | hop_length: 256 58 | padding: center 59 | 60 | trainer: 61 | logger: 62 | class_path: pytorch_lightning.loggers.TensorBoardLogger 63 | init_args: 64 | save_dir: logs/ 65 | callbacks: 66 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 67 | - class_path: pytorch_lightning.callbacks.ModelSummary 68 | init_args: 69 | max_depth: 2 70 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 71 | init_args: 72 | monitor: val_loss 73 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 74 | save_top_k: 3 75 | save_last: true 76 | - class_path: vocos.helpers.GradNormCallback 77 | 78 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 79 | # This equals to 1M steps per generator and 1M per discriminator 80 | max_steps: 2000000 81 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 82 | limit_val_batches: 100 83 | accelerator: gpu 84 | strategy: ddp 85 | devices: [0] 86 | log_every_n_steps: 100 87 | -------------------------------------------------------------------------------- /configs/vocos.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.6 2 | seed_everything: 4444 3 | 4 | data: 5 | class_path: vocos.dataset.VocosDataModule 6 | init_args: 7 | train_params: 8 | filelist_path: ??? 9 | sampling_rate: 24000 10 | num_samples: 16384 11 | batch_size: 16 12 | num_workers: 8 13 | 14 | val_params: 15 | filelist_path: ??? 16 | sampling_rate: 24000 17 | num_samples: 48384 18 | batch_size: 16 19 | num_workers: 8 20 | 21 | model: 22 | class_path: vocos.experiment.VocosExp 23 | init_args: 24 | sample_rate: 24000 25 | initial_learning_rate: 5e-4 26 | mel_loss_coeff: 45 27 | mrd_loss_coeff: 0.1 28 | num_warmup_steps: 0 # Optimizers warmup steps 29 | pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration 30 | 31 | # automatic evaluation 32 | evaluate_utmos: true 33 | evaluate_pesq: true 34 | evaluate_periodicty: true 35 | 36 | feature_extractor: 37 | class_path: vocos.feature_extractors.MelSpectrogramFeatures 38 | init_args: 39 | sample_rate: 24000 40 | n_fft: 1024 41 | hop_length: 256 42 | n_mels: 100 43 | padding: center 44 | 45 | backbone: 46 | class_path: vocos.models.VocosBackbone 47 | init_args: 48 | input_channels: 100 49 | dim: 512 50 | intermediate_dim: 1536 51 | num_layers: 8 52 | 53 | head: 54 | class_path: vocos.heads.ISTFTHead 55 | init_args: 56 | dim: 512 57 | n_fft: 1024 58 | hop_length: 256 59 | padding: center 60 | 61 | trainer: 62 | logger: 63 | class_path: pytorch_lightning.loggers.TensorBoardLogger 64 | init_args: 65 | save_dir: logs/ 66 | callbacks: 67 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 68 | - class_path: pytorch_lightning.callbacks.ModelSummary 69 | init_args: 70 | max_depth: 2 71 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 72 | init_args: 73 | monitor: val_loss 74 | filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f} 75 | save_top_k: 3 76 | save_last: true 77 | - class_path: vocos.helpers.GradNormCallback 78 | 79 | # Lightning calculates max_steps across all optimizer steps (rather than number of batches) 80 | # This equals to 1M steps per generator and 1M per discriminator 81 | max_steps: 2000000 82 | # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 83 | limit_val_batches: 100 84 | accelerator: gpu 85 | strategy: ddp 86 | devices: [0] 87 | log_every_n_steps: 100 88 | -------------------------------------------------------------------------------- /metrics/UTMOS.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import fairseq 4 | import pytorch_lightning as pl 5 | import requests 6 | import torch 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | 10 | UTMOS_CKPT_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt" 11 | WAV2VEC_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt" 12 | 13 | """ 14 | UTMOS score, automatic Mean Opinion Score (MOS) prediction system, 15 | adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo 16 | """ 17 | 18 | 19 | class UTMOSScore: 20 | """Predicting score for each audio clip.""" 21 | 22 | def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"): 23 | self.device = device 24 | filepath = os.path.join(os.path.dirname(__file__), ckpt_path) 25 | if not os.path.exists(filepath): 26 | download_file(UTMOS_CKPT_URL, filepath) 27 | self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device) 28 | 29 | def score(self, wavs: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Args: 32 | wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, 33 | the model processes the input as a single audio clip. The model 34 | performs batch processing when len(wavs) == 3. 35 | """ 36 | if len(wavs.shape) == 1: 37 | out_wavs = wavs.unsqueeze(0).unsqueeze(0) 38 | elif len(wavs.shape) == 2: 39 | out_wavs = wavs.unsqueeze(0) 40 | elif len(wavs.shape) == 3: 41 | out_wavs = wavs 42 | else: 43 | raise ValueError("Dimension of input tensor needs to be <= 3.") 44 | bs = out_wavs.shape[0] 45 | batch = { 46 | "wav": out_wavs, 47 | "domains": torch.zeros(bs, dtype=torch.int).to(self.device), 48 | "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, 49 | } 50 | with torch.no_grad(): 51 | output = self.model(batch) 52 | 53 | return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 54 | 55 | 56 | def download_file(url, filename): 57 | """ 58 | Downloads a file from the given URL 59 | 60 | Args: 61 | url (str): The URL of the file to download. 62 | filename (str): The name to save the file as. 63 | """ 64 | print(f"Downloading file {filename}...") 65 | response = requests.get(url, stream=True) 66 | response.raise_for_status() 67 | 68 | total_size_in_bytes = int(response.headers.get("content-length", 0)) 69 | progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) 70 | 71 | with open(filename, "wb") as f: 72 | for chunk in response.iter_content(chunk_size=8192): 73 | progress_bar.update(len(chunk)) 74 | f.write(chunk) 75 | 76 | progress_bar.close() 77 | 78 | 79 | def load_ssl_model(ckpt_path="wav2vec_small.pt"): 80 | filepath = os.path.join(os.path.dirname(__file__), ckpt_path) 81 | if not os.path.exists(filepath): 82 | download_file(WAV2VEC_URL, filepath) 83 | SSL_OUT_DIM = 768 84 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([filepath]) 85 | ssl_model = model[0] 86 | ssl_model.remove_pretraining_modules() 87 | return SSL_model(ssl_model, SSL_OUT_DIM) 88 | 89 | 90 | class BaselineLightningModule(pl.LightningModule): 91 | def __init__(self, cfg): 92 | super().__init__() 93 | self.cfg = cfg 94 | self.construct_model() 95 | self.save_hyperparameters() 96 | 97 | def construct_model(self): 98 | self.feature_extractors = nn.ModuleList( 99 | [load_ssl_model(ckpt_path="wav2vec_small.pt"), DomainEmbedding(3, 128),] 100 | ) 101 | output_dim = sum([feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) 102 | output_layers = [LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)] 103 | output_dim = output_layers[-1].get_output_dim() 104 | output_layers.append( 105 | Projection(hidden_dim=2048, activation=torch.nn.ReLU(), range_clipping=False, input_dim=output_dim) 106 | ) 107 | 108 | self.output_layers = nn.ModuleList(output_layers) 109 | 110 | def forward(self, inputs): 111 | outputs = {} 112 | for feature_extractor in self.feature_extractors: 113 | outputs.update(feature_extractor(inputs)) 114 | x = outputs 115 | for output_layer in self.output_layers: 116 | x = output_layer(x, inputs) 117 | return x 118 | 119 | 120 | class SSL_model(nn.Module): 121 | def __init__(self, ssl_model, ssl_out_dim) -> None: 122 | super(SSL_model, self).__init__() 123 | self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim 124 | 125 | def forward(self, batch): 126 | wav = batch["wav"] 127 | wav = wav.squeeze(1) # [batches, audio_len] 128 | res = self.ssl_model(wav, mask=False, features_only=True) 129 | x = res["x"] 130 | return {"ssl-feature": x} 131 | 132 | def get_output_dim(self): 133 | return self.ssl_out_dim 134 | 135 | 136 | class DomainEmbedding(nn.Module): 137 | def __init__(self, n_domains, domain_dim) -> None: 138 | super().__init__() 139 | self.embedding = nn.Embedding(n_domains, domain_dim) 140 | self.output_dim = domain_dim 141 | 142 | def forward(self, batch): 143 | return {"domain-feature": self.embedding(batch["domains"])} 144 | 145 | def get_output_dim(self): 146 | return self.output_dim 147 | 148 | 149 | class LDConditioner(nn.Module): 150 | """ 151 | Conditions ssl output by listener embedding 152 | """ 153 | 154 | def __init__(self, input_dim, judge_dim, num_judges=None): 155 | super().__init__() 156 | self.input_dim = input_dim 157 | self.judge_dim = judge_dim 158 | self.num_judges = num_judges 159 | assert num_judges != None 160 | self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) 161 | # concat [self.output_layer, phoneme features] 162 | 163 | self.decoder_rnn = nn.LSTM( 164 | input_size=self.input_dim + self.judge_dim, 165 | hidden_size=512, 166 | num_layers=1, 167 | batch_first=True, 168 | bidirectional=True, 169 | ) # linear? 170 | self.out_dim = self.decoder_rnn.hidden_size * 2 171 | 172 | def get_output_dim(self): 173 | return self.out_dim 174 | 175 | def forward(self, x, batch): 176 | judge_ids = batch["judge_id"] 177 | if "phoneme-feature" in x.keys(): 178 | concatenated_feature = torch.cat( 179 | (x["ssl-feature"], x["phoneme-feature"].unsqueeze(1).expand(-1, x["ssl-feature"].size(1), -1)), dim=2 180 | ) 181 | else: 182 | concatenated_feature = x["ssl-feature"] 183 | if "domain-feature" in x.keys(): 184 | concatenated_feature = torch.cat( 185 | (concatenated_feature, x["domain-feature"].unsqueeze(1).expand(-1, concatenated_feature.size(1), -1),), 186 | dim=2, 187 | ) 188 | if judge_ids != None: 189 | concatenated_feature = torch.cat( 190 | ( 191 | concatenated_feature, 192 | self.judge_embedding(judge_ids).unsqueeze(1).expand(-1, concatenated_feature.size(1), -1), 193 | ), 194 | dim=2, 195 | ) 196 | decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) 197 | return decoder_output 198 | 199 | 200 | class Projection(nn.Module): 201 | def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): 202 | super(Projection, self).__init__() 203 | self.range_clipping = range_clipping 204 | output_dim = 1 205 | if range_clipping: 206 | self.proj = nn.Tanh() 207 | 208 | self.net = nn.Sequential( 209 | nn.Linear(input_dim, hidden_dim), activation, nn.Dropout(0.3), nn.Linear(hidden_dim, output_dim), 210 | ) 211 | self.output_dim = output_dim 212 | 213 | def forward(self, x, batch): 214 | output = self.net(x) 215 | 216 | # range clipping 217 | if self.range_clipping: 218 | return self.proj(output) * 2.0 + 3 219 | else: 220 | return output 221 | 222 | def get_output_dim(self): 223 | return self.output_dim 224 | -------------------------------------------------------------------------------- /metrics/periodicity.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | import torchaudio 5 | import torchcrepe 6 | from torchcrepe.loudness import REF_DB 7 | 8 | SILENCE_THRESHOLD = -60 9 | UNVOICED_THRESHOLD = 0.21 10 | 11 | """ 12 | Periodicity metrics adapted from https://github.com/descriptinc/cargan 13 | """ 14 | 15 | 16 | def predict_pitch( 17 | audio: torch.Tensor, silence_threshold: float = SILENCE_THRESHOLD, unvoiced_treshold: float = UNVOICED_THRESHOLD 18 | ): 19 | """ 20 | Predicts pitch and periodicity for the given audio. 21 | 22 | Args: 23 | audio (Tensor): The audio waveform. 24 | silence_threshold (float): The threshold for silence detection. 25 | unvoiced_treshold (float): The threshold for unvoiced detection. 26 | 27 | Returns: 28 | pitch (ndarray): The predicted pitch. 29 | periodicity (ndarray): The predicted periodicity. 30 | """ 31 | # torchcrepe inference 32 | pitch, periodicity = torchcrepe.predict( 33 | audio, 34 | fmin=50.0, 35 | fmax=550, 36 | sample_rate=torchcrepe.SAMPLE_RATE, 37 | model="full", 38 | return_periodicity=True, 39 | device=audio.device, 40 | pad=False, 41 | ) 42 | pitch = pitch.cpu().numpy() 43 | periodicity = periodicity.cpu().numpy() 44 | 45 | # Calculate dB-scaled spectrogram and set low energy frames to unvoiced 46 | hop_length = torchcrepe.SAMPLE_RATE // 100 # default CREPE 47 | stft = torchaudio.functional.spectrogram( 48 | audio, 49 | window=torch.hann_window(torchcrepe.WINDOW_SIZE, device=audio.device), 50 | n_fft=torchcrepe.WINDOW_SIZE, 51 | hop_length=hop_length, 52 | win_length=torchcrepe.WINDOW_SIZE, 53 | power=2, 54 | normalized=False, 55 | pad=0, 56 | center=False, 57 | ) 58 | 59 | # Perceptual weighting 60 | freqs = librosa.fft_frequencies(sr=torchcrepe.SAMPLE_RATE, n_fft=torchcrepe.WINDOW_SIZE) 61 | perceptual_stft = librosa.perceptual_weighting(stft.cpu().numpy(), freqs) - REF_DB 62 | silence = perceptual_stft.mean(axis=1) < silence_threshold 63 | 64 | periodicity[silence] = 0 65 | pitch[periodicity < unvoiced_treshold] = torchcrepe.UNVOICED 66 | 67 | return pitch, periodicity 68 | 69 | 70 | def calculate_periodicity_metrics(y: torch.Tensor, y_hat: torch.Tensor): 71 | """ 72 | Calculates periodicity metrics for the predicted and true audio data. 73 | 74 | Args: 75 | y (Tensor): The true audio data. 76 | y_hat (Tensor): The predicted audio data. 77 | 78 | Returns: 79 | periodicity_loss (float): The periodicity loss. 80 | pitch_loss (float): The pitch loss. 81 | f1 (float): The F1 score for voiced/unvoiced classification 82 | """ 83 | true_pitch, true_periodicity = predict_pitch(y) 84 | pred_pitch, pred_periodicity = predict_pitch(y_hat) 85 | 86 | true_voiced = ~np.isnan(true_pitch) 87 | pred_voiced = ~np.isnan(pred_pitch) 88 | 89 | periodicity_loss = np.sqrt(((pred_periodicity - true_periodicity) ** 2).mean(axis=1)).mean() 90 | 91 | # Update pitch rmse 92 | voiced = true_voiced & pred_voiced 93 | difference_cents = 1200 * (np.log2(true_pitch[voiced]) - np.log2(pred_pitch[voiced])) 94 | pitch_loss = np.sqrt((difference_cents ** 2).mean()) 95 | 96 | # voiced/unvoiced precision and recall 97 | true_positives = (true_voiced & pred_voiced).sum() 98 | false_positives = (~true_voiced & pred_voiced).sum() 99 | false_negatives = (true_voiced & ~pred_voiced).sum() 100 | 101 | precision = true_positives / (true_positives + false_positives) 102 | recall = true_positives / (true_positives + false_negatives) 103 | f1 = 2 * precision * recall / (precision + recall) 104 | 105 | return periodicity_loss, pitch_loss, f1 106 | -------------------------------------------------------------------------------- /notebooks/Bark+Vocos.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "private_outputs": true, 7 | "provenance": [], 8 | "gpuType": "T4", 9 | "authorship_tag": "ABX9TyMC53IsYoVJIVijVzw3ADvX", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | }, 19 | "accelerator": "GPU" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "view-in-github", 26 | "colab_type": "text" 27 | }, 28 | "source": [ 29 | "\"Open" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "source": [ 35 | "# Text-to-Audio Synthesis using Bark and Vocos" 36 | ], 37 | "metadata": { 38 | "id": "NuRzVtHDZ_Gl" 39 | } 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "source": [ 44 | "In this notebook, we use [Bark](https://github.com/suno-ai/bark) generative model to turn a text prompt into EnCodec audio tokens. These tokens then go through two decoders, EnCodec and Vocos, to reconstruct the audio waveform. Compare the results to discover the differences in audio quality and characteristics." 45 | ], 46 | "metadata": { 47 | "id": "zJFDte0daDAz" 48 | } 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "source": [ 53 | "Make sure you have Bark and Vocos installed:" 54 | ], 55 | "metadata": { 56 | "id": "c9omqGDYnajY" 57 | } 58 | }, 59 | { 60 | "cell_type": "code", 61 | "source": [ 62 | "!pip install git+https://github.com/suno-ai/bark.git\n", 63 | "!pip install vocos" 64 | ], 65 | "metadata": { 66 | "id": "voH44g90NvtV" 67 | }, 68 | "execution_count": null, 69 | "outputs": [] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "source": [ 74 | "Download and load Bark models" 75 | ], 76 | "metadata": { 77 | "id": "s3cEjOIuj6tq" 78 | } 79 | }, 80 | { 81 | "cell_type": "code", 82 | "source": [ 83 | "from bark import preload_models\n", 84 | "\n", 85 | "preload_models()" 86 | ], 87 | "metadata": { 88 | "id": "1H7XtXRMjxUM" 89 | }, 90 | "execution_count": null, 91 | "outputs": [] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "source": [ 96 | "Download and load Vocos." 97 | ], 98 | "metadata": { 99 | "id": "YO1m0dJ1j-F5" 100 | } 101 | }, 102 | { 103 | "cell_type": "code", 104 | "source": [ 105 | "from vocos import Vocos\n", 106 | "import torch\n", 107 | "\n", 108 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 109 | "vocos = Vocos.from_pretrained(\"charactr/vocos-encodec-24khz\").to(device)" 110 | ], 111 | "metadata": { 112 | "id": "COQYTDDFkBCq" 113 | }, 114 | "execution_count": null, 115 | "outputs": [] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "source": [ 120 | "We are going to reuse `text_to_semantic` from Bark API, but to reconstruct audio waveform with a custom vododer, we need to slightly redefine the API to return `fine_tokens`." 121 | ], 122 | "metadata": { 123 | "id": "--RjqW0rk5JQ" 124 | } 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "metadata": { 130 | "id": "OiUsuN2DNl5S" 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "from typing import Optional, Union, Dict\n", 135 | "\n", 136 | "import numpy as np\n", 137 | "from bark.generation import generate_coarse, generate_fine\n", 138 | "\n", 139 | "\n", 140 | "def semantic_to_audio_tokens(\n", 141 | " semantic_tokens: np.ndarray,\n", 142 | " history_prompt: Optional[Union[Dict, str]] = None,\n", 143 | " temp: float = 0.7,\n", 144 | " silent: bool = False,\n", 145 | " output_full: bool = False,\n", 146 | "):\n", 147 | " coarse_tokens = generate_coarse(\n", 148 | " semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True\n", 149 | " )\n", 150 | " fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5)\n", 151 | "\n", 152 | " if output_full:\n", 153 | " full_generation = {\n", 154 | " \"semantic_prompt\": semantic_tokens,\n", 155 | " \"coarse_prompt\": coarse_tokens,\n", 156 | " \"fine_prompt\": fine_tokens,\n", 157 | " }\n", 158 | " return full_generation\n", 159 | " return fine_tokens" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "source": [ 165 | "Let's create a text prompt and generate audio tokens:" 166 | ], 167 | "metadata": { 168 | "id": "Cv8KCzXlmoF9" 169 | } 170 | }, 171 | { 172 | "cell_type": "code", 173 | "source": [ 174 | "from bark import text_to_semantic\n", 175 | "\n", 176 | "history_prompt = None\n", 177 | "text_prompt = \"So, you've heard about neural vocoding? [laughs] We've been messing around with this new model called Vocos.\"\n", 178 | "semantic_tokens = text_to_semantic(text_prompt, history_prompt=history_prompt, temp=0.7, silent=False,)\n", 179 | "audio_tokens = semantic_to_audio_tokens(\n", 180 | " semantic_tokens, history_prompt=history_prompt, temp=0.7, silent=False, output_full=False,\n", 181 | ")" 182 | ], 183 | "metadata": { 184 | "id": "pDmSTutoOH_G" 185 | }, 186 | "execution_count": null, 187 | "outputs": [] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "source": [ 192 | "Reconstruct audio waveform with EnCodec:" 193 | ], 194 | "metadata": { 195 | "id": "UYMzI8svTNqI" 196 | } 197 | }, 198 | { 199 | "cell_type": "code", 200 | "source": [ 201 | "from bark.generation import codec_decode\n", 202 | "from IPython.display import Audio\n", 203 | "\n", 204 | "encodec_output = codec_decode(audio_tokens)\n", 205 | "\n", 206 | "import torchaudio\n", 207 | "# Upsample to 44100 Hz for better reproduction on audio hardware\n", 208 | "encodec_output = torchaudio.functional.resample(torch.from_numpy(encodec_output), orig_freq=24000, new_freq=44100)\n", 209 | "Audio(encodec_output, rate=44100)" 210 | ], 211 | "metadata": { 212 | "id": "PzdytlXFTNQ2" 213 | }, 214 | "execution_count": null, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "source": [ 220 | "Reconstruct with Vocos:" 221 | ], 222 | "metadata": { 223 | "id": "BhUxBuP9TTTw" 224 | } 225 | }, 226 | { 227 | "cell_type": "code", 228 | "source": [ 229 | "audio_tokens_torch = torch.from_numpy(audio_tokens).to(device)\n", 230 | "features = vocos.codes_to_features(audio_tokens_torch)\n", 231 | "vocos_output = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) # 6 kbps\n", 232 | "# Upsample to 44100 Hz for better reproduction on audio hardware\n", 233 | "vocos_output = torchaudio.functional.resample(vocos_output, orig_freq=24000, new_freq=44100).cpu()\n", 234 | "Audio(vocos_output.numpy(), rate=44100)" 235 | ], 236 | "metadata": { 237 | "id": "8hzSWQ5-nBlV" 238 | }, 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "source": [ 245 | "Optionally save to mp3 files:" 246 | ], 247 | "metadata": { 248 | "id": "RjVXQIZRb1Re" 249 | } 250 | }, 251 | { 252 | "cell_type": "code", 253 | "source": [ 254 | "torchaudio.save(\"encodec.mp3\", encodec_output[None, :], 44100, compression=128)\n", 255 | "torchaudio.save(\"vocos.mp3\", vocos_output, 44100, compression=128)" 256 | ], 257 | "metadata": { 258 | "id": "PLFXpjUKb3WX" 259 | }, 260 | "execution_count": null, 261 | "outputs": [] 262 | } 263 | ] 264 | } -------------------------------------------------------------------------------- /requirements-train.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.8.6 2 | jsonargparse[signatures] 3 | transformers 4 | matplotlib 5 | torchcrepe 6 | pesq 7 | fairseq 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | numpy 4 | scipy 5 | einops 6 | pyyaml 7 | huggingface_hub 8 | encodec==0.1.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | for line in open("vocos/__init__.py"): 7 | line = line.strip() 8 | if "__version__" in line: 9 | context = {} 10 | exec(line, context) 11 | VERSION = context["__version__"] 12 | 13 | 14 | def read(*paths, **kwargs): 15 | content = "" 16 | with io.open( 17 | os.path.join(os.path.dirname(__file__), *paths), encoding=kwargs.get("encoding", "utf8"), 18 | ) as open_file: 19 | content = open_file.read().strip() 20 | return content 21 | 22 | 23 | def read_requirements(path): 24 | return [line.strip() for line in read(path).split("\n") if not line.startswith(('"', "#", "-", "git+"))] 25 | 26 | 27 | setup( 28 | name="vocos", 29 | version=VERSION, 30 | author="Hubert Siuzdak", 31 | author_email="huberts@charactr.com", 32 | description="Fourier-based neural vocoder for high-quality audio synthesis", 33 | url="https://github.com/charactr-platform/vocos", 34 | long_description=read("README.md"), 35 | long_description_content_type="text/markdown", 36 | packages=find_packages(), 37 | install_requires=read_requirements("requirements.txt"), 38 | extras_require={"train": read_requirements("requirements-train.txt")}, 39 | ) 40 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.cli import LightningCLI 2 | 3 | 4 | if __name__ == "__main__": 5 | cli = LightningCLI(run=False) 6 | cli.trainer.fit(model=cli.model, datamodule=cli.datamodule) 7 | -------------------------------------------------------------------------------- /vocos/__init__.py: -------------------------------------------------------------------------------- 1 | from vocos.pretrained import Vocos 2 | 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /vocos/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | import torch 5 | import torchaudio 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | torch.set_num_threads(1) 10 | 11 | 12 | @dataclass 13 | class DataConfig: 14 | filelist_path: str 15 | sampling_rate: int 16 | num_samples: int 17 | batch_size: int 18 | num_workers: int 19 | 20 | 21 | class VocosDataModule(LightningDataModule): 22 | def __init__(self, train_params: DataConfig, val_params: DataConfig): 23 | super().__init__() 24 | self.train_config = train_params 25 | self.val_config = val_params 26 | 27 | def _get_dataloder(self, cfg: DataConfig, train: bool): 28 | dataset = VocosDataset(cfg, train=train) 29 | dataloader = DataLoader( 30 | dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, 31 | ) 32 | return dataloader 33 | 34 | def train_dataloader(self) -> DataLoader: 35 | return self._get_dataloder(self.train_config, train=True) 36 | 37 | def val_dataloader(self) -> DataLoader: 38 | return self._get_dataloder(self.val_config, train=False) 39 | 40 | 41 | class VocosDataset(Dataset): 42 | def __init__(self, cfg: DataConfig, train: bool): 43 | with open(cfg.filelist_path) as f: 44 | self.filelist = f.read().splitlines() 45 | self.sampling_rate = cfg.sampling_rate 46 | self.num_samples = cfg.num_samples 47 | self.train = train 48 | 49 | def __len__(self) -> int: 50 | return len(self.filelist) 51 | 52 | def __getitem__(self, index: int) -> torch.Tensor: 53 | audio_path = self.filelist[index] 54 | y, sr = torchaudio.load(audio_path) 55 | if y.size(0) > 1: 56 | # mix to mono 57 | y = y.mean(dim=0, keepdim=True) 58 | gain = np.random.uniform(-1, -6) if self.train else -3 59 | y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) 60 | if sr != self.sampling_rate: 61 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) 62 | if y.size(-1) < self.num_samples: 63 | pad_length = self.num_samples - y.size(-1) 64 | padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) 65 | y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) 66 | elif self.train: 67 | start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) 68 | y = y[:, start : start + self.num_samples] 69 | else: 70 | # During validation, take always the first segment for determinism 71 | y = y[:, : self.num_samples] 72 | 73 | return y[0] 74 | -------------------------------------------------------------------------------- /vocos/discriminators.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from einops import rearrange 5 | from torch import nn 6 | from torch.nn import Conv2d 7 | from torch.nn.utils import weight_norm 8 | from torchaudio.transforms import Spectrogram 9 | 10 | 11 | class MultiPeriodDiscriminator(nn.Module): 12 | """ 13 | Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. 14 | Additionally, it allows incorporating conditional information with a learned embeddings table. 15 | 16 | Args: 17 | periods (tuple[int]): Tuple of periods for each discriminator. 18 | num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. 19 | Defaults to None. 20 | """ 21 | 22 | def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None): 23 | super().__init__() 24 | self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) 25 | 26 | def forward( 27 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None 28 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 29 | y_d_rs = [] 30 | y_d_gs = [] 31 | fmap_rs = [] 32 | fmap_gs = [] 33 | for d in self.discriminators: 34 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 35 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 36 | y_d_rs.append(y_d_r) 37 | fmap_rs.append(fmap_r) 38 | y_d_gs.append(y_d_g) 39 | fmap_gs.append(fmap_g) 40 | 41 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 42 | 43 | 44 | class DiscriminatorP(nn.Module): 45 | def __init__( 46 | self, 47 | period: int, 48 | in_channels: int = 1, 49 | kernel_size: int = 5, 50 | stride: int = 3, 51 | lrelu_slope: float = 0.1, 52 | num_embeddings: Optional[int] = None, 53 | ): 54 | super().__init__() 55 | self.period = period 56 | self.convs = nn.ModuleList( 57 | [ 58 | weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 59 | weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 60 | weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 61 | weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 62 | weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), 63 | ] 64 | ) 65 | if num_embeddings is not None: 66 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024) 67 | torch.nn.init.zeros_(self.emb.weight) 68 | 69 | self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 70 | self.lrelu_slope = lrelu_slope 71 | 72 | def forward( 73 | self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None 74 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 75 | x = x.unsqueeze(1) 76 | fmap = [] 77 | # 1d to 2d 78 | b, c, t = x.shape 79 | if t % self.period != 0: # pad first 80 | n_pad = self.period - (t % self.period) 81 | x = torch.nn.functional.pad(x, (0, n_pad), "reflect") 82 | t = t + n_pad 83 | x = x.view(b, c, t // self.period, self.period) 84 | 85 | for i, l in enumerate(self.convs): 86 | x = l(x) 87 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) 88 | if i > 0: 89 | fmap.append(x) 90 | if cond_embedding_id is not None: 91 | emb = self.emb(cond_embedding_id) 92 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 93 | else: 94 | h = 0 95 | x = self.conv_post(x) 96 | fmap.append(x) 97 | x += h 98 | x = torch.flatten(x, 1, -1) 99 | 100 | return x, fmap 101 | 102 | 103 | class MultiResolutionDiscriminator(nn.Module): 104 | def __init__( 105 | self, 106 | fft_sizes: Tuple[int, ...] = (2048, 1024, 512), 107 | num_embeddings: Optional[int] = None, 108 | ): 109 | """ 110 | Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. 111 | Additionally, it allows incorporating conditional information with a learned embeddings table. 112 | 113 | Args: 114 | fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). 115 | num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. 116 | Defaults to None. 117 | """ 118 | 119 | super().__init__() 120 | self.discriminators = nn.ModuleList( 121 | [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] 122 | ) 123 | 124 | def forward( 125 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 126 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 127 | y_d_rs = [] 128 | y_d_gs = [] 129 | fmap_rs = [] 130 | fmap_gs = [] 131 | 132 | for d in self.discriminators: 133 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 134 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 135 | y_d_rs.append(y_d_r) 136 | fmap_rs.append(fmap_r) 137 | y_d_gs.append(y_d_g) 138 | fmap_gs.append(fmap_g) 139 | 140 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 141 | 142 | 143 | class DiscriminatorR(nn.Module): 144 | def __init__( 145 | self, 146 | window_length: int, 147 | num_embeddings: Optional[int] = None, 148 | channels: int = 32, 149 | hop_factor: float = 0.25, 150 | bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), 151 | ): 152 | super().__init__() 153 | self.window_length = window_length 154 | self.hop_factor = hop_factor 155 | self.spec_fn = Spectrogram( 156 | n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None 157 | ) 158 | n_fft = window_length // 2 + 1 159 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 160 | self.bands = bands 161 | convs = lambda: nn.ModuleList( 162 | [ 163 | weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), 164 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 165 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 166 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 167 | weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), 168 | ] 169 | ) 170 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 171 | 172 | if num_embeddings is not None: 173 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 174 | torch.nn.init.zeros_(self.emb.weight) 175 | 176 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) 177 | 178 | def spectrogram(self, x): 179 | # Remove DC offset 180 | x = x - x.mean(dim=-1, keepdims=True) 181 | # Peak normalize the volume of input audio 182 | x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 183 | x = self.spec_fn(x) 184 | x = torch.view_as_real(x) 185 | x = rearrange(x, "b f t c -> b c t f") 186 | # Split into bands 187 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 188 | return x_bands 189 | 190 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 191 | x_bands = self.spectrogram(x) 192 | fmap = [] 193 | x = [] 194 | for band, stack in zip(x_bands, self.band_convs): 195 | for i, layer in enumerate(stack): 196 | band = layer(band) 197 | band = torch.nn.functional.leaky_relu(band, 0.1) 198 | if i > 0: 199 | fmap.append(band) 200 | x.append(band) 201 | x = torch.cat(x, dim=-1) 202 | if cond_embedding_id is not None: 203 | emb = self.emb(cond_embedding_id) 204 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 205 | else: 206 | h = 0 207 | x = self.conv_post(x) 208 | fmap.append(x) 209 | x += h 210 | 211 | return x, fmap 212 | -------------------------------------------------------------------------------- /vocos/experiment.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchaudio 7 | import transformers 8 | 9 | from vocos.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator 10 | from vocos.feature_extractors import FeatureExtractor 11 | from vocos.heads import FourierHead 12 | from vocos.helpers import plot_spectrogram_to_numpy 13 | from vocos.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss 14 | from vocos.models import Backbone 15 | from vocos.modules import safe_log 16 | 17 | 18 | class VocosExp(pl.LightningModule): 19 | # noinspection PyUnusedLocal 20 | def __init__( 21 | self, 22 | feature_extractor: FeatureExtractor, 23 | backbone: Backbone, 24 | head: FourierHead, 25 | sample_rate: int, 26 | initial_learning_rate: float, 27 | num_warmup_steps: int = 0, 28 | mel_loss_coeff: float = 45, 29 | mrd_loss_coeff: float = 1.0, 30 | pretrain_mel_steps: int = 0, 31 | decay_mel_coeff: bool = False, 32 | evaluate_utmos: bool = False, 33 | evaluate_pesq: bool = False, 34 | evaluate_periodicty: bool = False, 35 | ): 36 | """ 37 | Args: 38 | feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals. 39 | backbone (Backbone): An instance of Backbone model. 40 | head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform. 41 | sample_rate (int): Sampling rate of the audio signals. 42 | initial_learning_rate (float): Initial learning rate for the optimizer. 43 | num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0. 44 | mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45. 45 | mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0. 46 | pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0. 47 | decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False. 48 | evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run. 49 | evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run. 50 | evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run. 51 | """ 52 | super().__init__() 53 | self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"]) 54 | 55 | self.feature_extractor = feature_extractor 56 | self.backbone = backbone 57 | self.head = head 58 | 59 | self.multiperioddisc = MultiPeriodDiscriminator() 60 | self.multiresddisc = MultiResolutionDiscriminator() 61 | 62 | self.disc_loss = DiscriminatorLoss() 63 | self.gen_loss = GeneratorLoss() 64 | self.feat_matching_loss = FeatureMatchingLoss() 65 | self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate) 66 | 67 | self.train_discriminator = False 68 | self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff 69 | 70 | def configure_optimizers(self): 71 | disc_params = [ 72 | {"params": self.multiperioddisc.parameters()}, 73 | {"params": self.multiresddisc.parameters()}, 74 | ] 75 | gen_params = [ 76 | {"params": self.feature_extractor.parameters()}, 77 | {"params": self.backbone.parameters()}, 78 | {"params": self.head.parameters()}, 79 | ] 80 | 81 | opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9)) 82 | opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9)) 83 | 84 | max_steps = self.trainer.max_steps // 2 # Max steps per optimizer 85 | scheduler_disc = transformers.get_cosine_schedule_with_warmup( 86 | opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, 87 | ) 88 | scheduler_gen = transformers.get_cosine_schedule_with_warmup( 89 | opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps, 90 | ) 91 | 92 | return ( 93 | [opt_disc, opt_gen], 94 | [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}], 95 | ) 96 | 97 | def forward(self, audio_input, **kwargs): 98 | features = self.feature_extractor(audio_input, **kwargs) 99 | x = self.backbone(features, **kwargs) 100 | audio_output = self.head(x) 101 | return audio_output 102 | 103 | def training_step(self, batch, batch_idx, optimizer_idx, **kwargs): 104 | audio_input = batch 105 | 106 | # train discriminator 107 | if optimizer_idx == 0 and self.train_discriminator: 108 | with torch.no_grad(): 109 | audio_hat = self(audio_input, **kwargs) 110 | 111 | real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,) 112 | real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,) 113 | loss_mp, loss_mp_real, _ = self.disc_loss( 114 | disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp 115 | ) 116 | loss_mrd, loss_mrd_real, _ = self.disc_loss( 117 | disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd 118 | ) 119 | loss_mp /= len(loss_mp_real) 120 | loss_mrd /= len(loss_mrd_real) 121 | loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd 122 | 123 | self.log("discriminator/total", loss, prog_bar=True) 124 | self.log("discriminator/multi_period_loss", loss_mp) 125 | self.log("discriminator/multi_res_loss", loss_mrd) 126 | return loss 127 | 128 | # train generator 129 | if optimizer_idx == 1: 130 | audio_hat = self(audio_input, **kwargs) 131 | if self.train_discriminator: 132 | _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( 133 | y=audio_input, y_hat=audio_hat, **kwargs, 134 | ) 135 | _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( 136 | y=audio_input, y_hat=audio_hat, **kwargs, 137 | ) 138 | loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) 139 | loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) 140 | loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) 141 | loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) 142 | loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp) 143 | loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd) 144 | 145 | self.log("generator/multi_period_loss", loss_gen_mp) 146 | self.log("generator/multi_res_loss", loss_gen_mrd) 147 | self.log("generator/feature_matching_mp", loss_fm_mp) 148 | self.log("generator/feature_matching_mrd", loss_fm_mrd) 149 | else: 150 | loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0 151 | 152 | mel_loss = self.melspec_loss(audio_hat, audio_input) 153 | loss = ( 154 | loss_gen_mp 155 | + self.hparams.mrd_loss_coeff * loss_gen_mrd 156 | + loss_fm_mp 157 | + self.hparams.mrd_loss_coeff * loss_fm_mrd 158 | + self.mel_loss_coeff * mel_loss 159 | ) 160 | 161 | self.log("generator/total_loss", loss, prog_bar=True) 162 | self.log("mel_loss_coeff", self.mel_loss_coeff) 163 | self.log("generator/mel_loss", mel_loss) 164 | 165 | if self.global_step % 1000 == 0 and self.global_rank == 0: 166 | self.logger.experiment.add_audio( 167 | "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate 168 | ) 169 | self.logger.experiment.add_audio( 170 | "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate 171 | ) 172 | with torch.no_grad(): 173 | mel = safe_log(self.melspec_loss.mel_spec(audio_input[0])) 174 | mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0])) 175 | self.logger.experiment.add_image( 176 | "train/mel_target", 177 | plot_spectrogram_to_numpy(mel.data.cpu().numpy()), 178 | self.global_step, 179 | dataformats="HWC", 180 | ) 181 | self.logger.experiment.add_image( 182 | "train/mel_pred", 183 | plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), 184 | self.global_step, 185 | dataformats="HWC", 186 | ) 187 | 188 | return loss 189 | 190 | def on_validation_epoch_start(self): 191 | if self.hparams.evaluate_utmos: 192 | from metrics.UTMOS import UTMOSScore 193 | 194 | if not hasattr(self, "utmos_model"): 195 | self.utmos_model = UTMOSScore(device=self.device) 196 | 197 | def validation_step(self, batch, batch_idx, **kwargs): 198 | audio_input = batch 199 | audio_hat = self(audio_input, **kwargs) 200 | 201 | audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000) 202 | audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000) 203 | 204 | if self.hparams.evaluate_periodicty: 205 | from metrics.periodicity import calculate_periodicity_metrics 206 | 207 | periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz) 208 | else: 209 | periodicity_loss = pitch_loss = f1_score = 0 210 | 211 | if self.hparams.evaluate_utmos: 212 | utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean() 213 | else: 214 | utmos_score = torch.zeros(1, device=self.device) 215 | 216 | if self.hparams.evaluate_pesq: 217 | from pesq import pesq 218 | 219 | pesq_score = 0 220 | for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()): 221 | pesq_score += pesq(16000, ref, deg, "wb", on_error=1) 222 | pesq_score /= len(audio_16_khz) 223 | pesq_score = torch.tensor(pesq_score) 224 | else: 225 | pesq_score = torch.zeros(1, device=self.device) 226 | 227 | mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) 228 | total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) 229 | 230 | return { 231 | "val_loss": total_loss, 232 | "mel_loss": mel_loss, 233 | "utmos_score": utmos_score, 234 | "pesq_score": pesq_score, 235 | "periodicity_loss": periodicity_loss, 236 | "pitch_loss": pitch_loss, 237 | "f1_score": f1_score, 238 | "audio_input": audio_input[0], 239 | "audio_pred": audio_hat[0], 240 | } 241 | 242 | def validation_epoch_end(self, outputs): 243 | if self.global_rank == 0: 244 | *_, audio_in, audio_pred = outputs[0].values() 245 | self.logger.experiment.add_audio( 246 | "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate 247 | ) 248 | self.logger.experiment.add_audio( 249 | "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate 250 | ) 251 | mel_target = safe_log(self.melspec_loss.mel_spec(audio_in)) 252 | mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred)) 253 | self.logger.experiment.add_image( 254 | "val_mel_target", 255 | plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()), 256 | self.global_step, 257 | dataformats="HWC", 258 | ) 259 | self.logger.experiment.add_image( 260 | "val_mel_hat", 261 | plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()), 262 | self.global_step, 263 | dataformats="HWC", 264 | ) 265 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 266 | mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean() 267 | utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean() 268 | pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean() 269 | periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean() 270 | pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean() 271 | f1_score = np.array([x["f1_score"] for x in outputs]).mean() 272 | 273 | self.log("val_loss", avg_loss, sync_dist=True) 274 | self.log("val/mel_loss", mel_loss, sync_dist=True) 275 | self.log("val/utmos_score", utmos_score, sync_dist=True) 276 | self.log("val/pesq_score", pesq_score, sync_dist=True) 277 | self.log("val/periodicity_loss", periodicity_loss, sync_dist=True) 278 | self.log("val/pitch_loss", pitch_loss, sync_dist=True) 279 | self.log("val/f1_score", f1_score, sync_dist=True) 280 | 281 | @property 282 | def global_step(self): 283 | """ 284 | Override global_step so that it returns the total number of batches processed 285 | """ 286 | return self.trainer.fit_loop.epoch_loop.total_batch_idx 287 | 288 | def on_train_batch_start(self, *args): 289 | if self.global_step >= self.hparams.pretrain_mel_steps: 290 | self.train_discriminator = True 291 | else: 292 | self.train_discriminator = False 293 | 294 | def on_train_batch_end(self, *args): 295 | def mel_loss_coeff_decay(current_step, num_cycles=0.5): 296 | max_steps = self.trainer.max_steps // 2 297 | if current_step < self.hparams.num_warmup_steps: 298 | return 1.0 299 | progress = float(current_step - self.hparams.num_warmup_steps) / float( 300 | max(1, max_steps - self.hparams.num_warmup_steps) 301 | ) 302 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 303 | 304 | if self.hparams.decay_mel_coeff: 305 | self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1) 306 | 307 | 308 | class VocosEncodecExp(VocosExp): 309 | """ 310 | VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN. 311 | It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to 312 | a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step, 313 | while during validation, a fixed bandwidth_id is used. 314 | """ 315 | 316 | def __init__( 317 | self, 318 | feature_extractor: FeatureExtractor, 319 | backbone: Backbone, 320 | head: FourierHead, 321 | sample_rate: int, 322 | initial_learning_rate: float, 323 | num_warmup_steps: int, 324 | mel_loss_coeff: float = 45, 325 | mrd_loss_coeff: float = 1.0, 326 | pretrain_mel_steps: int = 0, 327 | decay_mel_coeff: bool = False, 328 | evaluate_utmos: bool = False, 329 | evaluate_pesq: bool = False, 330 | evaluate_periodicty: bool = False, 331 | ): 332 | super().__init__( 333 | feature_extractor, 334 | backbone, 335 | head, 336 | sample_rate, 337 | initial_learning_rate, 338 | num_warmup_steps, 339 | mel_loss_coeff, 340 | mrd_loss_coeff, 341 | pretrain_mel_steps, 342 | decay_mel_coeff, 343 | evaluate_utmos, 344 | evaluate_pesq, 345 | evaluate_periodicty, 346 | ) 347 | # Override with conditional discriminators 348 | self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) 349 | self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths)) 350 | 351 | def training_step(self, *args): 352 | bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,) 353 | output = super().training_step(*args, bandwidth_id=bandwidth_id) 354 | return output 355 | 356 | def validation_step(self, *args): 357 | bandwidth_id = torch.tensor([0], device=self.device) 358 | output = super().validation_step(*args, bandwidth_id=bandwidth_id) 359 | return output 360 | 361 | def validation_epoch_end(self, outputs): 362 | if self.global_rank == 0: 363 | *_, audio_in, _ = outputs[0].values() 364 | # Resynthesis with encodec for reference 365 | self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0]) 366 | encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :]) 367 | self.logger.experiment.add_audio( 368 | "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate, 369 | ) 370 | 371 | super().validation_epoch_end(outputs) 372 | -------------------------------------------------------------------------------- /vocos/feature_extractors.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torchaudio 5 | from encodec import EncodecModel 6 | from torch import nn 7 | 8 | from vocos.modules import safe_log 9 | 10 | 11 | class FeatureExtractor(nn.Module): 12 | """Base class for feature extractors.""" 13 | 14 | def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: 15 | """ 16 | Extract features from the given audio. 17 | 18 | Args: 19 | audio (Tensor): Input audio waveform. 20 | 21 | Returns: 22 | Tensor: Extracted features of shape (B, C, L), where B is the batch size, 23 | C denotes output features, and L is the sequence length. 24 | """ 25 | raise NotImplementedError("Subclasses must implement the forward method.") 26 | 27 | 28 | class MelSpectrogramFeatures(FeatureExtractor): 29 | def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): 30 | super().__init__() 31 | if padding not in ["center", "same"]: 32 | raise ValueError("Padding must be 'center' or 'same'.") 33 | self.padding = padding 34 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 35 | sample_rate=sample_rate, 36 | n_fft=n_fft, 37 | hop_length=hop_length, 38 | n_mels=n_mels, 39 | center=padding == "center", 40 | power=1, 41 | ) 42 | 43 | def forward(self, audio, **kwargs): 44 | if self.padding == "same": 45 | pad = self.mel_spec.win_length - self.mel_spec.hop_length 46 | audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") 47 | mel = self.mel_spec(audio) 48 | features = safe_log(mel) 49 | return features 50 | 51 | 52 | class EncodecFeatures(FeatureExtractor): 53 | def __init__( 54 | self, 55 | encodec_model: str = "encodec_24khz", 56 | bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], 57 | train_codebooks: bool = False, 58 | ): 59 | super().__init__() 60 | if encodec_model == "encodec_24khz": 61 | encodec = EncodecModel.encodec_model_24khz 62 | elif encodec_model == "encodec_48khz": 63 | encodec = EncodecModel.encodec_model_48khz 64 | else: 65 | raise ValueError( 66 | f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'." 67 | ) 68 | self.encodec = encodec(pretrained=True) 69 | for param in self.encodec.parameters(): 70 | param.requires_grad = False 71 | self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth( 72 | self.encodec.frame_rate, bandwidth=max(bandwidths) 73 | ) 74 | codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0) 75 | self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks) 76 | self.bandwidths = bandwidths 77 | 78 | @torch.no_grad() 79 | def get_encodec_codes(self, audio): 80 | audio = audio.unsqueeze(1) 81 | emb = self.encodec.encoder(audio) 82 | codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth) 83 | return codes 84 | 85 | def forward(self, audio: torch.Tensor, **kwargs): 86 | bandwidth_id = kwargs.get("bandwidth_id") 87 | if bandwidth_id is None: 88 | raise ValueError("The 'bandwidth_id' argument is required") 89 | self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode 90 | self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id]) 91 | codes = self.get_encodec_codes(audio) 92 | # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights` 93 | # with offsets given by the number of bins, and finally summed in a vectorized operation. 94 | offsets = torch.arange( 95 | 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device 96 | ) 97 | embeddings_idxs = codes + offsets.view(-1, 1, 1) 98 | features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0) 99 | return features.transpose(1, 2) 100 | -------------------------------------------------------------------------------- /vocos/heads.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz 6 | 7 | from vocos.spectral_ops import IMDCT, ISTFT 8 | from vocos.modules import symexp 9 | 10 | 11 | class FourierHead(nn.Module): 12 | """Base class for inverse fourier modules.""" 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Args: 17 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 18 | L is the sequence length, and H denotes the model dimension. 19 | 20 | Returns: 21 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 22 | """ 23 | raise NotImplementedError("Subclasses must implement the forward method.") 24 | 25 | 26 | class ISTFTHead(FourierHead): 27 | """ 28 | ISTFT Head module for predicting STFT complex coefficients. 29 | 30 | Args: 31 | dim (int): Hidden dimension of the model. 32 | n_fft (int): Size of Fourier transform. 33 | hop_length (int): The distance between neighboring sliding window frames, which should align with 34 | the resolution of the input features. 35 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 36 | """ 37 | 38 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): 39 | super().__init__() 40 | out_dim = n_fft + 2 41 | self.out = torch.nn.Linear(dim, out_dim) 42 | self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Forward pass of the ISTFTHead module. 47 | 48 | Args: 49 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 50 | L is the sequence length, and H denotes the model dimension. 51 | 52 | Returns: 53 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 54 | """ 55 | x = self.out(x).transpose(1, 2) 56 | mag, p = x.chunk(2, dim=1) 57 | mag = torch.exp(mag) 58 | mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes 59 | # wrapping happens here. These two lines produce real and imaginary value 60 | x = torch.cos(p) 61 | y = torch.sin(p) 62 | # recalculating phase here does not produce anything new 63 | # only costs time 64 | # phase = torch.atan2(y, x) 65 | # S = mag * torch.exp(phase * 1j) 66 | # better directly produce the complex value 67 | S = mag * (x + 1j * y) 68 | audio = self.istft(S) 69 | return audio 70 | 71 | 72 | class IMDCTSymExpHead(FourierHead): 73 | """ 74 | IMDCT Head module for predicting MDCT coefficients with symmetric exponential function 75 | 76 | Args: 77 | dim (int): Hidden dimension of the model. 78 | mdct_frame_len (int): Length of the MDCT frame. 79 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 80 | sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized 81 | based on perceptual scaling. Defaults to None. 82 | clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | dim: int, 88 | mdct_frame_len: int, 89 | padding: str = "same", 90 | sample_rate: Optional[int] = None, 91 | clip_audio: bool = False, 92 | ): 93 | super().__init__() 94 | out_dim = mdct_frame_len // 2 95 | self.out = nn.Linear(dim, out_dim) 96 | self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) 97 | self.clip_audio = clip_audio 98 | 99 | if sample_rate is not None: 100 | # optionally init the last layer following mel-scale 101 | m_max = _hz_to_mel(sample_rate // 2) 102 | m_pts = torch.linspace(0, m_max, out_dim) 103 | f_pts = _mel_to_hz(m_pts) 104 | scale = 1 - (f_pts / f_pts.max()) 105 | 106 | with torch.no_grad(): 107 | self.out.weight.mul_(scale.view(-1, 1)) 108 | 109 | def forward(self, x: torch.Tensor) -> torch.Tensor: 110 | """ 111 | Forward pass of the IMDCTSymExpHead module. 112 | 113 | Args: 114 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 115 | L is the sequence length, and H denotes the model dimension. 116 | 117 | Returns: 118 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 119 | """ 120 | x = self.out(x) 121 | x = symexp(x) 122 | x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes 123 | audio = self.imdct(x) 124 | if self.clip_audio: 125 | audio = torch.clip(x, min=-1.0, max=1.0) 126 | 127 | return audio 128 | 129 | 130 | class IMDCTCosHead(FourierHead): 131 | """ 132 | IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) 133 | 134 | Args: 135 | dim (int): Hidden dimension of the model. 136 | mdct_frame_len (int): Length of the MDCT frame. 137 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 138 | clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. 139 | """ 140 | 141 | def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): 142 | super().__init__() 143 | self.clip_audio = clip_audio 144 | self.out = nn.Linear(dim, mdct_frame_len) 145 | self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) 146 | 147 | def forward(self, x: torch.Tensor) -> torch.Tensor: 148 | """ 149 | Forward pass of the IMDCTCosHead module. 150 | 151 | Args: 152 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 153 | L is the sequence length, and H denotes the model dimension. 154 | 155 | Returns: 156 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 157 | """ 158 | x = self.out(x) 159 | m, p = x.chunk(2, dim=2) 160 | m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes 161 | audio = self.imdct(m * torch.cos(p)) 162 | if self.clip_audio: 163 | audio = torch.clip(x, min=-1.0, max=1.0) 164 | return audio 165 | -------------------------------------------------------------------------------- /vocos/helpers.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import torch 4 | from matplotlib import pyplot as plt 5 | from pytorch_lightning import Callback 6 | 7 | matplotlib.use("Agg") 8 | 9 | 10 | def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: 11 | """ 12 | Save a matplotlib figure to a numpy array. 13 | 14 | Args: 15 | fig (Figure): Matplotlib figure object. 16 | 17 | Returns: 18 | ndarray: Numpy array representing the figure. 19 | """ 20 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 21 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 22 | return data 23 | 24 | 25 | def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: 26 | """ 27 | Plot a spectrogram and convert it to a numpy array. 28 | 29 | Args: 30 | spectrogram (ndarray): Spectrogram data. 31 | 32 | Returns: 33 | ndarray: Numpy array representing the plotted spectrogram. 34 | """ 35 | spectrogram = spectrogram.astype(np.float32) 36 | fig, ax = plt.subplots(figsize=(12, 3)) 37 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 38 | plt.colorbar(im, ax=ax) 39 | plt.xlabel("Frames") 40 | plt.ylabel("Channels") 41 | plt.tight_layout() 42 | 43 | fig.canvas.draw() 44 | data = save_figure_to_numpy(fig) 45 | plt.close() 46 | return data 47 | 48 | 49 | class GradNormCallback(Callback): 50 | """ 51 | Callback to log the gradient norm. 52 | """ 53 | 54 | def on_after_backward(self, trainer, model): 55 | model.log("grad_norm", gradient_norm(model)) 56 | 57 | 58 | def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: 59 | """ 60 | Compute the gradient norm. 61 | 62 | Args: 63 | model (Module): PyTorch model. 64 | norm_type (float, optional): Type of the norm. Defaults to 2.0. 65 | 66 | Returns: 67 | Tensor: Gradient norm. 68 | """ 69 | grads = [p.grad for p in model.parameters() if p.grad is not None] 70 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) 71 | return total_norm 72 | -------------------------------------------------------------------------------- /vocos/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torchaudio 5 | from torch import nn 6 | 7 | from vocos.modules import safe_log 8 | 9 | 10 | class MelSpecReconstructionLoss(nn.Module): 11 | """ 12 | L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample 13 | """ 14 | 15 | def __init__( 16 | self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, 17 | ): 18 | super().__init__() 19 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 20 | sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1, 21 | ) 22 | 23 | def forward(self, y_hat, y) -> torch.Tensor: 24 | """ 25 | Args: 26 | y_hat (Tensor): Predicted audio waveform. 27 | y (Tensor): Ground truth audio waveform. 28 | 29 | Returns: 30 | Tensor: L1 loss between the mel-scaled magnitude spectrograms. 31 | """ 32 | mel_hat = safe_log(self.mel_spec(y_hat)) 33 | mel = safe_log(self.mel_spec(y)) 34 | 35 | loss = torch.nn.functional.l1_loss(mel, mel_hat) 36 | 37 | return loss 38 | 39 | 40 | class GeneratorLoss(nn.Module): 41 | """ 42 | Generator Loss module. Calculates the loss for the generator based on discriminator outputs. 43 | """ 44 | 45 | def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: 46 | """ 47 | Args: 48 | disc_outputs (List[Tensor]): List of discriminator outputs. 49 | 50 | Returns: 51 | Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from 52 | the sub-discriminators 53 | """ 54 | loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype) 55 | gen_losses = [] 56 | for dg in disc_outputs: 57 | l = torch.mean(torch.clamp(1 - dg, min=0)) 58 | gen_losses.append(l) 59 | loss += l 60 | 61 | return loss, gen_losses 62 | 63 | 64 | class DiscriminatorLoss(nn.Module): 65 | """ 66 | Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. 67 | """ 68 | 69 | def forward( 70 | self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] 71 | ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: 72 | """ 73 | Args: 74 | disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. 75 | disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. 76 | 77 | Returns: 78 | Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from 79 | the sub-discriminators for real outputs, and a list of 80 | loss values for generated outputs. 81 | """ 82 | loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype) 83 | r_losses = [] 84 | g_losses = [] 85 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 86 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 87 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 88 | loss += r_loss + g_loss 89 | r_losses.append(r_loss) 90 | g_losses.append(g_loss) 91 | 92 | return loss, r_losses, g_losses 93 | 94 | 95 | class FeatureMatchingLoss(nn.Module): 96 | """ 97 | Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. 98 | """ 99 | 100 | def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: 101 | """ 102 | Args: 103 | fmap_r (List[List[Tensor]]): List of feature maps from real samples. 104 | fmap_g (List[List[Tensor]]): List of feature maps from generated samples. 105 | 106 | Returns: 107 | Tensor: The calculated feature matching loss. 108 | """ 109 | loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) 110 | for dr, dg in zip(fmap_r, fmap_g): 111 | for rl, gl in zip(dr, dg): 112 | loss += torch.mean(torch.abs(rl - gl)) 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /vocos/models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils import weight_norm 6 | 7 | from vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm 8 | 9 | 10 | class Backbone(nn.Module): 11 | """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" 12 | 13 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 14 | """ 15 | Args: 16 | x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, 17 | C denotes output features, and L is the sequence length. 18 | 19 | Returns: 20 | Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, 21 | and H denotes the model dimension. 22 | """ 23 | raise NotImplementedError("Subclasses must implement the forward method.") 24 | 25 | 26 | class VocosBackbone(Backbone): 27 | """ 28 | Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization 29 | 30 | Args: 31 | input_channels (int): Number of input features channels. 32 | dim (int): Hidden dimension of the model. 33 | intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. 34 | num_layers (int): Number of ConvNeXtBlock layers. 35 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. 36 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 37 | None means non-conditional model. Defaults to None. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | input_channels: int, 43 | dim: int, 44 | intermediate_dim: int, 45 | num_layers: int, 46 | layer_scale_init_value: Optional[float] = None, 47 | adanorm_num_embeddings: Optional[int] = None, 48 | ): 49 | super().__init__() 50 | self.input_channels = input_channels 51 | self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) 52 | self.adanorm = adanorm_num_embeddings is not None 53 | if adanorm_num_embeddings: 54 | self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) 55 | else: 56 | self.norm = nn.LayerNorm(dim, eps=1e-6) 57 | layer_scale_init_value = layer_scale_init_value or 1 / num_layers 58 | self.convnext = nn.ModuleList( 59 | [ 60 | ConvNeXtBlock( 61 | dim=dim, 62 | intermediate_dim=intermediate_dim, 63 | layer_scale_init_value=layer_scale_init_value, 64 | adanorm_num_embeddings=adanorm_num_embeddings, 65 | ) 66 | for _ in range(num_layers) 67 | ] 68 | ) 69 | self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) 70 | self.apply(self._init_weights) 71 | 72 | def _init_weights(self, m): 73 | if isinstance(m, (nn.Conv1d, nn.Linear)): 74 | nn.init.trunc_normal_(m.weight, std=0.02) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 78 | bandwidth_id = kwargs.get('bandwidth_id', None) 79 | x = self.embed(x) 80 | if self.adanorm: 81 | assert bandwidth_id is not None 82 | x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) 83 | else: 84 | x = self.norm(x.transpose(1, 2)) 85 | x = x.transpose(1, 2) 86 | for conv_block in self.convnext: 87 | x = conv_block(x, cond_embedding_id=bandwidth_id) 88 | x = self.final_layer_norm(x.transpose(1, 2)) 89 | return x 90 | 91 | 92 | class VocosResNetBackbone(Backbone): 93 | """ 94 | Vocos backbone module built with ResBlocks. 95 | 96 | Args: 97 | input_channels (int): Number of input features channels. 98 | dim (int): Hidden dimension of the model. 99 | num_blocks (int): Number of ResBlock1 blocks. 100 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. 101 | """ 102 | 103 | def __init__( 104 | self, input_channels, dim, num_blocks, layer_scale_init_value=None, 105 | ): 106 | super().__init__() 107 | self.input_channels = input_channels 108 | self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)) 109 | layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 110 | self.resnet = nn.Sequential( 111 | *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)] 112 | ) 113 | 114 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 115 | x = self.embed(x) 116 | x = self.resnet(x) 117 | x = x.transpose(1, 2) 118 | return x 119 | -------------------------------------------------------------------------------- /vocos/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | 8 | class ConvNeXtBlock(nn.Module): 9 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 10 | 11 | Args: 12 | dim (int): Number of input channels. 13 | intermediate_dim (int): Dimensionality of the intermediate layer. 14 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 15 | Defaults to None. 16 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 17 | None means non-conditional LayerNorm. Defaults to None. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dim: int, 23 | intermediate_dim: int, 24 | layer_scale_init_value: float, 25 | adanorm_num_embeddings: Optional[int] = None, 26 | ): 27 | super().__init__() 28 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 29 | self.adanorm = adanorm_num_embeddings is not None 30 | if adanorm_num_embeddings: 31 | self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) 32 | else: 33 | self.norm = nn.LayerNorm(dim, eps=1e-6) 34 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 35 | self.act = nn.GELU() 36 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 37 | self.gamma = ( 38 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 39 | if layer_scale_init_value > 0 40 | else None 41 | ) 42 | 43 | def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: 44 | residual = x 45 | x = self.dwconv(x) 46 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 47 | if self.adanorm: 48 | assert cond_embedding_id is not None 49 | x = self.norm(x, cond_embedding_id) 50 | else: 51 | x = self.norm(x) 52 | x = self.pwconv1(x) 53 | x = self.act(x) 54 | x = self.pwconv2(x) 55 | if self.gamma is not None: 56 | x = self.gamma * x 57 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 58 | 59 | x = residual + x 60 | return x 61 | 62 | 63 | class AdaLayerNorm(nn.Module): 64 | """ 65 | Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes 66 | 67 | Args: 68 | num_embeddings (int): Number of embeddings. 69 | embedding_dim (int): Dimension of the embeddings. 70 | """ 71 | 72 | def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): 73 | super().__init__() 74 | self.eps = eps 75 | self.dim = embedding_dim 76 | self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 77 | self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 78 | torch.nn.init.ones_(self.scale.weight) 79 | torch.nn.init.zeros_(self.shift.weight) 80 | 81 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: 82 | scale = self.scale(cond_embedding_id) 83 | shift = self.shift(cond_embedding_id) 84 | x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) 85 | x = x * scale + shift 86 | return x 87 | 88 | 89 | class ResBlock1(nn.Module): 90 | """ 91 | ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, 92 | but without upsampling layers. 93 | 94 | Args: 95 | dim (int): Number of input channels. 96 | kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. 97 | dilation (tuple[int], optional): Dilation factors for the dilated convolutions. 98 | Defaults to (1, 3, 5). 99 | lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. 100 | Defaults to 0.1. 101 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 102 | Defaults to None. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | dim: int, 108 | kernel_size: int = 3, 109 | dilation: Tuple[int, int, int] = (1, 3, 5), 110 | lrelu_slope: float = 0.1, 111 | layer_scale_init_value: Optional[float] = None, 112 | ): 113 | super().__init__() 114 | self.lrelu_slope = lrelu_slope 115 | self.convs1 = nn.ModuleList( 116 | [ 117 | weight_norm( 118 | nn.Conv1d( 119 | dim, 120 | dim, 121 | kernel_size, 122 | 1, 123 | dilation=dilation[0], 124 | padding=self.get_padding(kernel_size, dilation[0]), 125 | ) 126 | ), 127 | weight_norm( 128 | nn.Conv1d( 129 | dim, 130 | dim, 131 | kernel_size, 132 | 1, 133 | dilation=dilation[1], 134 | padding=self.get_padding(kernel_size, dilation[1]), 135 | ) 136 | ), 137 | weight_norm( 138 | nn.Conv1d( 139 | dim, 140 | dim, 141 | kernel_size, 142 | 1, 143 | dilation=dilation[2], 144 | padding=self.get_padding(kernel_size, dilation[2]), 145 | ) 146 | ), 147 | ] 148 | ) 149 | 150 | self.convs2 = nn.ModuleList( 151 | [ 152 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 153 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 154 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 155 | ] 156 | ) 157 | 158 | self.gamma = nn.ParameterList( 159 | [ 160 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 161 | if layer_scale_init_value is not None 162 | else None, 163 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 164 | if layer_scale_init_value is not None 165 | else None, 166 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 167 | if layer_scale_init_value is not None 168 | else None, 169 | ] 170 | ) 171 | 172 | def forward(self, x: torch.Tensor) -> torch.Tensor: 173 | for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): 174 | xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) 175 | xt = c1(xt) 176 | xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) 177 | xt = c2(xt) 178 | if gamma is not None: 179 | xt = gamma * xt 180 | x = xt + x 181 | return x 182 | 183 | def remove_weight_norm(self): 184 | for l in self.convs1: 185 | remove_weight_norm(l) 186 | for l in self.convs2: 187 | remove_weight_norm(l) 188 | 189 | @staticmethod 190 | def get_padding(kernel_size: int, dilation: int = 1) -> int: 191 | return int((kernel_size * dilation - dilation) / 2) 192 | 193 | 194 | def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: 195 | """ 196 | Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. 197 | 198 | Args: 199 | x (Tensor): Input tensor. 200 | clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. 201 | 202 | Returns: 203 | Tensor: Element-wise logarithm of the input tensor with clipping applied. 204 | """ 205 | return torch.log(torch.clip(x, min=clip_val)) 206 | 207 | 208 | def symlog(x: torch.Tensor) -> torch.Tensor: 209 | return torch.sign(x) * torch.log1p(x.abs()) 210 | 211 | 212 | def symexp(x: torch.Tensor) -> torch.Tensor: 213 | return torch.sign(x) * (torch.exp(x.abs()) - 1) 214 | -------------------------------------------------------------------------------- /vocos/pretrained.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, Tuple, Union, Optional 4 | 5 | import torch 6 | import yaml 7 | from huggingface_hub import hf_hub_download 8 | from torch import nn 9 | from vocos.feature_extractors import FeatureExtractor, EncodecFeatures 10 | from vocos.heads import FourierHead 11 | from vocos.models import Backbone 12 | 13 | 14 | def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: 15 | """Instantiates a class with the given args and init. 16 | 17 | Args: 18 | args: Positional arguments required for instantiation. 19 | init: Dict of the form {"class_path":...,"init_args":...}. 20 | 21 | Returns: 22 | The instantiated class object. 23 | """ 24 | kwargs = init.get("init_args", {}) 25 | if not isinstance(args, tuple): 26 | args = (args,) 27 | class_module, class_name = init["class_path"].rsplit(".", 1) 28 | module = __import__(class_module, fromlist=[class_name]) 29 | args_class = getattr(module, class_name) 30 | return args_class(*args, **kwargs) 31 | 32 | 33 | class Vocos(nn.Module): 34 | """ 35 | The Vocos class represents a Fourier-based neural vocoder for audio synthesis. 36 | This class is primarily designed for inference, with support for loading from pretrained 37 | model checkpoints. It consists of three main components: a feature extractor, 38 | a backbone, and a head. 39 | """ 40 | 41 | def __init__( 42 | self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead, 43 | ): 44 | super().__init__() 45 | self.feature_extractor = feature_extractor 46 | self.backbone = backbone 47 | self.head = head 48 | 49 | @classmethod 50 | def from_hparams(cls, config_path: str) -> Vocos: 51 | """ 52 | Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. 53 | """ 54 | with open(config_path, "r") as f: 55 | config = yaml.safe_load(f) 56 | feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) 57 | backbone = instantiate_class(args=(), init=config["backbone"]) 58 | head = instantiate_class(args=(), init=config["head"]) 59 | model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) 60 | return model 61 | 62 | @classmethod 63 | def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: 64 | """ 65 | Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. 66 | """ 67 | config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision) 68 | model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision) 69 | model = cls.from_hparams(config_path) 70 | state_dict = torch.load(model_path, map_location="cpu") 71 | if isinstance(model.feature_extractor, EncodecFeatures): 72 | encodec_parameters = { 73 | "feature_extractor.encodec." + key: value 74 | for key, value in model.feature_extractor.encodec.state_dict().items() 75 | } 76 | state_dict.update(encodec_parameters) 77 | model.load_state_dict(state_dict) 78 | model.eval() 79 | return model 80 | 81 | @torch.inference_mode() 82 | def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: 83 | """ 84 | Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, 85 | which is then passed through the backbone and the head to reconstruct the audio output. 86 | 87 | Args: 88 | audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), 89 | where B is the batch size and L is the waveform length. 90 | 91 | 92 | Returns: 93 | Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). 94 | """ 95 | features = self.feature_extractor(audio_input, **kwargs) 96 | audio_output = self.decode(features, **kwargs) 97 | return audio_output 98 | 99 | @torch.inference_mode() 100 | def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: 101 | """ 102 | Method to decode audio waveform from already calculated features. The features input is passed through 103 | the backbone and the head to reconstruct the audio output. 104 | 105 | Args: 106 | features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, 107 | C denotes the feature dimension, and L is the sequence length. 108 | 109 | Returns: 110 | Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). 111 | """ 112 | x = self.backbone(features_input, **kwargs) 113 | audio_output = self.head(x) 114 | return audio_output 115 | 116 | @torch.inference_mode() 117 | def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: 118 | """ 119 | Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's 120 | codebook weights. 121 | 122 | Args: 123 | codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), 124 | where K is the number of codebooks, B is the batch size and L is the sequence length. 125 | 126 | Returns: 127 | Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, 128 | and L is the sequence length. 129 | """ 130 | assert isinstance( 131 | self.feature_extractor, EncodecFeatures 132 | ), "Feature extractor should be an instance of EncodecFeatures" 133 | 134 | if codes.dim() == 2: 135 | codes = codes.unsqueeze(1) 136 | 137 | n_bins = self.feature_extractor.encodec.quantizer.bins 138 | offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) 139 | embeddings_idxs = codes + offsets.view(-1, 1, 1) 140 | features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) 141 | features = features.transpose(1, 2) 142 | 143 | return features 144 | -------------------------------------------------------------------------------- /vocos/spectral_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import torch 4 | from torch import nn, view_as_real, view_as_complex 5 | 6 | 7 | class ISTFT(nn.Module): 8 | """ 9 | Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with 10 | windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. 11 | See issue: https://github.com/pytorch/pytorch/issues/62323 12 | Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. 13 | The NOLA constraint is met as we trim padded samples anyway. 14 | 15 | Args: 16 | n_fft (int): Size of Fourier transform. 17 | hop_length (int): The distance between neighboring sliding window frames. 18 | win_length (int): The size of window frame and STFT filter. 19 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 20 | """ 21 | 22 | def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): 23 | super().__init__() 24 | if padding not in ["center", "same"]: 25 | raise ValueError("Padding must be 'center' or 'same'.") 26 | self.padding = padding 27 | self.n_fft = n_fft 28 | self.hop_length = hop_length 29 | self.win_length = win_length 30 | window = torch.hann_window(win_length) 31 | self.register_buffer("window", window) 32 | 33 | def forward(self, spec: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. 36 | 37 | Args: 38 | spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, 39 | N is the number of frequency bins, and T is the number of time frames. 40 | 41 | Returns: 42 | Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. 43 | """ 44 | if self.padding == "center": 45 | # Fallback to pytorch native implementation 46 | return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) 47 | elif self.padding == "same": 48 | pad = (self.win_length - self.hop_length) // 2 49 | else: 50 | raise ValueError("Padding must be 'center' or 'same'.") 51 | 52 | assert spec.dim() == 3, "Expected a 3D tensor as input" 53 | B, N, T = spec.shape 54 | 55 | # Inverse FFT 56 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") 57 | ifft = ifft * self.window[None, :, None] 58 | 59 | # Overlap and Add 60 | output_size = (T - 1) * self.hop_length + self.win_length 61 | y = torch.nn.functional.fold( 62 | ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 63 | )[:, 0, 0, pad:-pad] 64 | 65 | # Window envelope 66 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) 67 | window_envelope = torch.nn.functional.fold( 68 | window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 69 | ).squeeze()[pad:-pad] 70 | 71 | # Normalize 72 | assert (window_envelope > 1e-11).all() 73 | y = y / window_envelope 74 | 75 | return y 76 | 77 | 78 | class MDCT(nn.Module): 79 | """ 80 | Modified Discrete Cosine Transform (MDCT) module. 81 | 82 | Args: 83 | frame_len (int): Length of the MDCT frame. 84 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 85 | """ 86 | 87 | def __init__(self, frame_len: int, padding: str = "same"): 88 | super().__init__() 89 | if padding not in ["center", "same"]: 90 | raise ValueError("Padding must be 'center' or 'same'.") 91 | self.padding = padding 92 | self.frame_len = frame_len 93 | N = frame_len // 2 94 | n0 = (N + 1) / 2 95 | window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() 96 | self.register_buffer("window", window) 97 | 98 | pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) 99 | post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) 100 | # view_as_real: NCCL Backend does not support ComplexFloat data type 101 | # https://github.com/pytorch/pytorch/issues/71613 102 | self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) 103 | self.register_buffer("post_twiddle", view_as_real(post_twiddle)) 104 | 105 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 106 | """ 107 | Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. 108 | 109 | Args: 110 | audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size 111 | and T is the length of the audio. 112 | 113 | Returns: 114 | Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames 115 | and N is the number of frequency bins. 116 | """ 117 | if self.padding == "center": 118 | audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) 119 | elif self.padding == "same": 120 | # hop_length is 1/2 frame_len 121 | audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) 122 | else: 123 | raise ValueError("Padding must be 'center' or 'same'.") 124 | 125 | x = audio.unfold(-1, self.frame_len, self.frame_len // 2) 126 | N = self.frame_len // 2 127 | x = x * self.window.expand(x.shape) 128 | X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] 129 | res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) 130 | return torch.real(res) * np.sqrt(2) 131 | 132 | 133 | class IMDCT(nn.Module): 134 | """ 135 | Inverse Modified Discrete Cosine Transform (IMDCT) module. 136 | 137 | Args: 138 | frame_len (int): Length of the MDCT frame. 139 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 140 | """ 141 | 142 | def __init__(self, frame_len: int, padding: str = "same"): 143 | super().__init__() 144 | if padding not in ["center", "same"]: 145 | raise ValueError("Padding must be 'center' or 'same'.") 146 | self.padding = padding 147 | self.frame_len = frame_len 148 | N = frame_len // 2 149 | n0 = (N + 1) / 2 150 | window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() 151 | self.register_buffer("window", window) 152 | 153 | pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) 154 | post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) 155 | self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) 156 | self.register_buffer("post_twiddle", view_as_real(post_twiddle)) 157 | 158 | def forward(self, X: torch.Tensor) -> torch.Tensor: 159 | """ 160 | Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. 161 | 162 | Args: 163 | X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, 164 | L is the number of frames, and N is the number of frequency bins. 165 | 166 | Returns: 167 | Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. 168 | """ 169 | B, L, N = X.shape 170 | Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) 171 | Y[..., :N] = X 172 | Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) 173 | y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) 174 | y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) 175 | result = y * self.window.expand(y.shape) 176 | output_size = (1, (L + 1) * N) 177 | audio = torch.nn.functional.fold( 178 | result.transpose(1, 2), 179 | output_size=output_size, 180 | kernel_size=(1, self.frame_len), 181 | stride=(1, self.frame_len // 2), 182 | )[:, 0, 0, :] 183 | 184 | if self.padding == "center": 185 | pad = self.frame_len // 2 186 | elif self.padding == "same": 187 | pad = self.frame_len // 4 188 | else: 189 | raise ValueError("Padding must be 'center' or 'same'.") 190 | 191 | audio = audio[:, pad:-pad] 192 | return audio 193 | --------------------------------------------------------------------------------