├── .gitignore ├── .gitmodules ├── .project-root ├── .python-version ├── .vscode └── launch.json ├── LICENSE ├── README.md ├── config ├── config.yaml ├── data │ ├── hubert_base.yaml │ ├── hubert_base_l3.yaml │ ├── hubert_large.yaml │ ├── hubert_large_l6.yaml │ ├── mel.yaml │ ├── necobert_base.yaml │ ├── necobert_masked.yaml │ ├── wav2vec2_base.yaml │ ├── wav2vec2_base_l3.yaml │ ├── wav2vec2_base_l6.yaml │ ├── wav2vec2_base_l9.yaml │ ├── wav2vec2_large.yaml │ ├── wav2vec2_large_l6.yaml │ ├── wavlm_base.yaml │ ├── wavlm_base_l3.yaml │ ├── wavlm_large.yaml │ ├── wavlm_large_l6.yaml │ └── wavlm_large_l8_xvector.yaml ├── model │ ├── hifigan_mel.yaml │ ├── hifigan_ssl.yaml │ ├── hifigan_ssl_large.yaml │ ├── hifigan_ssl_large_xvector.yaml │ └── wavegrad_ssl_large.yaml ├── preprocess │ ├── default.yaml │ ├── hubert_base.yaml │ ├── hubert_base_l3.yaml │ ├── hubert_large.yaml │ ├── hubert_large_l6.yaml │ ├── necobert_eng.yaml │ ├── necobert_masked.yaml │ ├── preprocess_dataset │ │ ├── glob_wav_dataset.yaml │ │ └── glob_wav_dataset_is2022.yaml │ ├── wav2vec2_base.yaml │ ├── wav2vec2_base_l3.yaml │ ├── wav2vec2_base_l6.yaml │ ├── wav2vec2_base_l9.yaml │ ├── wav2vec2_large.yaml │ ├── wav2vec2_large_l6.yaml │ ├── wavlm_base.yaml │ ├── wavlm_base_l3.yaml │ ├── wavlm_large.yaml │ ├── wavlm_large_l6.yaml │ └── wavlm_large_l8.yaml └── train │ └── default.yaml ├── notebooks ├── download_results.ipynb ├── evaluate_models.ipynb ├── get_upsample_rate.ipynb ├── plot_mel_feature.ipynb └── plot_ssl_feature.ipynb ├── pyproject.toml ├── requirements-dev.lock ├── requirements.lock ├── scripts ├── run_preprocessing.sh ├── run_preprocessing_l3.sh ├── run_preprocessing_layers.sh ├── run_synthesize_chime.sh ├── run_synthesize_cmu_arctic.sh ├── run_synthesize_pnl.sh ├── run_synthesize_street.sh ├── run_training_hubert_base.sh ├── run_training_hubert_base_continue.sh ├── run_training_hubert_base_l3.sh ├── run_training_hubert_base_l3_continue.sh ├── run_training_hubert_large.sh ├── run_training_hubert_large_continue.sh ├── run_training_hubert_large_l6.sh ├── run_training_hubert_large_l6_continue.sh ├── run_training_mel.sh ├── run_training_wav2vec2_base.sh ├── run_training_wav2vec2_base_continue.sh ├── run_training_wav2vec2_base_l3.sh ├── run_training_wav2vec2_base_l3_continue.sh ├── run_training_wav2vec2_base_l6.sh ├── run_training_wav2vec2_base_l6_continue.sh ├── run_training_wav2vec2_base_l9.sh ├── run_training_wav2vec2_base_l9_continue.sh ├── run_training_wav2vec2_large.sh ├── run_training_wav2vec2_large_continue.sh ├── run_training_wav2vec2_large_l6.sh ├── run_training_wav2vec2_large_l6_continue.sh ├── run_training_wavlm_base.sh ├── run_training_wavlm_base_continue.sh ├── run_training_wavlm_base_l3.sh ├── run_training_wavlm_base_l3_continue.sh ├── run_training_wavlm_large.sh ├── run_training_wavlm_large_continue.sh ├── run_training_wavlm_large_l6.sh ├── run_training_wavlm_large_l6_continue.sh └── run_training_wavlm_large_l8_xvector.sh └── src ├── lightning_vocoders ├── __init__.py ├── data │ └── datamodule.py ├── models │ ├── hifigan │ │ ├── generator_xvector.py │ │ ├── hifigan.py │ │ ├── lightning_module.py │ │ └── xvector_lightning_module.py │ └── wavegrad │ │ ├── lightning_module.py │ │ └── wavegrad.py └── preprocessor │ ├── dataset │ └── glob_wav_dataset.py │ └── preprocessor.py ├── preprocess.py ├── synthesize.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | preprocessed_data/ 163 | outputs/ 164 | tb_logs/ 165 | wandb/ 166 | events.out.tfevents* 167 | hparams.yaml 168 | *.ckpt 169 | synthesized/ 170 | *.wav 171 | *.log 172 | notebooks/ 173 | .idea/ 174 | *.tar.gz -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wataru-Nakata/ssl-vocoders/87ee8a4239c587b10b04c8b310684d35a9c64d25/.gitmodules -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wataru-Nakata/ssl-vocoders/87ee8a4239c587b10b04c8b310684d35a9c64d25/.project-root -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.10.11 2 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: training", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "src/train.py", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wataru-Nakata 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ssl-vocoders 2 | This repository contains training script for creating vocoders from speech Self-supervised learning models features. 3 | 4 | # Installation 5 | ```bash 6 | git clone https://github.com/Wataru-Nakata/ssl-vocoders 7 | cd ssl-vocoders 8 | pip install -e . 9 | ``` 10 | 11 | # Pretrained models 12 | Pretrained models are distributed on [huggingface](https://huggingface.co/Wataru/ssl-vocoder/tree/main) 13 | 14 | # How to use the pretrained models 15 | To load hifigan model trained on wavlm-large final hidden layer feature, try running the code below. 16 | ```python 17 | import lightning_vocoders 18 | from lightning_vocoders.models.hifigan.lightning_module import HiFiGANLightningModule 19 | model = HiFiGANLightningModule.load_from_checkpoint(lightning_vocoders.MODEL_URLS['wavlm-large'],map_location='cpu') 20 | ``` 21 | 22 | Also, colab example in provided for your better understanding. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-Rj6eBGc-0owr8q1u7KR9ca0V20ws8n4?usp=sharing) 23 | # Provieded checkpoints 24 | 25 | |SSL model name | layer 3 | layer 6| layer 9 | layer 12 | layer 24 | 26 | |---|---|---|---|---|---| 27 | | wav2vec2-base | ☑️ | ☑️ | ☑️ | ☑️ | N/A | 28 | | wav2vec2-large | ❌ | ❌ | ☑️ | ❌ | ☑️ | 29 | | hubert-base | ☑️ | ❌ | ❌ | ☑️ | ❌ | 30 | | hubert-large | ❌ | ☑️ | ❌ | ❌ | ☑️ | 31 | | wavlm-base | ☑️ | ❌ | ❌ | ☑️ | ❌ | 32 | | wavlm-large | ❌ | ☑️ | ❌ | ❌ | ☑️ | 33 | 34 | 35 | # How to train by yourself 36 | First, you need to run preprocessing script to extract required features from waveform 37 | ```bash 38 | python3 src/preprocess.py preprocess=wav2vec2-large 39 | ``` 40 | This will create webdataset format files in the path specified in config/preprocess/name_of_your_config.yaml 41 | 42 | Next you can run training by 43 | ``` 44 | python3 src/train.py data=your_confing_name model=your_config_name train=your_confing_name 45 | ``` 46 | 47 | 48 | # Acknoledgements 49 | I'd like to express my sincere gratitude to to 50 | * jlk876's HiFiGAN [paper](https://arxiv.org/abs/2010.05646) and their [official implementation](https://github.com/jik876/hifi-gan) 51 | 52 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - preprocess: default 4 | - model: hifigan_ssl 5 | - train: default 6 | - data: wav2vec2_base 7 | sample_rate: 22050 8 | compile: False -------------------------------------------------------------------------------- /config/data/hubert_base.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/hubert_base/hubert_base-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/hubert_base/hubert_base-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: hubert-base.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample -------------------------------------------------------------------------------- /config/data/hubert_base_l3.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/hubert_base_l3/hubert_base_l3-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/hubert_base_l3/hubert_base_l3-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: hubert-base-3.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | -------------------------------------------------------------------------------- /config/data/hubert_large.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/hubert_large/hubert_large-train-{000000..000138}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/hubert_large/hubert_large-val-{000000..000002}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: hubert-large.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample -------------------------------------------------------------------------------- /config/data/hubert_large_l6.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/hubert_large_l6/hubert_large_l6-train-{000000..000138}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/hubert_large_l6/hubert_large_l6-val-{000000..000002}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: hubert-large-6.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | -------------------------------------------------------------------------------- /config/data/mel.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_base/wav2vec2_base-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_base/wav2vec2_base-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: mel.pth 12 | samples_per_sec: 86.1328125 # 22050/256 13 | bias: 0 14 | 15 | segment_size: 16 | train: 32 17 | val: -1 # -1 for not segmenting the sample -------------------------------------------------------------------------------- /config/data/necobert_base.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: ./necobert_base/necobert_base-train-{000000..000120}.tar.gz 4 | val_dataset_path: ./necobert_base/necobert_base-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: necobert-base.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | -------------------------------------------------------------------------------- /config/data/necobert_masked.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: ./necobert_masked/necobert_masked-train-{000000..000120}.tar.gz 4 | val_dataset_path: ./necobert_masked/necobert_masked-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: necobert-base.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | -------------------------------------------------------------------------------- /config/data/wav2vec2_base.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_base/wav2vec2_base-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_base/wav2vec2_base-val-{000000..000001}.tar.gz 5 | 6 | train_batch_size: 32 7 | val_batch_size: 8 8 | 9 | target_feature: 10 | key: wav2vec2-base.pth 11 | samples_per_sec: 50 # 22050/256 12 | bias: 0.0025 13 | layer: 12 14 | 15 | segment_size: 16 | train: 50 17 | val: -1 # -1 for not segmenting the sample 18 | 19 | -------------------------------------------------------------------------------- /config/data/wav2vec2_base_l3.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_base_l3/wav2vec2_base_l3-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_base_l3/wav2vec2_base_l3-val-{000000..000001}.tar.gz 5 | 6 | train_batch_size: 32 7 | val_batch_size: 8 8 | 9 | target_feature: 10 | key: wav2vec2-base-3.pth 11 | samples_per_sec: 50 # 22050/256 12 | bias: 0.0025 13 | layer: 12 14 | 15 | segment_size: 16 | train: 50 17 | val: -1 # -1 for not segmenting the sample 18 | 19 | -------------------------------------------------------------------------------- /config/data/wav2vec2_base_l6.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_base_l6/wav2vec2_base_l6-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_base_l6/wav2vec2_base_l6-val-{000000..000001}.tar.gz 5 | 6 | train_batch_size: 32 7 | val_batch_size: 8 8 | 9 | target_feature: 10 | key: wav2vec2-base-6.pth 11 | samples_per_sec: 50 # 22050/256 12 | bias: 0.0025 13 | layer: 12 14 | 15 | segment_size: 16 | train: 50 17 | val: -1 # -1 for not segmenting the sample 18 | 19 | -------------------------------------------------------------------------------- /config/data/wav2vec2_base_l9.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_base_l9/wav2vec2_base_l9-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_base_l9/wav2vec2_base_l9-val-{000000..000001}.tar.gz 5 | 6 | train_batch_size: 32 7 | val_batch_size: 8 8 | 9 | target_feature: 10 | key: wav2vec2-base-9.pth 11 | samples_per_sec: 50 # 22050/256 12 | bias: 0.0025 13 | layer: 12 14 | 15 | segment_size: 16 | train: 50 17 | val: -1 # -1 for not segmenting the sample 18 | 19 | -------------------------------------------------------------------------------- /config/data/wav2vec2_large.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_large/wav2vec2_large-train-{000000..000138}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_large/wav2vec2_large-val-{000000..000002}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: wav2vec2-large.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | 20 | -------------------------------------------------------------------------------- /config/data/wav2vec2_large_l6.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wav2vec2_large_l6/wav2vec2_large_l6-train-{000000..000138}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wav2vec2_large_l6/wav2vec2_large_l6-val-{000000..000002}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: wav2vec2-large-6.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | 20 | -------------------------------------------------------------------------------- /config/data/wavlm_base.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wavlm_base/wavlm_base-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wavlm_base/wavlm_base-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: wavlm-base.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample -------------------------------------------------------------------------------- /config/data/wavlm_base_l3.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wavlm_base_l3/wavlm_base_l3-train-{000000..000119}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wavlm_base_l3/wavlm_base_l3-val-{000000..000001}.tar.gz 5 | 6 | 7 | train_batch_size: 32 8 | val_batch_size: 8 9 | 10 | target_feature: 11 | key: wavlm-base-3.pth 12 | samples_per_sec: 50 # 22050/256 13 | bias: 0.0025 14 | layer: 12 15 | 16 | segment_size: 17 | train: 50 18 | val: -1 # -1 for not segmenting the sample 19 | -------------------------------------------------------------------------------- /config/data/wavlm_large.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wavlm_large/wavlm_large-train-{000000..000138}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wavlm_large/wavlm_large-val-{000000..000002}.tar.gz 5 | 6 | train_batch_size: 32 7 | val_batch_size: 8 8 | 9 | target_feature: 10 | key: wavlm-large.pth 11 | samples_per_sec: 50 # 22050/256 12 | bias: 0.0025 13 | layer: 12 14 | 15 | segment_size: 16 | train: 50 17 | val: -1 # -1 for not segmenting the sample -------------------------------------------------------------------------------- /config/data/wavlm_large_l6.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /scratch/acc12576tt/wavlm_large_l6/wavlm_large_l6-train-{000000..000138}.tar.gz 4 | val_dataset_path: /scratch/acc12576tt/wavlm_large_l6/wavlm_large_l6-val-{000000..000002}.tar.gz 5 | 6 | train_batch_size: 32 7 | val_batch_size: 8 8 | 9 | target_feature: 10 | key: wavlm-large-6.pth 11 | samples_per_sec: 50 # 22050/256 12 | bias: 0.0025 13 | layer: 6 14 | 15 | segment_size: 16 | train: 50 17 | val: -1 # -1 for not segmenting the sample 18 | -------------------------------------------------------------------------------- /config/data/wavlm_large_l8_xvector.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: lightning_vocoders.data.datamodule.VocoderDataModule 3 | train_dataset_path: /home/wnakata/lightning-vocoders/wavlm_large_l8_xvector/wavlm_large_l8-train-{000000..000138}.tar.gz 4 | val_dataset_path: /home/wnakata/lightning-vocoders/wavlm_large_l8_xvector/wavlm_large_l8-val-{000000..000002}.tar.gz 5 | 6 | xvector: 7 | use_xvector: True 8 | model: 9 | _target_: speechbrain.pretrained.EncoderClassifier.from_hparams 10 | source: "speechbrain/spkrec-xvect-voxceleb" 11 | savedir: "pretrained_models/spkrec-xvect-voxceleb" 12 | sr: 16_000 13 | extract_secs: 5.0 14 | embedding_size: 512 15 | 16 | train_batch_size: 32 17 | val_batch_size: 1 18 | 19 | target_feature: 20 | key: wavlm-large-8.pth 21 | samples_per_sec: 50 # 22050/256 22 | bias: 0.0025 23 | layer: 8 24 | 25 | segment_size: 26 | train: 50 27 | val: -1 # -1 for not segmenting the sample 28 | -------------------------------------------------------------------------------- /config/model/hifigan_mel.yaml: -------------------------------------------------------------------------------- 1 | lightning_module: 2 | _target_: lightning_vocoders.models.hifigan.lightning_module.HiFiGANLightningModule 3 | 4 | generator: 5 | num_input_channels: 80 6 | upsample_rates: [8,8,2,2] 7 | upsample_initial_channel: 512 8 | upsample_kernel_sizes: [16,16,4,4] 9 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 10 | resblock_kernel_sizes: [3,7,11] 11 | resblock: "1" 12 | 13 | optim: 14 | opt_g: 15 | _target_: torch.optim.AdamW 16 | lr: 0.0002 17 | betas: [0.8,0.99] 18 | opt_d: 19 | _target_: torch.optim.AdamW 20 | lr: 0.0002 21 | betas: [0.8,0.99] 22 | scheduler_g: 23 | _target_: torch.optim.lr_scheduler.ExponentialLR 24 | gamma: 0.999998 25 | scheduler_d: 26 | _target_: torch.optim.lr_scheduler.ExponentialLR 27 | gamma: 0.999998 28 | adversarial_start_step: 10_000 29 | 30 | loss: 31 | recons_coef: 45 32 | fm_mpd_coef: 1 33 | fm_msd_coef: 1 34 | g_mpd_coef: 1 35 | g_msd_coef: 1 36 | logging_wav_samples: 10 37 | train_segment_size: 8192 -------------------------------------------------------------------------------- /config/model/hifigan_ssl.yaml: -------------------------------------------------------------------------------- 1 | lightning_module: 2 | _target_: lightning_vocoders.models.hifigan.lightning_module.HiFiGANLightningModule 3 | 4 | generator: 5 | num_input_channels: 768 6 | upsample_rates: [7,7,3,3] 7 | upsample_initial_channel: 512 8 | upsample_kernel_sizes: [15,15,7,7] 9 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 10 | resblock_kernel_sizes: [3,7,11] 11 | resblock: "1" 12 | 13 | optim: 14 | opt_g: 15 | _target_: torch.optim.AdamW 16 | lr: 0.0002 17 | betas: [0.8,0.99] 18 | opt_d: 19 | _target_: torch.optim.AdamW 20 | lr: 0.0002 21 | betas: [0.8,0.99] 22 | scheduler_g: 23 | _target_: torch.optim.lr_scheduler.ExponentialLR 24 | gamma: 0.999998 25 | scheduler_d: 26 | _target_: torch.optim.lr_scheduler.ExponentialLR 27 | gamma: 0.999998 28 | adversarial_start_step: 10_000 29 | 30 | loss: 31 | recons_coef: 45 32 | fm_mpd_coef: 1 33 | fm_msd_coef: 1 34 | g_mpd_coef: 1 35 | g_msd_coef: 1 36 | logging_wav_samples: 10 -------------------------------------------------------------------------------- /config/model/hifigan_ssl_large.yaml: -------------------------------------------------------------------------------- 1 | lightning_module: 2 | _target_: lightning_vocoders.models.hifigan.lightning_module.HiFiGANLightningModule 3 | 4 | generator: 5 | num_input_channels: 1024 6 | upsample_rates: [7,7,3,3] 7 | upsample_initial_channel: 512 8 | upsample_kernel_sizes: [15,15,7,7] 9 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 10 | resblock_kernel_sizes: [3,7,11] 11 | resblock: "1" 12 | 13 | optim: 14 | opt_g: 15 | _target_: torch.optim.AdamW 16 | lr: 0.0002 17 | betas: [0.8,0.99] 18 | opt_d: 19 | _target_: torch.optim.AdamW 20 | lr: 0.0002 21 | betas: [0.8,0.99] 22 | scheduler_g: 23 | _target_: torch.optim.lr_scheduler.ExponentialLR 24 | gamma: 0.999998 25 | scheduler_d: 26 | _target_: torch.optim.lr_scheduler.ExponentialLR 27 | gamma: 0.999998 28 | adversarial_start_step: 10_000 29 | 30 | loss: 31 | recons_coef: 45 32 | fm_mpd_coef: 1 33 | fm_msd_coef: 1 34 | g_mpd_coef: 1 35 | g_msd_coef: 1 36 | logging_wav_samples: 10 -------------------------------------------------------------------------------- /config/model/hifigan_ssl_large_xvector.yaml: -------------------------------------------------------------------------------- 1 | lightning_module: 2 | _target_: lightning_vocoders.models.hifigan.xvector_lightning_module.HiFiGANXvectorLightningModule 3 | 4 | generator: 5 | num_input_channels: 1024 6 | upsample_rates: [7,7,3,3] 7 | upsample_initial_channel: 512 8 | upsample_kernel_sizes: [15,15,7,7] 9 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 10 | resblock_kernel_sizes: [3,7,11] 11 | resblock: "1" 12 | xvector_dim: 512 13 | 14 | optim: 15 | opt_g: 16 | _target_: torch.optim.AdamW 17 | lr: 0.0002 18 | betas: [0.8,0.99] 19 | opt_d: 20 | _target_: torch.optim.AdamW 21 | lr: 0.0002 22 | betas: [0.8,0.99] 23 | scheduler_g: 24 | _target_: torch.optim.lr_scheduler.ExponentialLR 25 | gamma: 0.999998 26 | scheduler_d: 27 | _target_: torch.optim.lr_scheduler.ExponentialLR 28 | gamma: 0.999998 29 | adversarial_start_step: 10_000 30 | 31 | loss: 32 | recons_coef: 45 33 | fm_mpd_coef: 1 34 | fm_msd_coef: 1 35 | g_mpd_coef: 1 36 | g_msd_coef: 1 37 | logging_wav_samples: 10 38 | -------------------------------------------------------------------------------- /config/model/wavegrad_ssl_large.yaml: -------------------------------------------------------------------------------- 1 | lightning_module: 2 | _target_: lightning_vocoders.models.wavegrad.lightning_module.WaveGradLightningModule 3 | 4 | model_params: 5 | n_input_channels: 1024 6 | upsamples: 7 | - [768,512,7, [1,2,1,2]] 8 | - [512,256,7, [1,2,4,8]] 9 | - [256,128,3, [1,2,4,8]] 10 | - [128,128,3, [1,2,4,8]] 11 | downsamples: 12 | - [32, 128,3] 13 | - [128,128,3] 14 | - [128,256,7] 15 | downsample_conv: [ 1, 32, 5] 16 | film_layers: 17 | - [32,128] 18 | - [128,128] 19 | - [128,256] 20 | - [256,512] 21 | 22 | 23 | 24 | 25 | optim: 26 | _target_: torch.optim.AdamW 27 | lr: 0.0002 28 | noise_schedule: 29 | start: 1e-6 30 | stop: 0.01 31 | num: 1000 32 | n_logging_wav_samples: 10 -------------------------------------------------------------------------------- /config/preprocess/default.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "preprocessed_data/glob/glob-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "preprocessed_data/glob/glob-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 37 | - sr: 16_000 38 | key: wav2vec2-base.pth 39 | layer: 12 -------------------------------------------------------------------------------- /config/preprocess/hubert_base.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "./hubert_base/hubert_base-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "./hubert_base/hubert_base-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/hubert-base-ls960" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/hubert-base-ls960" 37 | - sr: 16_000 38 | key: hubert-base.pth 39 | layer: 12 40 | -------------------------------------------------------------------------------- /config/preprocess/hubert_base_l3.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/hubert_base_l3/hubert_base_l3-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/hubert_base_l3/hubert_base_l3-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/hubert-base-ls960" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/hubert-base-ls960" 37 | - sr: 16_000 38 | key: hubert-base-3.pth 39 | layer: 3 40 | -------------------------------------------------------------------------------- /config/preprocess/hubert_large.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "preprocessed_data/hubert_large/hubert_large-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "preprocessed_data/hubert_large/hubert_large-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/hubert-large-ll60k" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/hubert-large-ll60k" 37 | - sr: 16_000 38 | key: hubert-large.pth 39 | layer: 24 -------------------------------------------------------------------------------- /config/preprocess/hubert_large_l6.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/hubert_large_l6/hubert_large_l6-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/hubert_large_l6/hubert_large_l6-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/hubert-large-ll60k" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/hubert-large-ll60k" 37 | - sr: 16_000 38 | key: hubert-large-6.pth 39 | layer: 6 40 | -------------------------------------------------------------------------------- /config/preprocess/necobert_eng.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "./necobert_base/necobert_base-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "./necobert_base/necobert_base-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "Wataru/necobert-base-ls" 35 | trust_remote_code: True 36 | - null 37 | - sr: 24_000 38 | key: necobert-base.pth 39 | layer: 12 40 | -------------------------------------------------------------------------------- /config/preprocess/necobert_masked.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "./necobert_masked/necobert_masked-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "./necobert_masked/necobert_masked-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "Wataru/necobert-base-masked" 35 | trust_remote_code: True 36 | - null 37 | - sr: 24_000 38 | key: necobert-base.pth 39 | layer: 12 40 | -------------------------------------------------------------------------------- /config/preprocess/preprocess_dataset/glob_wav_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning_vocoders.preprocessor.dataset.glob_wav_dataset.GlobWavDataset 2 | roots: 3 | - /mnt/hdd/datasets/LJSpeech-1.1 4 | - /mnt/hdd/datasets/VCTK-Corpus 5 | - /mnt/hdd/datasets/libritts/LibriTTS/train-clean-100 6 | - /mnt/hdd/datasets/libritts/LibriTTS/train-clean-360 7 | patterns: 8 | - "**/*.wav" 9 | - "**/*.wav" 10 | - "**/*.wav" 11 | - "**/*.wav" -------------------------------------------------------------------------------- /config/preprocess/preprocess_dataset/glob_wav_dataset_is2022.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: lightning_vocoders.preprocessor.dataset.glob_wav_dataset.GlobWavDataset 3 | roots: 4 | - /work/ge43/e43001/datasets/libritts/LibriTTS/train-clean-100 5 | - /work/ge43/e43001/datasets/libritts/LibriTTS/train-clean-360 6 | - /work/ge43/e43001/datasets/libritts/LibriTTS/train-other-500 7 | patterns: 8 | - "**/*.wav" 9 | - "**/*.wav" 10 | - "**/*.wav" 11 | val: 12 | _target_: lightning_vocoders.preprocessor.dataset.glob_wav_dataset.GlobWavDataset 13 | roots: 14 | - /work/ge43/e43001/datasets/libritts/LibriTTS/dev-clean 15 | - /work/ge43/e43001/datasets/libritts/LibriTTS/dev-other 16 | patterns: 17 | - "**/*.wav" 18 | - "**/*.wav" 19 | -------------------------------------------------------------------------------- /config/preprocess/wav2vec2_base.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base/wav2vec2_base-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base/wav2vec2_base-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 37 | - sr: 16_000 38 | key: wav2vec2-base.pth 39 | layer: 12 40 | -------------------------------------------------------------------------------- /config/preprocess/wav2vec2_base_l3.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base_l3/wav2vec2_base_l3-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base_l3/wav2vec2_base_l3-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 37 | - sr: 16_000 38 | key: wav2vec2-base-3.pth 39 | layer: 3 40 | -------------------------------------------------------------------------------- /config/preprocess/wav2vec2_base_l6.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base_l6/wav2vec2_base_l6-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base_l6/wav2vec2_base_l6-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 37 | - sr: 16_000 38 | key: wav2vec2-base-6.pth 39 | layer: 6 40 | -------------------------------------------------------------------------------- /config/preprocess/wav2vec2_base_l9.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base_l9/wav2vec2_base_l9-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_base_l9/wav2vec2_base_l9-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-base" 37 | - sr: 16_000 38 | key: wav2vec2-base-9.pth 39 | layer: 9 40 | -------------------------------------------------------------------------------- /config/preprocess/wav2vec2_large.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_large/wav2vec2_large-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_large/wav2vec2_large-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-large" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-large" 37 | - sr: 16_000 38 | key: wav2vec2-large.pth 39 | layer: 24 40 | -------------------------------------------------------------------------------- /config/preprocess/wav2vec2_large_l6.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_large_l6/wav2vec2_large_l6-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wav2vec2_large_l6/wav2vec2_large_l6-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "facebook/wav2vec2-large" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "facebook/wav2vec2-large" 37 | - sr: 16_000 38 | key: wav2vec2-large-6.pth 39 | layer: 6 40 | -------------------------------------------------------------------------------- /config/preprocess/wavlm_base.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wavlm_base/wavlm_base-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wavlm_base/wavlm_base-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "microsoft/wavlm-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "microsoft/wavlm-base" 37 | - sr: 16_000 38 | key: wavlm-base.pth 39 | layer: 12 40 | -------------------------------------------------------------------------------- /config/preprocess/wavlm_base_l3.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wavlm_base_l3/wavlm_base_l3-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wavlm_base_l3/wavlm_base_l3-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "microsoft/wavlm-base" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "microsoft/wavlm-base" 37 | - sr: 16_000 38 | key: wavlm-base-3.pth 39 | layer: 3 40 | -------------------------------------------------------------------------------- /config/preprocess/wavlm_large.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wavlm_large/wavlm_large-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wavlm_large/wavlm_large-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "microsoft/wavlm-large" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "microsoft/wavlm-large" 37 | - sr: 16_000 38 | key: wavlm-large.pth 39 | layer: 24 40 | -------------------------------------------------------------------------------- /config/preprocess/wavlm_large_l6.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/preprocessed_data/wavlm_large_l6/wavlm_large_l6-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/preprocessed_data/wavlm_large_l6/wavlm_large_l6-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "microsoft/wavlm-large" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "microsoft/wavlm-large" 37 | - sr: 16_000 38 | key: wavlm-large-6.pth 39 | layer: 6 40 | -------------------------------------------------------------------------------- /config/preprocess/wavlm_large_l8.yaml: -------------------------------------------------------------------------------- 1 | stft: 2 | n_fft: 1024 3 | win_length: 1024 4 | hop_length: 256 5 | power: 1 6 | center: True 7 | mel: 8 | n_mels: 80 9 | sample_rate: ${sample_rate} 10 | f_min: 0 11 | f_max: 8000 12 | n_stft: 513 #${preprocess.stft.n_fft} // 2 + 1 13 | norm: "slaney" 14 | mel_scale: "slaney" 15 | 16 | audio: 17 | sample_rate: ${sample_rate} 18 | defaults: 19 | - preprocess_dataset: glob_wav_dataset 20 | 21 | 22 | train_tar_sink: 23 | _target_: webdataset.ShardWriter 24 | pattern: "/mnt/hdd/lightning_vocoder/wavlm_large_l8_xvector/wavlm_large_l8-train-%06d.tar.gz" 25 | val_tar_sink: 26 | _target_: webdataset.ShardWriter 27 | pattern: "/mnt/hdd/lightning_vocoder/wavlm_large_l8_xvector/wavlm_large_l8-val-%06d.tar.gz" 28 | val_size: 3000 29 | 30 | 31 | ssl_models: 32 | - 33 | - _target_: transformers.AutoModel.from_pretrained 34 | pretrained_model_name_or_path: "microsoft/wavlm-large" 35 | - _target_: transformers.AutoFeatureExtractor.from_pretrained 36 | pretrained_model_name_or_path: "microsoft/wavlm-large" 37 | - sr: 16_000 38 | key: wavlm-large-8.pth 39 | layer: 8 40 | -------------------------------------------------------------------------------- /config/train/default.yaml: -------------------------------------------------------------------------------- 1 | trainer: 2 | _target_: lightning.Trainer 3 | accelerator: "gpu" 4 | devices: [0] 5 | precision: 32 6 | check_val_every_n_epoch: 1 7 | max_epochs: 3300 8 | 9 | ckpt_path: 10 | 11 | loggers: 12 | - _target_: lightning.pytorch.loggers.TensorBoardLogger 13 | save_dir: "tb_logs" 14 | - _target_: lightning.pytorch.loggers.WandbLogger 15 | project: "hifigan-lightning" 16 | log_model: "all" 17 | 18 | -------------------------------------------------------------------------------- /notebooks/download_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import wandb\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 4, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "api = wandb.Api()" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 5, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "run_dict = {\n", 28 | " \"wav2vec2-base\": \"model-o4tpn4wd\", #\"leafy-sound-179\",\n", 29 | " \"wav2vec2-large\": \"model-90j4be4y\", #\"driven-tree-180\",\n", 30 | " \"hubert-base\": \"model-7yumzed9\",\n", 31 | " \"hubert-large\": \"model-kimjphhs\",\n", 32 | " \"wavlm-base\": \"model-ugwfwmm3\",\n", 33 | " \"wavlm-large\": \"model-v042svn1\",\n", 34 | " \"wav2vec2_l3\": \"model-5n6os6t5\" ,\n", 35 | " \"wav2vec2_l6\": \"model-a3kd0hkn\",\n", 36 | " \"wav2vec2_l9\": \"model-kh36yjxg\",\n", 37 | " \"mel\": \"model-e43i978z\",\n", 38 | "}" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 6, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-o4tpn4wd:latest, 996.84MB. 1 files... \n", 51 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 52 | "Done. 0:0:1.1\n", 53 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-90j4be4y:latest, 1007.34MB. 1 files... \n", 54 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 55 | "Done. 0:0:1.2\n", 56 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-7yumzed9:latest, 996.84MB. 1 files... \n", 57 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 58 | "Done. 0:0:1.1\n", 59 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-kimjphhs:latest, 1007.34MB. 1 files... \n", 60 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 61 | "Done. 0:0:1.1\n", 62 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-ugwfwmm3:latest, 996.84MB. 1 files... \n", 63 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 64 | "Done. 0:0:1.2\n", 65 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-v042svn1:latest, 1007.34MB. 1 files... \n", 66 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 67 | "Done. 0:0:1.2\n", 68 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-5n6os6t5:latest, 996.84MB. 1 files... \n", 69 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 70 | "Done. 0:0:43.3\n", 71 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-a3kd0hkn:latest, 996.84MB. 1 files... \n", 72 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 73 | "Done. 0:0:33.7\n", 74 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-kh36yjxg:latest, 996.84MB. 1 files... \n", 75 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 76 | "Done. 0:0:33.9\n", 77 | "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-e43i978z:latest, 970.14MB. 1 files... \n", 78 | "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", 79 | "Done. 0:0:1.1\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "for k,v in run_dict.items():\n", 85 | " api.artifact(f\"wataru9871/hifigan-lightning/{v}:latest\").download(f\"./checkpoints/{k}\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [] 94 | } 95 | ], 96 | "metadata": { 97 | "kernelspec": { 98 | "display_name": ".venv", 99 | "language": "python", 100 | "name": "python3" 101 | }, 102 | "language_info": { 103 | "codemirror_mode": { 104 | "name": "ipython", 105 | "version": 3 106 | }, 107 | "file_extension": ".py", 108 | "mimetype": "text/x-python", 109 | "name": "python", 110 | "nbconvert_exporter": "python", 111 | "pygments_lexer": "ipython3", 112 | "version": "3.10.11" 113 | }, 114 | "orig_nbformat": 4 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } 119 | -------------------------------------------------------------------------------- /notebooks/evaluate_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/wnakata/lightning-vocoders/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import torch\n", 19 | "import io\n", 20 | "from lightning_vocoders.models.hifigan.lightning_module import HiFiGANLightningModule\n", 21 | "from lightning_vocoders.preprocessor.dataset.glob_wav_dataset import GlobWavDataset\n", 22 | "from lightning_vocoders.preprocessor.preprocessor import Preprocessor\n", 23 | "from torch.utils.data.dataloader import DataLoader\n", 24 | "import lightning.pytorch as pl\n", 25 | "\n", 26 | "def synthesize(ckpt_path,wav_path,pattern,output_path):\n", 27 | " lightning_module = HiFiGANLightningModule\n", 28 | " lightning_module = lightning_module.load_from_checkpoint(ckpt_path)\n", 29 | " cfg = lightning_module.cfg\n", 30 | "\n", 31 | " dataset = GlobWavDataset([wav_path],[pattern],shuffled=False,add_random_string=False)\n", 32 | " preprocessor = Preprocessor(lightning_module.cfg)\n", 33 | "\n", 34 | " @torch.no_grad()\n", 35 | " def test_collate_fn(sample):\n", 36 | " assert len(sample) == 1 # only expect batch size of 1\n", 37 | " wav_name, (wav_data,sr), wav_path = sample[0]\n", 38 | " wav_data = wav_data[0].unsqueeze(0)\n", 39 | " preprocessed_sample = preprocessor.process_utterance(wav_name,wav_data,sr,wav_path)\n", 40 | " for k,v in preprocessed_sample.items():\n", 41 | " if k.endswith(\".pth\"):\n", 42 | " preprocessed_sample[k] = torch.load(io.BytesIO(v))\n", 43 | " batch = {\n", 44 | " \"resampled_speech.pth\": [preprocessed_sample[\"resampled_speech.pth\"]],\n", 45 | " \"input_feature\": preprocessed_sample[cfg.data.target_feature.key].unsqueeze(0),\n", 46 | " \"filenames\": [preprocessed_sample[\"__key__\"]],\n", 47 | " \"wav_lens\": None\n", 48 | " }\n", 49 | " return batch\n", 50 | " test_dataloader = DataLoader(dataset,collate_fn=test_collate_fn)\n", 51 | " lightning_module.output_path = output_path\n", 52 | " trainer = pl.Trainer(enable_progress_bar=False)\n", 53 | " trainer.test(lightning_module,test_dataloader)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "from pathlib import Path\n", 63 | "jnv_wavs = ('jnv','/mnt/hdd/datasets/jnv_corpus_ver1/',\"**/*.wav\")\n", 64 | "arctic_wavs = ('arctic',\"/mnt/hdd/datasets/cmu_arctic/\",\"**/arctic_a0[0-1][0-9][0-9].wav\")\n", 65 | "pnl_wavs = ('pnl',\"/mnt/hdd/datasets/Nonspeech/\",\"**/*.wav\")\n", 66 | "jvs_wavs = ('jvs',\"/mnt/hdd/datasets/jvs_ver1/jvs001/parallel100/\",\"**/*.wav\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "model_ckpt_path_dict = {\n", 76 | " \"wav2vec2-base\": \"checkpoints/wav2vec2-base/model.ckpt\",\n", 77 | " \"wav2vec2-large\": \"checkpoints/wav2vec2-large/model.ckpt\",\n", 78 | " \"wav2vec2-base-l3\": \"checkpoints/wav2vec2_l3/model.ckpt\",\n", 79 | " \"wav2vec2-base-l6\": \"checkpoints/wav2vec2_l6/model.ckpt\",\n", 80 | " \"wav2vec2-base-l9\": \"checkpoints/wav2vec2_l9/model.ckpt\",\n", 81 | " \"hubert-base\": \"checkpoints/hubert-base/model.ckpt\",\n", 82 | " \"hubert-large\": \"checkpoints/hubert-large/model.ckpt\",\n", 83 | " \"wavlm-base\": \"checkpoints/wavlm-base/model.ckpt\",\n", 84 | " \"wavlm-large\": \"checkpoints/wavlm-large/model.ckpt\"\n", 85 | "}" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stderr", 95 | "output_type": "stream", 96 | "text": [ 97 | "/home/wnakata/lightning-vocoders/.venv/lib/python3.10/site-packages/transformers/configuration_utils.py:380: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n", 98 | " warnings.warn(\n", 99 | "Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'quantizer.weight_proj.bias', 'project_hid.bias', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors']\n", 100 | "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 101 | "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 102 | "GPU available: True (cuda), used: True\n", 103 | "TPU available: False, using: 0 TPU cores\n", 104 | "IPU available: False, using: 0 IPUs\n", 105 | "HPU available: False, using: 0 HPUs\n", 106 | "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 107 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 108 | "/home/wnakata/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 109 | " rank_zero_warn(\n", 110 | "/home/wnakata/lightning-vocoders/.venv/lib/python3.10/site-packages/transformers/configuration_utils.py:380: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n", 111 | " warnings.warn(\n", 112 | "Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'quantizer.weight_proj.bias', 'project_hid.bias', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors']\n", 113 | "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 114 | "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 115 | "GPU available: True (cuda), used: True\n", 116 | "TPU available: False, using: 0 TPU cores\n", 117 | "IPU available: False, using: 0 IPUs\n", 118 | "HPU available: False, using: 0 HPUs\n", 119 | "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 120 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 121 | "Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'quantizer.weight_proj.bias', 'project_hid.bias', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors']\n", 122 | "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 123 | "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 124 | "GPU available: True (cuda), used: True\n", 125 | "TPU available: False, using: 0 TPU cores\n", 126 | "IPU available: False, using: 0 IPUs\n", 127 | "HPU available: False, using: 0 HPUs\n", 128 | "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 129 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 130 | "Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'quantizer.weight_proj.bias', 'project_hid.bias', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors']\n", 131 | "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 132 | "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 133 | "GPU available: True (cuda), used: True\n", 134 | "TPU available: False, using: 0 TPU cores\n", 135 | "IPU available: False, using: 0 IPUs\n", 136 | "HPU available: False, using: 0 HPUs\n", 137 | "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 138 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 139 | "Some weights of the model checkpoint at facebook/wav2vec2-large were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'quantizer.weight_proj.bias', 'project_hid.bias', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'quantizer.codevectors']\n", 140 | "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 141 | "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 142 | "GPU available: True (cuda), used: True\n", 143 | "TPU available: False, using: 0 TPU cores\n", 144 | "IPU available: False, using: 0 IPUs\n", 145 | "HPU available: False, using: 0 HPUs\n", 146 | "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 147 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" 148 | ] 149 | }, 150 | { 151 | "ename": "OutOfMemoryError", 152 | "evalue": "CUDA out of memory. Tried to allocate 38.00 MiB (GPU 0; 23.62 GiB total capacity; 1.64 GiB already allocated; 37.69 MiB free; 1.79 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF", 153 | "output_type": "error", 154 | "traceback": [ 155 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 156 | "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", 157 | "Cell \u001b[0;32mIn[4], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[39mfor\u001b[39;00m corpus \u001b[39min\u001b[39;00m [jnv_wavs,arctic_wavs,pnl_wavs,jvs_wavs]:\n\u001b[1;32m 3\u001b[0m name, path, pattern \u001b[39m=\u001b[39m corpus\n\u001b[0;32m----> 4\u001b[0m synthesize(ckpt_path,path,pattern,\u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mtest_wavs/\u001b[39;49m\u001b[39m{\u001b[39;49;00mmodel_name\u001b[39m}\u001b[39;49;00m\u001b[39m/\u001b[39;49m\u001b[39m{\u001b[39;49;00mname\u001b[39m}\u001b[39;49;00m\u001b[39m\"\u001b[39;49m)\n", 158 | "Cell \u001b[0;32mIn[1], line 36\u001b[0m, in \u001b[0;36msynthesize\u001b[0;34m(ckpt_path, wav_path, pattern, output_path)\u001b[0m\n\u001b[1;32m 34\u001b[0m lightning_module\u001b[39m.\u001b[39moutput_path \u001b[39m=\u001b[39m output_path\n\u001b[1;32m 35\u001b[0m trainer \u001b[39m=\u001b[39m pl\u001b[39m.\u001b[39mTrainer(enable_progress_bar\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[0;32m---> 36\u001b[0m trainer\u001b[39m.\u001b[39;49mtest(lightning_module,test_dataloader)\n", 159 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:706\u001b[0m, in \u001b[0;36mTrainer.test\u001b[0;34m(self, model, dataloaders, ckpt_path, verbose, datamodule)\u001b[0m\n\u001b[1;32m 704\u001b[0m model \u001b[39m=\u001b[39m _maybe_unwrap_optimized(model)\n\u001b[1;32m 705\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstrategy\u001b[39m.\u001b[39m_lightning_module \u001b[39m=\u001b[39m model\n\u001b[0;32m--> 706\u001b[0m \u001b[39mreturn\u001b[39;00m call\u001b[39m.\u001b[39;49m_call_and_handle_interrupt(\n\u001b[1;32m 707\u001b[0m \u001b[39mself\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_test_impl, model, dataloaders, ckpt_path, verbose, datamodule\n\u001b[1;32m 708\u001b[0m )\n", 160 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[39mreturn\u001b[39;00m trainer\u001b[39m.\u001b[39mstrategy\u001b[39m.\u001b[39mlauncher\u001b[39m.\u001b[39mlaunch(trainer_fn, \u001b[39m*\u001b[39margs, trainer\u001b[39m=\u001b[39mtrainer, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 43\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 44\u001b[0m \u001b[39mreturn\u001b[39;00m trainer_fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 46\u001b[0m \u001b[39mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 47\u001b[0m _call_teardown_hook(trainer)\n", 161 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:749\u001b[0m, in \u001b[0;36mTrainer._test_impl\u001b[0;34m(self, model, dataloaders, ckpt_path, verbose, datamodule)\u001b[0m\n\u001b[1;32m 744\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_data_connector\u001b[39m.\u001b[39mattach_data(model, test_dataloaders\u001b[39m=\u001b[39mdataloaders, datamodule\u001b[39m=\u001b[39mdatamodule)\n\u001b[1;32m 746\u001b[0m ckpt_path \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_checkpoint_connector\u001b[39m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 747\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mfn, ckpt_path, model_provided\u001b[39m=\u001b[39mmodel_provided, model_connected\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlightning_module \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 748\u001b[0m )\n\u001b[0;32m--> 749\u001b[0m results \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run(model, ckpt_path\u001b[39m=\u001b[39;49mckpt_path)\n\u001b[1;32m 750\u001b[0m \u001b[39m# remove the tensors from the test results\u001b[39;00m\n\u001b[1;32m 751\u001b[0m results \u001b[39m=\u001b[39m convert_tensors_to_scalars(results)\n", 162 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:935\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 930\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_signal_connector\u001b[39m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 932\u001b[0m \u001b[39m# ----------------------------\u001b[39;00m\n\u001b[1;32m 933\u001b[0m \u001b[39m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 934\u001b[0m \u001b[39m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 935\u001b[0m results \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_stage()\n\u001b[1;32m 937\u001b[0m \u001b[39m# ----------------------------\u001b[39;00m\n\u001b[1;32m 938\u001b[0m \u001b[39m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 939\u001b[0m \u001b[39m# ----------------------------\u001b[39;00m\n\u001b[1;32m 940\u001b[0m log\u001b[39m.\u001b[39mdebug(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__class__\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m: trainer tearing down\u001b[39m\u001b[39m\"\u001b[39m)\n", 163 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:971\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 968\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstrategy\u001b[39m.\u001b[39mbarrier(\u001b[39m\"\u001b[39m\u001b[39mrun-stage\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 970\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevaluating:\n\u001b[0;32m--> 971\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_evaluation_loop\u001b[39m.\u001b[39;49mrun()\n\u001b[1;32m 972\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredicting:\n\u001b[1;32m 973\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpredict_loop\u001b[39m.\u001b[39mrun()\n", 164 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:177\u001b[0m, in \u001b[0;36m_no_grad_context.._decorator\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 175\u001b[0m context_manager \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mno_grad\n\u001b[1;32m 176\u001b[0m \u001b[39mwith\u001b[39;00m context_manager():\n\u001b[0;32m--> 177\u001b[0m \u001b[39mreturn\u001b[39;00m loop_run(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", 165 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:108\u001b[0m, in \u001b[0;36m_EvaluationLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[1;32m 107\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 108\u001b[0m batch, batch_idx, dataloader_idx \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(data_fetcher)\n\u001b[1;32m 109\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbatch_progress\u001b[39m.\u001b[39mis_last_batch \u001b[39m=\u001b[39m data_fetcher\u001b[39m.\u001b[39mdone\n\u001b[1;32m 110\u001b[0m \u001b[39mif\u001b[39;00m previous_dataloader_idx \u001b[39m!=\u001b[39m dataloader_idx:\n\u001b[1;32m 111\u001b[0m \u001b[39m# the dataloader has changed, notify the logger connector\u001b[39;00m\n", 166 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fetchers.py:136\u001b[0m, in \u001b[0;36m_PrefetchDataFetcher.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdone:\n\u001b[1;32m 134\u001b[0m \u001b[39m# this will run only when no pre-fetching was done.\u001b[39;00m\n\u001b[1;32m 135\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 136\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fetch_next_batch(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataloader_iter)\n\u001b[1;32m 137\u001b[0m \u001b[39m# consume the batch we just fetched\u001b[39;00m\n\u001b[1;32m 138\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbatches\u001b[39m.\u001b[39mpop(\u001b[39m0\u001b[39m)\n", 167 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fetchers.py:150\u001b[0m, in \u001b[0;36m_PrefetchDataFetcher._fetch_next_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_start_profiler()\n\u001b[1;32m 149\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 150\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(iterator)\n\u001b[1;32m 151\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m 152\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_stop_profiler()\n", 168 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/combined_loader.py:276\u001b[0m, in \u001b[0;36mCombinedLoader.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 274\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__next__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[1;32m 275\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_iterator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m--> 276\u001b[0m out \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_iterator)\n\u001b[1;32m 277\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_iterator, _Sequential):\n\u001b[1;32m 278\u001b[0m \u001b[39mreturn\u001b[39;00m out\n", 169 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/combined_loader.py:122\u001b[0m, in \u001b[0;36m_Sequential.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 122\u001b[0m out \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49miterators[\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_iterator_idx])\n\u001b[1;32m 123\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_idx\n\u001b[1;32m 124\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_idx \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n", 170 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 631\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 636\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", 171 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 675\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[1;32m 676\u001b[0m index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index() \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 677\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index) \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 678\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[1;32m 679\u001b[0m data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n", 172 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset[possibly_batched_index]\n\u001b[0;32m---> 54\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcollate_fn(data)\n", 173 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", 174 | "Cell \u001b[0;32mIn[1], line 22\u001b[0m, in \u001b[0;36msynthesize..test_collate_fn\u001b[0;34m(sample)\u001b[0m\n\u001b[1;32m 20\u001b[0m wav_name, (wav_data,sr), wav_path \u001b[39m=\u001b[39m sample[\u001b[39m0\u001b[39m]\n\u001b[1;32m 21\u001b[0m wav_data \u001b[39m=\u001b[39m wav_data[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39munsqueeze(\u001b[39m0\u001b[39m)\n\u001b[0;32m---> 22\u001b[0m preprocessed_sample \u001b[39m=\u001b[39m preprocessor\u001b[39m.\u001b[39;49mprocess_utterance(wav_name,wav_data,sr,wav_path)\n\u001b[1;32m 23\u001b[0m \u001b[39mfor\u001b[39;00m k,v \u001b[39min\u001b[39;00m preprocessed_sample\u001b[39m.\u001b[39mitems():\n\u001b[1;32m 24\u001b[0m \u001b[39mif\u001b[39;00m k\u001b[39m.\u001b[39mendswith(\u001b[39m\"\u001b[39m\u001b[39m.pth\u001b[39m\u001b[39m\"\u001b[39m):\n", 175 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_context\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[39mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", 176 | "File \u001b[0;32m~/lightning-vocoders/src/lightning_vocoders/preprocessor/preprocessor.py:63\u001b[0m, in \u001b[0;36mPreprocessor.process_utterance\u001b[0;34m(self, basename, orig_waveform, sample_rate, audio_file_path)\u001b[0m\n\u001b[1;32m 61\u001b[0m inputs\u001b[39m.\u001b[39mto(\u001b[39m\"\u001b[39m\u001b[39mcuda\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 62\u001b[0m ssl_model\u001b[39m.\u001b[39mto(\u001b[39m\"\u001b[39m\u001b[39mcuda\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 63\u001b[0m output \u001b[39m=\u001b[39m ssl_model(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49minputs, output_hidden_states\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 64\u001b[0m sample[feature_cfg\u001b[39m.\u001b[39mkey] \u001b[39m=\u001b[39m webdataset\u001b[39m.\u001b[39mtorch_dumps(\n\u001b[1;32m 65\u001b[0m output\u001b[39m.\u001b[39mhidden_states[feature_cfg\u001b[39m.\u001b[39mlayer][\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mcpu()\n\u001b[1;32m 66\u001b[0m )\n\u001b[1;32m 68\u001b[0m \u001b[39mreturn\u001b[39;00m sample\n", 177 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", 178 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1306\u001b[0m, in \u001b[0;36mWav2Vec2Model.forward\u001b[0;34m(self, input_values, attention_mask, mask_time_indices, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1301\u001b[0m output_hidden_states \u001b[39m=\u001b[39m (\n\u001b[1;32m 1302\u001b[0m output_hidden_states \u001b[39mif\u001b[39;00m output_hidden_states \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_hidden_states\n\u001b[1;32m 1303\u001b[0m )\n\u001b[1;32m 1304\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1306\u001b[0m extract_features \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfeature_extractor(input_values)\n\u001b[1;32m 1307\u001b[0m extract_features \u001b[39m=\u001b[39m extract_features\u001b[39m.\u001b[39mtranspose(\u001b[39m1\u001b[39m, \u001b[39m2\u001b[39m)\n\u001b[1;32m 1309\u001b[0m \u001b[39mif\u001b[39;00m attention_mask \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 1310\u001b[0m \u001b[39m# compute reduced attention_mask corresponding to feature vectors\u001b[39;00m\n", 179 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", 180 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:453\u001b[0m, in \u001b[0;36mWav2Vec2FeatureEncoder.forward\u001b[0;34m(self, input_values)\u001b[0m\n\u001b[1;32m 448\u001b[0m hidden_states \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39mcheckpoint\u001b[39m.\u001b[39mcheckpoint(\n\u001b[1;32m 449\u001b[0m create_custom_forward(conv_layer),\n\u001b[1;32m 450\u001b[0m hidden_states,\n\u001b[1;32m 451\u001b[0m )\n\u001b[1;32m 452\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 453\u001b[0m hidden_states \u001b[39m=\u001b[39m conv_layer(hidden_states)\n\u001b[1;32m 455\u001b[0m \u001b[39mreturn\u001b[39;00m hidden_states\n", 181 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", 182 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:354\u001b[0m, in \u001b[0;36mWav2Vec2GroupNormConvLayer.forward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, hidden_states):\n\u001b[1;32m 353\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconv(hidden_states)\n\u001b[0;32m--> 354\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlayer_norm(hidden_states)\n\u001b[1;32m 355\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mactivation(hidden_states)\n\u001b[1;32m 356\u001b[0m \u001b[39mreturn\u001b[39;00m hidden_states\n", 183 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n", 184 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/nn/modules/normalization.py:273\u001b[0m, in \u001b[0;36mGroupNorm.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 273\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mgroup_norm(\n\u001b[1;32m 274\u001b[0m \u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnum_groups, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49meps)\n", 185 | "File \u001b[0;32m~/lightning-vocoders/.venv/lib/python3.10/site-packages/torch/nn/functional.py:2530\u001b[0m, in \u001b[0;36mgroup_norm\u001b[0;34m(input, num_groups, weight, bias, eps)\u001b[0m\n\u001b[1;32m 2528\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mExpected at least 2 dimensions for input tensor but received \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39minput\u001b[39m\u001b[39m.\u001b[39mdim()\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 2529\u001b[0m _verify_batch_size([\u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize(\u001b[39m0\u001b[39m) \u001b[39m*\u001b[39m \u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize(\u001b[39m1\u001b[39m) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_groups, num_groups] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(\u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize()[\u001b[39m2\u001b[39m:]))\n\u001b[0;32m-> 2530\u001b[0m \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39;49mgroup_norm(\u001b[39minput\u001b[39;49m, num_groups, weight, bias, eps, torch\u001b[39m.\u001b[39;49mbackends\u001b[39m.\u001b[39;49mcudnn\u001b[39m.\u001b[39;49menabled)\n", 186 | "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 38.00 MiB (GPU 0; 23.62 GiB total capacity; 1.64 GiB already allocated; 37.69 MiB free; 1.79 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "for model_name, ckpt_path in model_ckpt_path_dict.items():\n", 192 | " for corpus in [jnv_wavs,arctic_wavs,pnl_wavs,jvs_wavs]:\n", 193 | " name, path, pattern = corpus\n", 194 | " synthesize(ckpt_path,path,pattern,f\"test_wavs/{model_name}/{name}\")\n" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": ".venv", 208 | "language": "python", 209 | "name": "python3" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.10.11" 222 | }, 223 | "orig_nbformat": 4 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /notebooks/get_upsample_rate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/wnakata/lightning-vocoders/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n", 14 | "/home/wnakata/lightning-vocoders/.venv/lib/python3.10/site-packages/transformers/configuration_utils.py:380: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n", 15 | " warnings.warn(\n", 16 | "Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['quantizer.codevectors', 'quantizer.weight_proj.bias', 'project_q.weight', 'project_hid.bias', 'quantizer.weight_proj.weight', 'project_hid.weight', 'project_q.bias']\n", 17 | "- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 18 | "- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "from lightning_vocoders.models.hifigan.lightning_module import HiFiGANLightningModule\n", 24 | "lightning_module = HiFiGANLightningModule.load_from_checkpoint(\"../tb_logs/lightning_logs/version_159/checkpoints/epoch=9-step=29480.ckpt\")\n", 25 | "from lightning_vocoders.models.hifigan.hifigan import Generator\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "generator_cfg = lightning_module.cfg.model.generator\n", 35 | "generator = Generator(generator_cfg)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 4, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "text/plain": [ 46 | "{'num_input_channels': 768, 'upsample_rates': [7, 7, 3, 3], 'upsample_initial_channel': 512, 'upsample_kernel_sizes': [15, 15, 7, 7], 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 'resblock_kernel_sizes': [3, 7, 11], 'resblock': '1'}" 47 | ] 48 | }, 49 | "execution_count": 4, 50 | "metadata": {}, 51 | "output_type": "execute_result" 52 | } 53 | ], 54 | "source": [ 55 | "generator_cfg" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 5, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "torch.Size([1, 1, 22050])" 67 | ] 68 | }, 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "import torch\n", 76 | "generator_cfg.upsample_rates = [7,7,3,3]\n", 77 | "generator_cfg.upsample_kernel_sizes = [15,15,7,7]\n", 78 | "generator = Generator(generator_cfg)\n", 79 | "generator(torch.randn((1,50,768))).size()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 9, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stderr", 89 | "output_type": "stream", 90 | "text": [ 91 | " 0%| | 80/16360 [00:00<00:45, 358.50it/s]\n", 92 | " 0%| | 80/16040 [00:05<18:36, 14.30it/s]\n", 93 | " 1%| | 80/15720 [00:05<18:13, 14.31it/s]\n", 94 | " 1%| | 80/15400 [00:05<17:52, 14.29it/s]\n", 95 | " 1%| | 80/15080 [00:05<17:29, 14.29it/s]\n", 96 | " 1%| | 80/14760 [00:05<17:08, 14.27it/s]\n", 97 | " 1%| | 80/14440 [00:05<16:45, 14.28it/s]\n", 98 | " 1%| | 80/14120 [00:05<16:25, 14.25it/s]\n", 99 | " 1%| | 80/13800 [00:05<16:02, 14.26it/s]\n", 100 | " 1%| | 80/13480 [00:05<15:39, 14.27it/s]\n", 101 | " 1%| | 80/13160 [00:05<15:16, 14.28it/s]\n", 102 | " 1%| | 80/12840 [00:05<14:54, 14.26it/s]\n", 103 | " 1%| | 80/12520 [00:05<14:31, 14.27it/s]\n", 104 | " 1%| | 80/12200 [00:05<14:10, 14.25it/s]\n", 105 | " 1%| | 80/11880 [00:05<13:47, 14.26it/s]\n", 106 | " 1%| | 80/11560 [00:05<13:26, 14.23it/s]\n", 107 | " 0%| | 10/11240 [00:00<14:13, 13.16it/s]\n" 108 | ] 109 | }, 110 | { 111 | "ename": "KeyboardInterrupt", 112 | "evalue": "", 113 | "output_type": "error", 114 | "traceback": [ 115 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 116 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 117 | "Cell \u001b[0;32mIn[9], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m start \u001b[39m=\u001b[39m \u001b[39m16_000\u001b[39m\u001b[39m/\u001b[39m\u001b[39m/\u001b[39m\u001b[39m50\u001b[39m \u001b[39m*\u001b[39m i\n\u001b[1;32m 6\u001b[0m \u001b[39mfor\u001b[39;00m j \u001b[39min\u001b[39;00m tqdm(\u001b[39mrange\u001b[39m(start,\u001b[39m17_000\u001b[39m)):\n\u001b[0;32m----> 7\u001b[0m \u001b[39mif\u001b[39;00m lightning_module\u001b[39m.\u001b[39mpreprocessor\u001b[39m.\u001b[39mssl_model(input_values\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39;49mrandn(\u001b[39m1\u001b[39;49m,j)\u001b[39m.\u001b[39;49mcuda())\u001b[39m.\u001b[39mlast_hidden_state\u001b[39m.\u001b[39msize(\u001b[39m1\u001b[39m) \u001b[39m==\u001b[39m i:\n\u001b[1;32m 8\u001b[0m lengths\u001b[39m.\u001b[39mappend(j )\n\u001b[1;32m 9\u001b[0m \u001b[39mbreak\u001b[39;00m\n", 118 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "from tqdm import tqdm\n", 124 | "lengths = []\n", 125 | "lightning_module.preprocessor.ssl_model.eval()\n", 126 | "for i in range(2,50):\n", 127 | " start = 16_000//50 * i\n", 128 | " for j in tqdm(range(start,17_000)):\n", 129 | " if lightning_module.preprocessor.ssl_model(input_values=torch.randn(1,j).cuda()).last_hidden_state.size(1) == i:\n", 130 | " lengths.append(j )\n", 131 | " break" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 10, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "[720,\n", 143 | " 1040,\n", 144 | " 1360,\n", 145 | " 1680,\n", 146 | " 2000,\n", 147 | " 2320,\n", 148 | " 2640,\n", 149 | " 2960,\n", 150 | " 3280,\n", 151 | " 3600,\n", 152 | " 3920,\n", 153 | " 4240,\n", 154 | " 4560,\n", 155 | " 4880,\n", 156 | " 5200,\n", 157 | " 5520]" 158 | ] 159 | }, 160 | "execution_count": 10, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "lengths" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 12, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "data": { 176 | "text/plain": [ 177 | "[80, 400, 720, 1040, 1360, 1680, 2000, 2320, 2640, 2960]" 178 | ] 179 | }, 180 | "execution_count": 12, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "def get_length(x):\n", 187 | " return 16_000//50 * x + 80\n", 188 | "[get_length(i) for i in range(10)]" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 22, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/plain": [ 199 | "49" 200 | ] 201 | }, 202 | "execution_count": 22, 203 | "metadata": {}, 204 | "output_type": "execute_result" 205 | } 206 | ], 207 | "source": [ 208 | "lightning_module.preprocessor.ssl_model(input_values=torch.randn(1,get_length(49)).cuda()).last_hidden_state.size(1)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "ge" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "from matplotlib import pyplot as plt\n", 227 | "plt.plot(lengths)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stderr", 237 | "output_type": "stream", 238 | "text": [ 239 | "It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.\n" 240 | ] 241 | }, 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "{'input_values': [array([ 0.15781322, -1.3593025 , 0.07017981, ..., 0.6999978 ,\n", 246 | " 0.65397877, 0.6198929 ], dtype=float32)]}" 247 | ] 248 | }, 249 | "execution_count": 31, 250 | "metadata": {}, 251 | "output_type": "execute_result" 252 | } 253 | ], 254 | "source": [ 255 | "lightning_module.preprocessor.ssl_prepreocessor(torch.randn(16_000))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "torch.Size([1, 1, 22016])" 267 | ] 268 | }, 269 | "execution_count": 31, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "generator_cfg.upsample_rates = [8,8,2,2]\n", 276 | "generator_cfg.upsample_kernel_sizes = [16,16,8,8]\n", 277 | "generator = Generator(generator_cfg)\n", 278 | "generator(torch.randn((1,86,80))).size()" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "torch.Size([1, 1, 32000])" 290 | ] 291 | }, 292 | "execution_count": 28, 293 | "metadata": {}, 294 | "output_type": "execute_result" 295 | } 296 | ], 297 | "source": [ 298 | "generator_cfg.upsample_rates = [5,4,4,2,2,2]\n", 299 | "generator_cfg.upsample_kernel_sizes = [11,8,8,4,4,4]\n", 300 | "generator = Generator(generator_cfg)\n", 301 | "generator(torch.randn((1,50,80))).size()" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [] 310 | } 311 | ], 312 | "metadata": { 313 | "kernelspec": { 314 | "display_name": ".venv", 315 | "language": "python", 316 | "name": "python3" 317 | }, 318 | "language_info": { 319 | "codemirror_mode": { 320 | "name": "ipython", 321 | "version": 3 322 | }, 323 | "file_extension": ".py", 324 | "mimetype": "text/x-python", 325 | "name": "python", 326 | "nbconvert_exporter": "python", 327 | "pygments_lexer": "ipython3", 328 | "version": "3.10.11" 329 | }, 330 | "orig_nbformat": 4 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 2 334 | } 335 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lightning-vocoders" 3 | version = "0.1.0" 4 | description = "Add a short description here" 5 | authors = [ 6 | { name = "Wataru Nakata", email = "wataru9871@gmail.com" } 7 | ] 8 | dependencies = [ "soundfile~=0.12.1", "transformers~=4.35.2", "torchaudio~=2.0.2", "lightning~=2.0.2", "hydra-core~=1.3.2", "pyrootutils~=1.0.4", "webdataset~=0.2.48", "tensorboard~=2.13.0", "wandb~=0.15.3", "torch~=2.0.1", "pandarallel~=1.6.5", "speechbrain~=0.5.15", "descript-audio-codec~=1.0.0", "xvector-jtubespeech~=0.0.2"] 9 | readme = "README.md" 10 | requires-python = ">= 3.9" 11 | 12 | [build-system] 13 | requires = ["hatchling"] 14 | build-backend = "hatchling.build" 15 | 16 | [tool.rye] 17 | managed = true 18 | dev-dependencies = ["ipykernel~=6.23.1", "black~=23.3.0", "nnmnkwii~=0.1.2", "pysptk~=0.2.0", "librosa~=0.10.0.post2", "pyworld~=0.3.3", "ipywidgets~=8.0.6", "seaborn~=0.12.2", "umap-learn~=0.5.5"] 19 | [tool.hatch.metadata] 20 | allow-direct-references = true 21 | -------------------------------------------------------------------------------- /requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | 9 | -e file:. 10 | absl-py==1.4.0 11 | aiohttp==3.8.4 12 | aiosignal==1.3.1 13 | antlr4-python3-runtime==4.9.3 14 | anyio==3.7.0 15 | appdirs==1.4.4 16 | argbind==0.3.7 17 | arrow==1.2.3 18 | asttokens==2.2.1 19 | async-timeout==4.0.2 20 | attrs==23.1.0 21 | audioread==3.0.0 22 | backcall==0.2.0 23 | beautifulsoup4==4.12.2 24 | black==23.3.0 25 | blessed==1.20.0 26 | braceexpand==0.1.7 27 | cachetools==5.3.1 28 | certifi==2023.5.7 29 | cffi==1.15.1 30 | charset-normalizer==3.1.0 31 | click==8.1.3 32 | cmake==3.26.3 33 | comm==0.1.3 34 | contourpy==1.1.1 35 | croniter==1.3.15 36 | cycler==0.11.0 37 | cython==0.29.35 38 | dateutils==0.6.12 39 | debugpy==1.6.7 40 | decorator==5.1.1 41 | deepdiff==6.3.0 42 | descript-audio-codec==1.0.0 43 | descript-audiotools==0.7.2 44 | dill==0.3.6 45 | docker-pycreds==0.4.0 46 | docstring-parser==0.15 47 | einops==0.7.0 48 | exceptiongroup==1.1.1 49 | executing==1.2.0 50 | fastapi==0.88.0 51 | fastdtw==0.3.4 52 | ffmpy==0.3.1 53 | filelock==3.12.0 54 | fire==0.5.0 55 | flatten-dict==0.4.2 56 | fonttools==4.42.1 57 | frozenlist==1.3.3 58 | fsspec==2023.5.0 59 | future==0.18.3 60 | gitdb==4.0.10 61 | gitpython==3.1.31 62 | google-auth==2.17.3 63 | google-auth-oauthlib==1.0.0 64 | grpcio==1.54.2 65 | h11==0.14.0 66 | huggingface-hub==0.19.4 67 | hydra-core==1.3.2 68 | hyperpyyaml==1.2.1 69 | idna==3.4 70 | importlib-resources==6.1.1 71 | inquirer==3.1.3 72 | ipykernel==6.23.1 73 | ipython==8.13.2 74 | ipywidgets==8.0.6 75 | itsdangerous==2.1.2 76 | jedi==0.18.2 77 | jinja2==3.1.2 78 | joblib==1.2.0 79 | julius==0.2.7 80 | jupyter-client==8.2.0 81 | jupyter-core==5.3.0 82 | jupyterlab-widgets==3.0.7 83 | kiwisolver==1.4.5 84 | lazy-loader==0.2 85 | librosa==0.10.0.post2 86 | lightning==2.0.2 87 | lightning-cloud==0.5.36 88 | lightning-utilities==0.8.0 89 | lit==16.0.5 90 | llvmlite==0.40.0 91 | markdown==3.4.3 92 | markdown-it-py==2.2.0 93 | markdown2==2.4.11 94 | markupsafe==2.1.2 95 | matplotlib==3.8.0 96 | matplotlib-inline==0.1.6 97 | mdurl==0.1.2 98 | mpmath==1.3.0 99 | msgpack==1.0.5 100 | multidict==6.0.4 101 | mypy-extensions==1.0.0 102 | nest-asyncio==1.5.6 103 | networkx==3.1 104 | nnmnkwii==0.1.2 105 | numba==0.57.0 106 | numpy==1.24.3 107 | nvidia-cublas-cu11==11.10.3.66 108 | nvidia-cuda-cupti-cu11==11.7.101 109 | nvidia-cuda-nvrtc-cu11==11.7.99 110 | nvidia-cuda-runtime-cu11==11.7.99 111 | nvidia-cudnn-cu11==8.5.0.96 112 | nvidia-cufft-cu11==10.9.0.58 113 | nvidia-curand-cu11==10.2.10.91 114 | nvidia-cusolver-cu11==11.4.0.1 115 | nvidia-cusparse-cu11==11.7.4.91 116 | nvidia-nccl-cu11==2.14.3 117 | nvidia-nvtx-cu11==11.7.91 118 | oauthlib==3.2.2 119 | omegaconf==2.3.0 120 | ordered-set==4.1.0 121 | packaging==23.1 122 | pandarallel==1.6.5 123 | pandas==2.0.2 124 | parso==0.8.3 125 | pathspec==0.11.1 126 | pathtools==0.1.2 127 | pexpect==4.8.0 128 | pickleshare==0.7.5 129 | pillow==10.0.1 130 | platformdirs==3.5.1 131 | pooch==1.6.0 132 | prompt-toolkit==3.0.38 133 | protobuf==3.19.6 134 | psutil==5.9.5 135 | ptyprocess==0.7.0 136 | pure-eval==0.2.2 137 | pyasn1==0.5.0 138 | pyasn1-modules==0.3.0 139 | pycparser==2.21 140 | pydantic==1.10.8 141 | pygments==2.15.1 142 | pyjwt==2.7.0 143 | pyloudnorm==0.1.1 144 | pyparsing==3.1.1 145 | pyrootutils==1.0.4 146 | pysptk==0.2.0 147 | pystoi==0.3.3 148 | python-dateutil==2.8.2 149 | python-dotenv==1.0.0 150 | python-editor==1.0.4 151 | python-multipart==0.0.6 152 | pytorch-lightning==2.0.2 153 | pytz==2023.3 154 | pyworld==0.3.3 155 | pyyaml==6.0 156 | pyzmq==25.1.0 157 | randomname==0.2.1 158 | readchar==4.0.5 159 | regex==2023.5.5 160 | requests==2.31.0 161 | requests-oauthlib==1.3.1 162 | rich==13.3.5 163 | rsa==4.9 164 | ruamel-yaml==0.17.28 165 | ruamel-yaml-clib==0.2.7 166 | safetensors==0.4.1 167 | scikit-learn==1.2.2 168 | scipy==1.10.1 169 | seaborn==0.12.2 170 | sentencepiece==0.1.99 171 | sentry-sdk==1.21.1 172 | setproctitle==1.3.2 173 | six==1.16.0 174 | smmap==5.0.0 175 | sniffio==1.3.0 176 | soundfile==0.12.1 177 | soupsieve==2.4.1 178 | soxr==0.3.5 179 | speechbrain==0.5.15 180 | stack-data==0.6.2 181 | starlette==0.22.0 182 | starsessions==1.3.0 183 | sympy==1.12 184 | tensorboard==2.13.0 185 | tensorboard-data-server==0.7.0 186 | termcolor==2.4.0 187 | threadpoolctl==3.1.0 188 | tokenizers==0.15.0 189 | tomli==2.0.1 190 | torch==2.0.1 191 | torch-stoi==0.1.2 192 | torchaudio==2.0.2 193 | torchmetrics==0.11.4 194 | tornado==6.3.2 195 | tqdm==4.65.0 196 | traitlets==5.9.0 197 | transformers==4.35.2 198 | triton==2.0.0 199 | typing-extensions==4.6.2 200 | tzdata==2023.3 201 | urllib3==2.0.2 202 | uvicorn==0.22.0 203 | wandb==0.15.3 204 | wcwidth==0.2.6 205 | webdataset==0.2.48 206 | websocket-client==1.5.2 207 | websockets==11.0.3 208 | werkzeug==2.3.4 209 | wheel==0.40.0 210 | widgetsnbextension==4.0.7 211 | yarl==1.9.2 212 | # The following packages are considered to be unsafe in a requirements file: 213 | setuptools==67.8.0 214 | -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | 9 | -e file:. 10 | absl-py==1.4.0 11 | aiohttp==3.8.4 12 | aiosignal==1.3.1 13 | antlr4-python3-runtime==4.9.3 14 | anyio==3.7.0 15 | appdirs==1.4.4 16 | argbind==0.3.7 17 | arrow==1.2.3 18 | asttokens==2.4.1 19 | async-timeout==4.0.2 20 | attrs==23.1.0 21 | audioread==3.0.1 22 | beautifulsoup4==4.12.2 23 | blessed==1.20.0 24 | braceexpand==0.1.7 25 | cachetools==5.3.1 26 | certifi==2023.5.7 27 | cffi==1.15.1 28 | charset-normalizer==3.1.0 29 | click==8.1.3 30 | cmake==3.26.3 31 | contourpy==1.2.0 32 | croniter==1.3.15 33 | cycler==0.12.1 34 | dateutils==0.6.12 35 | decorator==5.1.1 36 | deepdiff==6.3.0 37 | descript-audio-codec==1.0.0 38 | descript-audiotools==0.7.2 39 | dill==0.3.6 40 | docker-pycreds==0.4.0 41 | docstring-parser==0.15 42 | einops==0.7.0 43 | exceptiongroup==1.1.1 44 | executing==2.0.1 45 | fastapi==0.88.0 46 | ffmpy==0.3.1 47 | filelock==3.12.0 48 | fire==0.5.0 49 | flatten-dict==0.4.2 50 | fonttools==4.46.0 51 | frozenlist==1.3.3 52 | fsspec==2023.5.0 53 | future==0.18.3 54 | gitdb==4.0.10 55 | gitpython==3.1.31 56 | google-auth==2.17.3 57 | google-auth-oauthlib==1.0.0 58 | grpcio==1.54.2 59 | h11==0.14.0 60 | huggingface-hub==0.19.4 61 | hydra-core==1.3.2 62 | hyperpyyaml==1.2.1 63 | idna==3.4 64 | importlib-resources==6.1.1 65 | inquirer==3.1.3 66 | ipython==8.18.1 67 | itsdangerous==2.1.2 68 | jedi==0.19.1 69 | jinja2==3.1.2 70 | joblib==1.2.0 71 | julius==0.2.7 72 | kiwisolver==1.4.5 73 | lazy-loader==0.3 74 | librosa==0.10.1 75 | lightning==2.0.2 76 | lightning-cloud==0.5.36 77 | lightning-utilities==0.8.0 78 | lit==16.0.5 79 | llvmlite==0.41.1 80 | markdown==3.4.3 81 | markdown-it-py==2.2.0 82 | markdown2==2.4.11 83 | markupsafe==2.1.2 84 | matplotlib==3.8.2 85 | matplotlib-inline==0.1.6 86 | mdurl==0.1.2 87 | mpmath==1.3.0 88 | msgpack==1.0.7 89 | multidict==6.0.4 90 | networkx==3.1 91 | numba==0.58.1 92 | numpy==1.24.3 93 | nvidia-cublas-cu11==11.10.3.66 94 | nvidia-cuda-cupti-cu11==11.7.101 95 | nvidia-cuda-nvrtc-cu11==11.7.99 96 | nvidia-cuda-runtime-cu11==11.7.99 97 | nvidia-cudnn-cu11==8.5.0.96 98 | nvidia-cufft-cu11==10.9.0.58 99 | nvidia-curand-cu11==10.2.10.91 100 | nvidia-cusolver-cu11==11.4.0.1 101 | nvidia-cusparse-cu11==11.7.4.91 102 | nvidia-nccl-cu11==2.14.3 103 | nvidia-nvtx-cu11==11.7.91 104 | oauthlib==3.2.2 105 | omegaconf==2.3.0 106 | ordered-set==4.1.0 107 | packaging==23.1 108 | pandarallel==1.6.5 109 | pandas==2.0.2 110 | parso==0.8.3 111 | pathtools==0.1.2 112 | pexpect==4.9.0 113 | pillow==10.1.0 114 | platformdirs==4.0.0 115 | pooch==1.8.0 116 | prompt-toolkit==3.0.41 117 | protobuf==3.19.6 118 | psutil==5.9.5 119 | ptyprocess==0.7.0 120 | pure-eval==0.2.2 121 | pyasn1==0.5.0 122 | pyasn1-modules==0.3.0 123 | pycparser==2.21 124 | pydantic==1.10.8 125 | pygments==2.15.1 126 | pyjwt==2.7.0 127 | pyloudnorm==0.1.1 128 | pyparsing==3.1.1 129 | pyrootutils==1.0.4 130 | pystoi==0.3.3 131 | python-dateutil==2.8.2 132 | python-dotenv==1.0.0 133 | python-editor==1.0.4 134 | python-multipart==0.0.6 135 | pytorch-lightning==2.0.2 136 | pytz==2023.3 137 | pyyaml==6.0 138 | randomname==0.2.1 139 | readchar==4.0.5 140 | regex==2023.5.5 141 | requests==2.31.0 142 | requests-oauthlib==1.3.1 143 | rich==13.3.5 144 | rsa==4.9 145 | ruamel-yaml==0.17.28 146 | ruamel-yaml-clib==0.2.7 147 | safetensors==0.4.1 148 | scikit-learn==1.3.2 149 | scipy==1.10.1 150 | sentencepiece==0.1.99 151 | sentry-sdk==1.21.1 152 | setproctitle==1.3.2 153 | six==1.16.0 154 | smmap==5.0.0 155 | sniffio==1.3.0 156 | soundfile==0.12.1 157 | soupsieve==2.4.1 158 | soxr==0.3.7 159 | speechbrain==0.5.15 160 | stack-data==0.6.3 161 | starlette==0.22.0 162 | starsessions==1.3.0 163 | sympy==1.12 164 | tensorboard==2.13.0 165 | tensorboard-data-server==0.7.0 166 | termcolor==2.4.0 167 | threadpoolctl==3.2.0 168 | tokenizers==0.15.0 169 | torch==2.0.1 170 | torch-stoi==0.1.2 171 | torchaudio==2.0.2 172 | torchmetrics==0.11.4 173 | tqdm==4.65.0 174 | traitlets==5.9.0 175 | transformers==4.35.2 176 | triton==2.0.0 177 | typing-extensions==4.6.2 178 | tzdata==2023.3 179 | urllib3==2.0.2 180 | uvicorn==0.22.0 181 | wandb==0.15.3 182 | wcwidth==0.2.6 183 | webdataset==0.2.48 184 | websocket-client==1.5.2 185 | websockets==11.0.3 186 | werkzeug==2.3.4 187 | wheel==0.40.0 188 | yarl==1.9.2 189 | # The following packages are considered to be unsafe in a requirements file: 190 | setuptools==67.8.0 191 | -------------------------------------------------------------------------------- /scripts/run_preprocessing.sh: -------------------------------------------------------------------------------- 1 | python3 src/preprocess.py preprocess=wav2vec2-base > wav2vec2_base.log 2 | python3 src/preprocess.py preprocess=wav2vec2-large > wav2vec2_large.log 3 | python3 src/preprocess.py preprocess=wavlm-base > wavlm_base.log 4 | python3 src/preprocess.py preprocess=wavlm-large > wavlm_large.log 5 | python3 src/preprocess.py preprocess=hubert-base > hubert_base.log 6 | python3 src/preprocess.py preprocess=hubert-large > hubert_large.log 7 | -------------------------------------------------------------------------------- /scripts/run_preprocessing_l3.sh: -------------------------------------------------------------------------------- 1 | #python3 src/preprocess.py preprocess=wav2vec2_base_l3 2 | #python3 src/preprocess.py preprocess=hubert_base_l3 3 | #python3 src/preprocess.py preprocess=wavlm_base_l3 4 | python3 src/preprocess.py preprocess=wav2vec2_large_l6 5 | python3 src/preprocess.py preprocess=hubert_large_l6 6 | python3 src/preprocess.py preprocess=wavlm_large_l6 7 | -------------------------------------------------------------------------------- /scripts/run_preprocessing_layers.sh: -------------------------------------------------------------------------------- 1 | python3 src/preprocess.py preprocess=wav2vec2_base_l3 > wav2vec2_base_l3.log 2 | python3 src/preprocess.py preprocess=wav2vec2_base_l6 > wav2vec2_base_l6.log 3 | python3 src/preprocess.py preprocess=wav2vec2_base_l9 > wav2vec2_base_l9.log -------------------------------------------------------------------------------- /scripts/run_synthesize_chime.sh: -------------------------------------------------------------------------------- 1 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-base/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/wav2vec2-base/ 2 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-large/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/wav2vec2-large/ 3 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-base/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/hubert-base/ 4 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-large/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/hubert-large/ 5 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-base/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/wavlm-base/ 6 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-large/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/wavlm-large/ 7 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/mel/model.ckpt --wav /mnt/hdd/datasets/chime_home/chunks/ --output_path chime/mel/ 8 | -------------------------------------------------------------------------------- /scripts/run_synthesize_cmu_arctic.sh: -------------------------------------------------------------------------------- 1 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-base/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/wav2vec2-base/ 2 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-large/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/wav2vec2-large/ 3 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-base/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/hubert-base/ 4 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-large/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/hubert-large/ 5 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-base/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/wavlm-base/ 6 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-large/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/wavlm-large/ 7 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/mel/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_slt_arctic/wav/ --output_path cmu_us_slt_arctic/mel/ 8 | 9 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-base/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/wav2vec2-base/ 10 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-large/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/wav2vec2-large/ 11 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-base/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/hubert-base/ 12 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-large/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/hubert-large/ 13 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-base/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/wavlm-base/ 14 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-large/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/wavlm-large/ 15 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/mel/model.ckpt --wav /mnt/hdd/datasets/cmu_arctic/cmu_us_aew_arctic/wav/ --output_path cmu_us_aew_arctic/mel/ 16 | -------------------------------------------------------------------------------- /scripts/run_synthesize_pnl.sh: -------------------------------------------------------------------------------- 1 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-base/model.ckpt --wav /mnt/hdd/datasets/Nonspeech/ --output_path pnl/wav2vec2-base/ 2 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-large/model.ckpt --wav /mnt/hdd/datasets/Nonspeech/ --output_path pnl/wav2vec2-large/ 3 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-base/model.ckpt --wav /mnt/hdd/datasets/Nonspeech/ --output_path pnl/hubert-base/ 4 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-large/model.ckpt --wav /mnt/hdd/datasets/Nonspeech/ --output_path pnl/hubert-large/ 5 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-base/model.ckpt --wav /mnt/hdd/datasets/Nonspeech --output_path pnl/wavlm-base/ 6 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-large/model.ckpt --wav /mnt/hdd/datasets/Nonspeech --output_path pnl/wavlm-large/ 7 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/mel/model.ckpt --wav /mnt/hdd/datasets/Nonspeech --output_path pnl/mel/ 8 | -------------------------------------------------------------------------------- /scripts/run_synthesize_street.sh: -------------------------------------------------------------------------------- 1 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-base/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/wav2vec2-base/ 2 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wav2vec2-large/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/wav2vec2-large/ 3 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-base/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/hubert-base/ 4 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/hubert-large/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/hubert-large/ 5 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-base/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/wavlm-base/ 6 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/wavlm-large/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/wavlm-large/ 7 | python3 src/synthesize.py --ckpt_path notebooks/checkpoints/mel/model.ckpt --wav /mnt/hdd/datasets/NOIZEUS/15dB --output_path street/mel/ 8 | -------------------------------------------------------------------------------- /scripts/run_training_hubert_base.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_base model=hifigan_ssl data=hubert_base -------------------------------------------------------------------------------- /scripts/run_training_hubert_base_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_base model=hifigan_ssl data=hubert_base 'train.ckpt_path="/home/acc12576tt/model-hubert-base.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_hubert_base_l3.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_base_l3 model=hifigan_ssl data=hubert_base_l3 13 | -------------------------------------------------------------------------------- /scripts/run_training_hubert_base_l3_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_base_l3 model=hifigan_ssl data=hubert_base_l3 'train.ckpt_path="/home/acc12576tt/hubert-base-l3/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_hubert_large.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_large model=hifigan_ssl_large data=hubert_large -------------------------------------------------------------------------------- /scripts/run_training_hubert_large_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_large model=hifigan_ssl_large data=hubert_large train.ckpt_path=/home/acc12576tt/model.ckpt 13 | -------------------------------------------------------------------------------- /scripts/run_training_hubert_large_l6.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_large_l6 model=hifigan_ssl_large data=hubert_large_l6 13 | -------------------------------------------------------------------------------- /scripts/run_training_hubert_large_l6_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=hubert_large_l6 model=hifigan_ssl_large data=hubert_large_l6 'train.ckpt_path="/home/acc12576tt/hubert-large-l6/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_mel.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base model=hifigan_mel data=mel -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base model=hifigan_ssl data=wav2vec2_base -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base model=hifigan_ssl data=wav2vec2_base 'train.ckpt_path="/home/acc12576tt/wav2vec2_base_mid.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_l3.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base_l3 model=hifigan_ssl data=wav2vec2_base_l3 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_l3_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base_l3 model=hifigan_ssl data=wav2vec2_base_l3 'train.ckpt_path="/home/acc12576tt/checkpoints/wav2vec2_l3/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_l6.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base_l6 model=hifigan_ssl data=wav2vec2_base_l6 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_l6_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base_l6 model=hifigan_ssl data=wav2vec2_base_l6 'train.ckpt_path="/home/acc12576tt/checkpoints/wav2vec2_l6/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_l9.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base_l9 model=hifigan_ssl data=wav2vec2_base_l9 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_base_l9_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_base_l9 model=hifigan_ssl data=wav2vec2_base_l9 'train.ckpt_path="/home/acc12576tt/checkpoints/wav2vec2_l9/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_large.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_large model=hifigan_ssl_large data=wav2vec2_large -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_large_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_large model=hifigan_ssl_large data=wav2vec2_large 'train.ckpt_path="/home/acc12576tt/wav2vec2_large_mid.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_large_l6.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_large_l6 model=hifigan_ssl_large data=wav2vec2_large_l6 13 | -------------------------------------------------------------------------------- /scripts/run_training_wav2vec2_large_l6_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wav2vec2_large_l6 model=hifigan_ssl_large data=wav2vec2_large_l6 'train.ckpt_path="/home/acc12576tt/wav2vec2-large-l6/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_base.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_base model=hifigan_ssl data=wavlm_base -------------------------------------------------------------------------------- /scripts/run_training_wavlm_base_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_base model=hifigan_ssl data=wavlm_base 'train.ckpt_path="/home/acc12576tt/wavlm_base_mid.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_base_l3.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_base_l3 model=hifigan_ssl data=wavlm_base_l3 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_base_l3_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_base_l3 model=hifigan_ssl data=wavlm_base_l3 'train.ckpt_path="/home/acc12576tt/wavlm-base-l3/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_large.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_large model=hifigan_ssl_large data=wavlm_large -------------------------------------------------------------------------------- /scripts/run_training_wavlm_large_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_large model=hifigan_ssl_large data=wavlm_large 'train.ckpt_path="/home/acc12576tt/wavlm_large_mid.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_large_l6.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_large_l6 model=hifigan_ssl_large data=wavlm_large_l6 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_large_l6_continue.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py preprocess=wavlm_large_l6 model=hifigan_ssl_large data=wavlm_large_l6 'train.ckpt_path="/home/acc12576tt/wavlm-large-l6/model.ckpt"' 13 | -------------------------------------------------------------------------------- /scripts/run_training_wavlm_large_l8_xvector.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #$ -l rt_AG.small=1 3 | #$ -l h_rt=72:00:00 4 | #$ -j y 5 | #$-cwd 6 | source /etc/profile.d/modules.sh 7 | module load python/3.10/3.10.10 8 | module load cuda/12.1/12.1.1 9 | module load cudnn/8.9/8.9.2 10 | module load nccl/2.18/2.18.1-1 11 | source venv/bin/activate 12 | python3 src/train.py data=wavlm_large_l8_xvector model=hifigan_ssl_large_xvector preprocess=wavlm_large_l8 -------------------------------------------------------------------------------- /src/lightning_vocoders/__init__.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | from lightning.pytorch import LightningModule 4 | MODEL_URLS = { 5 | "hubert-base-l3": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/hubert-base-l3/hubert-base-l3.ckpt", 6 | "hubert-base": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/hubert-base/hubert-base.ckpt", 7 | "hubert-large-l6": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/hubert-large-l6/hubert-large-l6.ckpt", 8 | "hubert-large": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/hubert-large/hubert-large.ckpt", 9 | "mel": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/mel/mel.ckpt", 10 | "wav2vec2-base": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wav2vec2-base/wav2vec2-base.ckpt", 11 | "wav2vec2-large": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wav2vec2-large/wav2vec2-large.ckpt", 12 | "wav2vec2-large-l6": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wav2vec2-large-l6/wav2vec2-large-l6.ckpt", 13 | "wav2vec2_l3": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wav2vec2_l3/wav2vec2_l3.ckpt", 14 | "wav2vec2_l6": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wav2vec2_l6/wav2vec2_l6.ckpt", 15 | "wav2vec2_l9": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wav2vec2_l9/wav2vec2_l9.ckpt", 16 | "wavlm-base-l3": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wavlm-base-l3/wavlm-base-l3.ckpt", 17 | "wavlm-large-l6": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wavlm-large-l6/wavlm-large-l6.ckpt", 18 | "wavlm-large": "https://huggingface.co/Wataru/ssl-vocoder/resolve/main/wavlm-large/wavlm-large.ckpt", 19 | } 20 | -------------------------------------------------------------------------------- /src/lightning_vocoders/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import lightning 3 | from omegaconf import DictConfig 4 | from torch.utils.data import DataLoader 5 | from torch.nn.utils.rnn import pad_sequence 6 | from pathlib import Path 7 | import torch 8 | from torch.utils.data import random_split 9 | import json 10 | import math 11 | import random 12 | import torchaudio 13 | import transformers 14 | import hydra 15 | 16 | class VocoderDataModule(lightning.LightningDataModule): 17 | def __init__(self, cfg: DictConfig) -> None: 18 | super().__init__() 19 | self.cfg = cfg 20 | 21 | def setup(self, stage: str): 22 | self.train_dataset = ( 23 | wds.WebDataset(self.cfg.data.train_dataset_path) 24 | .shuffle(1000) 25 | .decode(wds.torch_audio) 26 | ) 27 | self.val_dataset = wds.WebDataset(self.cfg.data.val_dataset_path).decode( 28 | wds.torch_audio 29 | ) 30 | 31 | def train_dataloader(self): 32 | return DataLoader( 33 | self.train_dataset, 34 | batch_size=self.cfg.data.train_batch_size, 35 | collate_fn=lambda batch: self.collate_fn( 36 | batch, self.cfg.data.segment_size.train 37 | ), 38 | num_workers=20, 39 | ) 40 | 41 | def val_dataloader(self): 42 | return DataLoader( 43 | self.val_dataset, 44 | batch_size=self.cfg.data.val_batch_size, 45 | collate_fn=lambda batch: self.collate_fn( 46 | batch, self.cfg.data.segment_size.val 47 | ), 48 | num_workers=20, 49 | ) 50 | 51 | @torch.no_grad() 52 | def collate_fn(self, batch, segment_size: int = -1): 53 | 54 | outputs = dict() 55 | if segment_size != -1: 56 | cropped_speeches = [] 57 | input_features = [] 58 | for sample in batch: 59 | wav = sample["resampled_speech.pth"] 60 | input_feature= sample[self.cfg.data.target_feature.key] 61 | feature_len = input_feature.size(0) 62 | if feature_len > (segment_size+1): 63 | feature_start = random.randint( 64 | 0, feature_len - segment_size - 1 65 | ) 66 | feature_end = segment_size + feature_start 67 | speech_start_sec = feature_start / self.cfg.data.target_feature.samples_per_sec + self.cfg.data.target_feature.bias 68 | speech_end_sec = (feature_start + segment_size) / self.cfg.data.target_feature.samples_per_sec + self.cfg.data.target_feature.bias 69 | cropped_speeches.append( 70 | wav.squeeze()[ 71 | int(speech_start_sec * self.cfg.sample_rate) : int(speech_end_sec * self.cfg.sample_rate) 72 | ] 73 | ) 74 | input_features.append( 75 | input_feature[ 76 | feature_start:feature_end 77 | ] 78 | ) 79 | else: 80 | cropped_speeches.append(wav.squeeze()) 81 | input_features.append( 82 | input_feature 83 | ) 84 | outputs["resampled_speech.pth"] = pad_sequence( 85 | cropped_speeches, batch_first=True 86 | ) 87 | outputs["input_feature"] = pad_sequence( 88 | input_features, batch_first=True 89 | ) 90 | else: 91 | outputs["resampled_speech.pth"] = pad_sequence( 92 | [b["resampled_speech.pth"].squeeze() for b in batch], batch_first=True 93 | ) 94 | outputs["input_feature"] = pad_sequence( 95 | [b[self.cfg.data.target_feature.key].squeeze() for b in batch], batch_first=True 96 | ) 97 | 98 | outputs["wav_lens"] = torch.tensor( 99 | [b["resampled_speech.pth"].size(0) for b in batch] 100 | ) 101 | if hasattr(self.cfg.data,'xvector'): 102 | if self.cfg.data.xvector.use_xvector: 103 | outputs["xvector"] = torch.stack( 104 | [b["xvector.pth"] for b in batch] 105 | ) 106 | 107 | outputs["filenames"] = [b["__key__"] for b in batch] 108 | return outputs 109 | -------------------------------------------------------------------------------- /src/lightning_vocoders/models/hifigan/generator_xvector.py: -------------------------------------------------------------------------------- 1 | from .hifigan import weight_norm,Conv1d,ResBlock1,ResBlock2,ConvTranspose1d,init_weights,remove_weight_norm, LRELU_SLOPE 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | 6 | class FiLMLayer(nn.Module): 7 | def __init__(self,input_channels,intermediate_channels) -> None: 8 | super().__init__() 9 | self.conv1 = nn.Conv1d(input_channels,intermediate_channels,kernel_size=3,stride=1,padding=1) 10 | self.conv2 = nn.Conv1d(intermediate_channels, input_channels,kernel_size=3,stride=1,padding=1) 11 | self.leaky_relu = nn.LeakyReLU(0.1) 12 | 13 | def forward(self,a:torch.Tensor,b:torch.Tensor): 14 | batch_size, K, D = a.size() 15 | Q = b.size(1) 16 | a = a.transpose(1,2) 17 | output = self.conv2((self.leaky_relu(self.conv1(a)).transpose(1,2) + b).transpose(1,2)) 18 | output = output.permute(0,2,1) 19 | assert output.size() == (batch_size,K,D) 20 | return output 21 | class GeneratorWithXvector(torch.nn.Module): 22 | def __init__(self, h): 23 | super().__init__() 24 | self.h = h 25 | self.num_kernels = len(h.resblock_kernel_sizes) 26 | self.num_upsamples = len(h.upsample_rates) 27 | self.conv_pre = weight_norm( 28 | Conv1d(h.num_input_channels, h.upsample_initial_channel, 7, 1, padding=3) 29 | ) 30 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 31 | 32 | self.ups = nn.ModuleList() 33 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 34 | self.ups.append( 35 | weight_norm( 36 | ConvTranspose1d( 37 | h.upsample_initial_channel // (2**i), 38 | h.upsample_initial_channel // (2 ** (i + 1)), 39 | k, 40 | u, 41 | padding=(k - u) // 2, 42 | ) 43 | ) 44 | ) 45 | 46 | self.resblocks = nn.ModuleList() 47 | 48 | self.feature_xvector_film = FiLMLayer(h.num_input_channels,h.xvector_dim) 49 | for i in range(len(self.ups)): 50 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 51 | for j, (k, d) in enumerate( 52 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 53 | ): 54 | self.resblocks.append(resblock(h, ch, k, d)) 55 | 56 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 57 | self.ups.apply(init_weights) 58 | self.conv_post.apply(init_weights) 59 | 60 | def forward(self, feature,xvector): 61 | x = self.feature_xvector_film(feature,xvector.unsqueeze(1)) 62 | x = self.conv_pre(x.transpose(1, 2)) 63 | for i in range(self.num_upsamples): 64 | x = F.leaky_relu(x, LRELU_SLOPE) 65 | x = self.ups[i](x) 66 | xs = None 67 | for j in range(self.num_kernels): 68 | if xs is None: 69 | xs = self.resblocks[i * self.num_kernels + j](x) 70 | else: 71 | xs += self.resblocks[i * self.num_kernels + j](x) 72 | x = xs / self.num_kernels 73 | x = F.leaky_relu(x) 74 | x = self.conv_post(x) 75 | x = torch.tanh(x) 76 | 77 | return x 78 | 79 | def remove_weight_norm(self): 80 | print("Removing weight norm...") 81 | for l in self.ups: 82 | remove_weight_norm(l) 83 | for l in self.resblocks: 84 | l.remove_weight_norm() 85 | remove_weight_norm(self.conv_pre) 86 | remove_weight_norm(self.conv_post) 87 | -------------------------------------------------------------------------------- /src/lightning_vocoders/models/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock1(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock1, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class ResBlock2(torch.nn.Module): 113 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 114 | super(ResBlock2, self).__init__() 115 | self.h = h 116 | self.convs = nn.ModuleList( 117 | [ 118 | weight_norm( 119 | Conv1d( 120 | channels, 121 | channels, 122 | kernel_size, 123 | 1, 124 | dilation=dilation[0], 125 | padding=get_padding(kernel_size, dilation[0]), 126 | ) 127 | ), 128 | weight_norm( 129 | Conv1d( 130 | channels, 131 | channels, 132 | kernel_size, 133 | 1, 134 | dilation=dilation[1], 135 | padding=get_padding(kernel_size, dilation[1]), 136 | ) 137 | ), 138 | ] 139 | ) 140 | self.convs.apply(init_weights) 141 | 142 | def forward(self, x): 143 | for c in self.convs: 144 | xt = F.leaky_relu(x, LRELU_SLOPE) 145 | xt = c(xt) 146 | x = xt + x 147 | return x 148 | 149 | def remove_weight_norm(self): 150 | for l in self.convs: 151 | remove_weight_norm(l) 152 | 153 | 154 | class Generator(torch.nn.Module): 155 | def __init__(self, h): 156 | super(Generator, self).__init__() 157 | self.h = h 158 | self.num_kernels = len(h.resblock_kernel_sizes) 159 | self.num_upsamples = len(h.upsample_rates) 160 | self.conv_pre = weight_norm( 161 | Conv1d(h.num_input_channels, h.upsample_initial_channel, 7, 1, padding=3) 162 | ) 163 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 164 | 165 | self.ups = nn.ModuleList() 166 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 167 | self.ups.append( 168 | weight_norm( 169 | ConvTranspose1d( 170 | h.upsample_initial_channel // (2**i), 171 | h.upsample_initial_channel // (2 ** (i + 1)), 172 | k, 173 | u, 174 | padding=(k - u) // 2, 175 | ) 176 | ) 177 | ) 178 | 179 | self.resblocks = nn.ModuleList() 180 | for i in range(len(self.ups)): 181 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 182 | for j, (k, d) in enumerate( 183 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 184 | ): 185 | self.resblocks.append(resblock(h, ch, k, d)) 186 | 187 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 188 | self.ups.apply(init_weights) 189 | self.conv_post.apply(init_weights) 190 | 191 | def forward(self, x): 192 | x = self.conv_pre(x.transpose(1, 2)) 193 | for i in range(self.num_upsamples): 194 | x = F.leaky_relu(x, LRELU_SLOPE) 195 | x = self.ups[i](x) 196 | xs = None 197 | for j in range(self.num_kernels): 198 | if xs is None: 199 | xs = self.resblocks[i * self.num_kernels + j](x) 200 | else: 201 | xs += self.resblocks[i * self.num_kernels + j](x) 202 | x = xs / self.num_kernels 203 | x = F.leaky_relu(x) 204 | x = self.conv_post(x) 205 | x = torch.tanh(x) 206 | 207 | return x 208 | 209 | def remove_weight_norm(self): 210 | print("Removing weight norm...") 211 | for l in self.ups: 212 | remove_weight_norm(l) 213 | for l in self.resblocks: 214 | l.remove_weight_norm() 215 | remove_weight_norm(self.conv_pre) 216 | remove_weight_norm(self.conv_post) 217 | 218 | 219 | class DiscriminatorP(torch.nn.Module): 220 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 221 | super(DiscriminatorP, self).__init__() 222 | self.period = period 223 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 224 | self.convs = nn.ModuleList( 225 | [ 226 | norm_f( 227 | Conv2d( 228 | 1, 229 | 32, 230 | (kernel_size, 1), 231 | (stride, 1), 232 | padding=(get_padding(5, 1), 0), 233 | ) 234 | ), 235 | norm_f( 236 | Conv2d( 237 | 32, 238 | 128, 239 | (kernel_size, 1), 240 | (stride, 1), 241 | padding=(get_padding(5, 1), 0), 242 | ) 243 | ), 244 | norm_f( 245 | Conv2d( 246 | 128, 247 | 512, 248 | (kernel_size, 1), 249 | (stride, 1), 250 | padding=(get_padding(5, 1), 0), 251 | ) 252 | ), 253 | norm_f( 254 | Conv2d( 255 | 512, 256 | 1024, 257 | (kernel_size, 1), 258 | (stride, 1), 259 | padding=(get_padding(5, 1), 0), 260 | ) 261 | ), 262 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 263 | ] 264 | ) 265 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 266 | 267 | def forward(self, x): 268 | fmap = [] 269 | 270 | # 1d to 2d 271 | b, c, t = x.shape 272 | if t % self.period != 0: # pad first 273 | n_pad = self.period - (t % self.period) 274 | x = F.pad(x, (0, n_pad), "reflect") 275 | t = t + n_pad 276 | x = x.view(b, c, t // self.period, self.period) 277 | 278 | for l in self.convs: 279 | x = l(x) 280 | x = F.leaky_relu(x, LRELU_SLOPE) 281 | fmap.append(x) 282 | x = self.conv_post(x) 283 | fmap.append(x) 284 | x = torch.flatten(x, 1, -1) 285 | 286 | return x, fmap 287 | 288 | 289 | class MultiPeriodDiscriminator(torch.nn.Module): 290 | def __init__(self): 291 | super(MultiPeriodDiscriminator, self).__init__() 292 | self.discriminators = nn.ModuleList( 293 | [ 294 | DiscriminatorP(2), 295 | DiscriminatorP(3), 296 | DiscriminatorP(5), 297 | DiscriminatorP(7), 298 | DiscriminatorP(11), 299 | ] 300 | ) 301 | 302 | def forward(self, y, y_hat): 303 | y_d_rs = [] 304 | y_d_gs = [] 305 | fmap_rs = [] 306 | fmap_gs = [] 307 | for i, d in enumerate(self.discriminators): 308 | y_d_r, fmap_r = d(y) 309 | y_d_g, fmap_g = d(y_hat) 310 | y_d_rs.append(y_d_r) 311 | fmap_rs.append(fmap_r) 312 | y_d_gs.append(y_d_g) 313 | fmap_gs.append(fmap_g) 314 | 315 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 316 | 317 | 318 | class DiscriminatorS(torch.nn.Module): 319 | def __init__(self, use_spectral_norm=False): 320 | super(DiscriminatorS, self).__init__() 321 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 322 | self.convs = nn.ModuleList( 323 | [ 324 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 325 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 326 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 327 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 328 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 329 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 330 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 331 | ] 332 | ) 333 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 334 | 335 | def forward(self, x): 336 | fmap = [] 337 | for l in self.convs: 338 | x = l(x) 339 | x = F.leaky_relu(x, LRELU_SLOPE) 340 | fmap.append(x) 341 | x = self.conv_post(x) 342 | fmap.append(x) 343 | x = torch.flatten(x, 1, -1) 344 | 345 | return x, fmap 346 | 347 | 348 | class MultiScaleDiscriminator(torch.nn.Module): 349 | def __init__(self): 350 | super(MultiScaleDiscriminator, self).__init__() 351 | self.discriminators = nn.ModuleList( 352 | [ 353 | DiscriminatorS(use_spectral_norm=True), 354 | DiscriminatorS(), 355 | DiscriminatorS(), 356 | ] 357 | ) 358 | self.meanpools = nn.ModuleList( 359 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] 360 | ) 361 | 362 | def forward(self, y, y_hat): 363 | y_d_rs = [] 364 | y_d_gs = [] 365 | fmap_rs = [] 366 | fmap_gs = [] 367 | for i, d in enumerate(self.discriminators): 368 | if i != 0: 369 | y = self.meanpools[i - 1](y) 370 | y_hat = self.meanpools[i - 1](y_hat) 371 | y_d_r, fmap_r = d(y) 372 | y_d_g, fmap_g = d(y_hat) 373 | y_d_rs.append(y_d_r) 374 | fmap_rs.append(fmap_r) 375 | y_d_gs.append(y_d_g) 376 | fmap_gs.append(fmap_g) 377 | 378 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 379 | 380 | 381 | def feature_loss(fmap_r, fmap_g): 382 | loss = 0 383 | for dr, dg in zip(fmap_r, fmap_g): 384 | for rl, gl in zip(dr, dg): 385 | loss += torch.mean(torch.abs(rl - gl)) 386 | 387 | return loss * 2 388 | 389 | 390 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 391 | loss = 0 392 | r_losses = [] 393 | g_losses = [] 394 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 395 | r_loss = torch.mean((1 - dr) ** 2) 396 | g_loss = torch.mean(dg**2) 397 | loss += r_loss + g_loss 398 | r_losses.append(r_loss.item()) 399 | g_losses.append(g_loss.item()) 400 | 401 | return loss, r_losses, g_losses 402 | 403 | 404 | def generator_loss(disc_outputs): 405 | loss = 0 406 | gen_losses = [] 407 | for dg in disc_outputs: 408 | l = torch.mean((1 - dg) ** 2) 409 | gen_losses.append(l) 410 | loss += l 411 | 412 | return loss, gen_losses 413 | -------------------------------------------------------------------------------- /src/lightning_vocoders/models/hifigan/lightning_module.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Optional 3 | from lightning.pytorch import LightningModule, loggers 4 | from lightning.pytorch.utilities.types import STEP_OUTPUT 5 | from omegaconf import DictConfig 6 | import numpy as np 7 | from webdataset import resampled 8 | from .hifigan import ( 9 | Generator, 10 | MultiPeriodDiscriminator, 11 | MultiScaleDiscriminator, 12 | discriminator_loss, 13 | generator_loss, 14 | feature_loss, 15 | ) 16 | import torch 17 | import hydra 18 | import torchaudio 19 | import transformers 20 | from pathlib import Path 21 | 22 | 23 | class Preprocessor(torch.nn.Module): 24 | def __init__(self, cfg:DictConfig) -> None: 25 | super().__init__() 26 | self.resampler = torchaudio.transforms.Resample(cfg.sample_rate,16_000) 27 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 28 | cfg.sample_rate, 29 | cfg.preprocess.stft.n_fft, 30 | cfg.preprocess.stft.win_length, 31 | cfg.preprocess.stft.hop_length, 32 | cfg.preprocess.mel.f_min, 33 | cfg.preprocess.mel.f_max, 34 | n_mels=cfg.preprocess.mel.n_mels, 35 | ) 36 | def get_logmelspec(self,waveform): 37 | melspec = self.mel_spec(waveform) 38 | logmelspec = torch.log(torch.clamp_min(melspec, 1.0e-5) * 1.0).to(torch.float32) 39 | return logmelspec 40 | 41 | class HiFiGANLightningModule(LightningModule,object): 42 | def __init__(self, cfg: DictConfig) -> None: 43 | super().__init__() 44 | self.generator = Generator(cfg.model.generator) 45 | self.multi_period_discriminator = MultiPeriodDiscriminator() 46 | self.multi_scale_discriminator = MultiScaleDiscriminator() 47 | self.automatic_optimization = False 48 | self.preprocessor = Preprocessor(cfg) 49 | self.cfg = cfg 50 | self.save_hyperparameters() 51 | 52 | def configure_optimizers(self) -> Any: 53 | opt_g = hydra.utils.instantiate( 54 | self.cfg.model.optim.opt_g, params=self.generator.parameters() 55 | ) 56 | opt_d = hydra.utils.instantiate( 57 | self.cfg.model.optim.opt_d, 58 | params=itertools.chain( 59 | self.multi_scale_discriminator.parameters(), 60 | self.multi_period_discriminator.parameters(), 61 | ), 62 | ) 63 | scheduler_g = hydra.utils.instantiate( 64 | self.cfg.model.optim.scheduler_g, optimizer=opt_g 65 | ) 66 | scheduler_d = hydra.utils.instantiate( 67 | self.cfg.model.optim.scheduler_d, optimizer=opt_d 68 | ) 69 | 70 | return [opt_g, opt_d], [ 71 | {"name": "scheduler_g", "scheduler": scheduler_g}, 72 | {"name": "scheduler_d", "scheduler": scheduler_d}, 73 | ] 74 | 75 | def generator_forward(self,batch): 76 | generator_input = batch["input_feature"] 77 | wav_generator_out = self.generator(generator_input) 78 | return wav_generator_out 79 | 80 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 81 | wav,generator_input, _ = ( 82 | batch["resampled_speech.pth"], 83 | batch["input_feature"], 84 | batch["filenames"], 85 | ) 86 | mel = self.preprocessor.get_logmelspec(wav) 87 | wav = wav.unsqueeze(1) 88 | wav_generator_out = self.generator_forward(batch) 89 | output_length = min(wav_generator_out.size(2),wav.size(2)) 90 | wav = wav[:,:,:output_length] 91 | wav_generator_out = wav_generator_out[:,:,:output_length] 92 | 93 | opt_g, opt_d = self.optimizers() 94 | sch_g, sch_d = self.lr_schedulers() 95 | if self.global_step >= self.cfg.model.adversarial_start_step: 96 | opt_d.zero_grad() 97 | 98 | # mpd 99 | mpd_out_real, mpd_out_fake, _, _ = self.multi_period_discriminator( 100 | wav, wav_generator_out.detach() 101 | ) 102 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss( 103 | mpd_out_real, mpd_out_fake 104 | ) 105 | 106 | # msd 107 | msd_out_real, msd_out_fake, _, _ = self.multi_scale_discriminator( 108 | wav, wav_generator_out.detach() 109 | ) 110 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss( 111 | msd_out_real, msd_out_fake 112 | ) 113 | 114 | loss_disc_all = loss_disc_s + loss_disc_f 115 | self.manual_backward(loss_disc_all) 116 | opt_d.step() 117 | sch_d.step() 118 | self.log("train/discriminator/loss_disc_f", loss_disc_f) 119 | self.log("train/discriminator/loss_disc_s", loss_disc_s) 120 | else: 121 | loss_disc_f = loss_disc_s = 0.0 122 | 123 | # generator 124 | opt_g.zero_grad() 125 | predicted_mel = self.preprocessor.get_logmelspec(wav_generator_out.squeeze(1)) 126 | loss_recons = self.reconstruction_loss(mel, predicted_mel) 127 | loss_g = loss_recons * self.cfg.model.loss.recons_coef 128 | if self.global_step >= self.cfg.model.adversarial_start_step: 129 | ( 130 | mpd_out_real, 131 | mpd_out_fake, 132 | fmap_f_real, 133 | fmap_f_generated, 134 | ) = self.multi_period_discriminator(wav, wav_generator_out) 135 | loss_fm_mpd = feature_loss(fmap_f_real, fmap_f_generated) 136 | 137 | # msd 138 | ( 139 | msd_out_real, 140 | msd_out_fake, 141 | fmap_scale_real, 142 | fmap_scale_generated, 143 | ) = self.multi_scale_discriminator(wav, wav_generator_out) 144 | loss_fm_msd = feature_loss(fmap_scale_real, fmap_scale_generated) 145 | 146 | loss_g_mpd, losses_gen_f = generator_loss(mpd_out_fake) 147 | loss_g_msd, losses_gen_s = generator_loss(msd_out_fake) 148 | loss_g += loss_fm_mpd * self.cfg.model.loss.fm_mpd_coef 149 | loss_g += loss_fm_msd * self.cfg.model.loss.fm_msd_coef 150 | loss_g += loss_g_mpd * self.cfg.model.loss.g_mpd_coef 151 | loss_g += loss_g_msd * self.cfg.model.loss.g_msd_coef 152 | self.log("train/generator/loss_fm_mpd", loss_fm_mpd) 153 | self.log("train/generator/loss_fm_msd", loss_fm_msd) 154 | self.log("train/generator/loss_g_mpd", loss_g_mpd) 155 | self.log("train/generator/loss_g_msd", loss_g_msd) 156 | self.manual_backward(loss_g) 157 | self.log("train/loss_reconstruction", loss_recons) 158 | self.log("train/generator/loss", loss_g) 159 | opt_g.step() 160 | sch_g.step() 161 | 162 | def validation_step(self, batch, batch_idx): 163 | wav, generator_input, filename, wav_lens = ( 164 | batch["resampled_speech.pth"], 165 | batch["input_feature"], 166 | batch["filenames"], 167 | batch["wav_lens"], 168 | ) 169 | mel = self.preprocessor.get_logmelspec(wav) 170 | wav_generator_out = self.generator_forward(batch) 171 | predicted_mel = self.preprocessor.get_logmelspec(wav_generator_out.squeeze(1)) 172 | loss_recons = self.reconstruction_loss(mel, predicted_mel) 173 | if ( 174 | batch_idx < self.cfg.model.logging_wav_samples 175 | and self.global_rank == 0 176 | and self.local_rank == 0 177 | ): 178 | self.log_audio( 179 | wav_generator_out[0] 180 | .squeeze()[: wav_lens[0]] 181 | .cpu() 182 | .numpy() 183 | .astype(np.float32), 184 | name=f"generated/{filename[0]}", 185 | sampling_rate=self.cfg.sample_rate, 186 | ) 187 | self.log_audio( 188 | wav[0].squeeze()[: wav_lens[0]].cpu().numpy().astype(np.float32), 189 | name=f"natural/{filename[0]}", 190 | sampling_rate=self.cfg.sample_rate, 191 | ) 192 | 193 | self.log("val/reconstruction", loss_recons) 194 | def on_test_start(self): 195 | Path(f"{self.output_path}").mkdir(exist_ok=True,parents=True) 196 | def test_step(self,batch,batch_idx): 197 | generator_input = batch["input_feature"] 198 | wav_generator_out = self.generator_forward(batch) 199 | return wav_generator_out 200 | def on_test_batch_end(self, outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0): 201 | for output,filename,resampled in zip(outputs,batch["filenames"],batch["resampled_speech.pth"]): 202 | torchaudio.save(filepath=f"{self.output_path}/{filename}.wav",src=output.cpu(),sample_rate=self.cfg.sample_rate) 203 | torchaudio.save(filepath=f"{self.output_path}/{filename}_gt.wav",src=resampled.unsqueeze(0).cpu(),sample_rate=self.cfg.sample_rate) 204 | return 205 | 206 | 207 | def reconstruction_loss(self, mel_gt, mel_predicted): 208 | length = min(mel_gt.size(2), mel_predicted.size(2)) 209 | return torch.nn.L1Loss()( 210 | mel_gt[:, :, :length], 211 | mel_predicted[:, :, :length], 212 | ) 213 | 214 | def log_audio(self, audio, name, sampling_rate): 215 | for logger in self.loggers: 216 | if type(logger) == loggers.WandbLogger: 217 | import wandb 218 | 219 | wandb.log( 220 | {name: wandb.Audio(audio, sample_rate=sampling_rate)}, 221 | step=self.global_step, 222 | ) 223 | elif type(logger) == loggers.TensorBoardLogger: 224 | logger.experiment.add_audio( 225 | name, 226 | audio, 227 | self.global_step, 228 | sampling_rate, 229 | ) 230 | -------------------------------------------------------------------------------- /src/lightning_vocoders/models/hifigan/xvector_lightning_module.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | from .lightning_module import HiFiGANLightningModule 3 | from .generator_xvector import GeneratorWithXvector 4 | 5 | class HiFiGANXvectorLightningModule(HiFiGANLightningModule,object): 6 | def __init__(self, cfg: DictConfig) -> None: 7 | HiFiGANLightningModule.__init__(self,cfg) 8 | self.generator = GeneratorWithXvector(cfg.model.generator) 9 | 10 | def generator_forward(self, batch): 11 | print(batch['xvector'].shape) 12 | wav_generator_out = self.generator(batch["input_feature"],batch['xvector']) 13 | return wav_generator_out 14 | -------------------------------------------------------------------------------- /src/lightning_vocoders/models/wavegrad/lightning_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from lightning.pytorch import LightningModule, loggers 3 | from lightning.pytorch.utilities.types import STEP_OUTPUT 4 | import hydra 5 | from .wavegrad import WaveGrad 6 | import numpy as np 7 | import torch 8 | 9 | class WaveGradLightningModule(LightningModule): 10 | def __init__(self, cfg) -> None: 11 | super().__init__() 12 | self.model = WaveGrad(**cfg.model.model_params) 13 | self.cfg = cfg 14 | self.criterion = torch.nn.L1Loss() 15 | self.num_steps = 1000 16 | self.save_hyperparameters() 17 | def setup(self, stage: str) -> None: 18 | self.beta = np.linspace(**self.cfg.model.noise_schedule) 19 | self.alpha = 1-self.beta 20 | self.alpha_cum = np.cumprod(self.alpha) 21 | noise_level = np.cumprod(1-self.beta)**0.5 22 | noise_level = np.concatenate([[1.0], noise_level], axis=0) 23 | self.noise_level = torch.tensor(noise_level.astype(np.float32),device=self.device) 24 | 25 | def training_step(self, batch, batch_idx) -> STEP_OUTPUT: 26 | wav,input_feature = batch['resampled_speech.pth'], batch['input_feature'] 27 | batch_size = wav.size(0) 28 | wav = wav[:,:int(input_feature.size(1)/50*22050)] 29 | 30 | s = torch.randint(1, self.num_steps + 1, [batch_size], device=self.device) 31 | l_a, l_b = self.noise_level.to(self.device)[s - 1], self.noise_level.to(self.device)[s] 32 | noise_scale = l_a + torch.rand(batch_size, device=self.device) * (l_b - l_a) 33 | noise_scale = noise_scale.unsqueeze(1) 34 | noise = torch.randn_like(wav) 35 | 36 | noisy_wav = noise_scale * wav + (1.0 - noise_scale**2)**0.5 * noise 37 | 38 | predicted = self.model(noisy_wav, input_feature.transpose(1,2),noise_scale.squeeze(1)) 39 | loss = self.criterion(predicted.squeeze(1), noise) 40 | self.log('train/loss',loss) 41 | return loss 42 | 43 | def validation_step(self, batch, batch_idx): 44 | wav,input_feature = batch['resampled_speech.pth'], batch['input_feature'] 45 | batch_size = wav.size(0) 46 | wav = wav[:,:int((input_feature.size(1))/50*22050 + 1e-3)] 47 | 48 | s = torch.randint(1, self.num_steps + 1, [batch_size], device=self.device) 49 | l_a, l_b = self.noise_level.to(self.device)[s - 1], self.noise_level.to(self.device)[s] 50 | noise_scale = l_a + torch.rand(batch_size, device=self.device) * (l_b - l_a) 51 | noise_scale = noise_scale.unsqueeze(1) 52 | noise = torch.randn_like(wav) 53 | 54 | noisy_wav = noise_scale * wav + (1.0 - noise_scale**2)**0.5 * noise 55 | 56 | predicted = self.model(noisy_wav, input_feature.transpose(1,2),noise_scale.squeeze(1)) 57 | loss = self.criterion(predicted.squeeze(1), noise) 58 | self.log('val/loss',loss) 59 | if batch_idx < self.cfg.model.n_logging_wav_samples and self.global_rank == 0 and self.local_rank == 0: 60 | predicted_audio = self.predict(input_feature[0].unsqueeze(0), wav.size(1)) 61 | self.log_audio(predicted_audio[0].cpu().numpy().astype(np.float32),name=f"generated/{batch['filenames'][0]}",sampling_rate=self.cfg.sample_rate) 62 | return loss 63 | def predict(self, input_feature,audio_length): 64 | audio = torch.randn((1, audio_length),device=self.device) 65 | noise_scale = torch.from_numpy(self.alpha_cum**0.5).float().unsqueeze(1).to(self.device) 66 | 67 | for n in range(len(self.alpha) - 1 , -1 ,-1): 68 | c1= 1/ (self.alpha[n] ** 0.5) 69 | 70 | c2 = (1- self.alpha[n]) / (1-self.alpha_cum[n]) ** 0.5 71 | audio = c1 * (audio - c2 * self.model(audio,input_feature.transpose(1,2),noise_scale[n]).squeeze(1)) 72 | if n> 0: 73 | noise = torch.randn_like(audio) 74 | sigma = ((1.0 - self.alpha_cum[n-1])/ (1.0 - self.alpha_cum[n]) * self.beta[n]) ** 0.5 75 | audio += sigma*noise 76 | audio = torch.clamp(audio, -1.0, 1.0) 77 | return audio 78 | def configure_optimizers(self) -> Any: 79 | return hydra.utils.instantiate(self.cfg.model.optim, params=self.parameters()) 80 | 81 | def log_audio(self, audio, name, sampling_rate): 82 | for logger in self.loggers: 83 | match type(logger): 84 | case loggers.WandbLogger: 85 | import wandb 86 | 87 | wandb.log( 88 | {name: wandb.Audio(audio, sample_rate=sampling_rate)}, 89 | step=self.global_step, 90 | ) 91 | case loggers.TensorBoardLogger: 92 | logger.experiment.add_audio( 93 | name, 94 | audio, 95 | self.global_step, 96 | sampling_rate, 97 | ) 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /src/lightning_vocoders/models/wavegrad/wavegrad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 LMNT, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from math import log as ln 22 | 23 | 24 | class Conv1d(nn.Conv1d): 25 | def __init__(self, *args, **kwargs): 26 | super().__init__(*args, **kwargs) 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | nn.init.orthogonal_(self.weight) 31 | nn.init.zeros_(self.bias) 32 | 33 | 34 | class PositionalEncoding(nn.Module): 35 | def __init__(self, dim): 36 | super().__init__() 37 | self.dim = dim 38 | 39 | def forward(self, x, noise_level): 40 | """ 41 | Arguments: 42 | x: 43 | (shape: [N,C,T], dtype: float32) 44 | noise_level: 45 | (shape: [N], dtype: float32) 46 | 47 | Returns: 48 | noise_level: 49 | (shape: [N,C,T], dtype: float32) 50 | """ 51 | N = x.shape[0] 52 | T = x.shape[2] 53 | return (x + self._build_encoding(noise_level)[:, :, None]) 54 | 55 | def _build_encoding(self, noise_level): 56 | count = self.dim // 2 57 | step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count 58 | encoding = noise_level.unsqueeze(1) * torch.exp(-ln(1e4) * step.unsqueeze(0)) 59 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 60 | return encoding 61 | 62 | 63 | class FiLM(nn.Module): 64 | def __init__(self, input_size, output_size): 65 | super().__init__() 66 | self.encoding = PositionalEncoding(input_size) 67 | self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) 68 | self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) 69 | self.reset_parameters() 70 | 71 | def reset_parameters(self): 72 | nn.init.xavier_uniform_(self.input_conv.weight) 73 | nn.init.xavier_uniform_(self.output_conv.weight) 74 | nn.init.zeros_(self.input_conv.bias) 75 | nn.init.zeros_(self.output_conv.bias) 76 | 77 | def forward(self, x, noise_scale): 78 | x = self.input_conv(x) 79 | x = F.leaky_relu(x, 0.2) 80 | x = self.encoding(x, noise_scale) 81 | shift, scale = torch.chunk(self.output_conv(x), 2, dim=1) 82 | return shift, scale 83 | 84 | 85 | class UBlock(nn.Module): 86 | def __init__(self, input_size, hidden_size, factor, dilation): 87 | super().__init__() 88 | assert isinstance(dilation, (list, tuple)) 89 | assert len(dilation) == 4 90 | 91 | self.factor = factor 92 | self.block1 = Conv1d(input_size, hidden_size, 1) 93 | self.block2 = nn.ModuleList([ 94 | Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), 95 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]) 96 | ]) 97 | self.block3 = nn.ModuleList([ 98 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), 99 | Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]) 100 | ]) 101 | 102 | def forward(self, x, film_shift, film_scale): 103 | block1 = F.interpolate(x, size=x.shape[-1] * self.factor) 104 | block1 = self.block1(block1) 105 | 106 | block2 = F.leaky_relu(x, 0.2) 107 | block2 = F.interpolate(block2, size=x.shape[-1] * self.factor) 108 | block2 = self.block2[0](block2) 109 | block2 = film_shift + film_scale * block2 110 | block2 = F.leaky_relu(block2, 0.2) 111 | block2 = self.block2[1](block2) 112 | 113 | x = block1 + block2 114 | 115 | block3 = film_shift + film_scale * x 116 | block3 = F.leaky_relu(block3, 0.2) 117 | block3 = self.block3[0](block3) 118 | block3 = film_shift + film_scale * block3 119 | block3 = F.leaky_relu(block3, 0.2) 120 | block3 = self.block3[1](block3) 121 | 122 | x = x + block3 123 | return x 124 | 125 | 126 | class DBlock(nn.Module): 127 | def __init__(self, input_size, hidden_size, factor): 128 | super().__init__() 129 | self.factor = factor 130 | self.residual_dense = Conv1d(input_size, hidden_size, 1) 131 | self.conv = nn.ModuleList([ 132 | Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), 133 | Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), 134 | Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), 135 | ]) 136 | 137 | def forward(self, x): 138 | size = x.shape[-1] // self.factor 139 | 140 | residual = self.residual_dense(x) 141 | residual = F.interpolate(residual, size=size) 142 | 143 | x = F.interpolate(x, size=size) 144 | for layer in self.conv: 145 | x = F.leaky_relu(x, 0.2) 146 | x = layer(x) 147 | 148 | return x + residual 149 | 150 | 151 | class WaveGrad(nn.Module): 152 | def __init__(self,n_input_channels,upsamples,downsamples,downsample_conv,film_layers): 153 | super().__init__() 154 | self.downsample = nn.ModuleList() 155 | self.upsample = nn.ModuleList() 156 | self.downsample.append(Conv1d(downsample_conv[0], downsample_conv[1], downsample_conv[2], padding=2)) 157 | for downsample in downsamples: 158 | self.downsample.append( 159 | DBlock(downsample[0], downsample[1], downsample[2]) 160 | ) 161 | for upsample in upsamples: 162 | self.upsample.append( 163 | UBlock(upsample[0], upsample[1], upsample[2], tuple(upsample[3])) 164 | ) 165 | 166 | self.film = nn.ModuleList() 167 | for film_layer in film_layers: 168 | self.film.append( 169 | FiLM(film_layer[0], film_layer[1]) 170 | ) 171 | self.first_conv = Conv1d(n_input_channels, 768, 3, padding=1) 172 | self.last_conv = Conv1d(128, 1, 3, padding=1) 173 | 174 | def forward(self, audio, spectrogram, noise_scale): 175 | x = audio.unsqueeze(1) 176 | downsampled = [] 177 | for film, layer in zip(self.film, self.downsample): 178 | x = layer(x) 179 | downsampled.append(film(x, noise_scale)) 180 | 181 | x = self.first_conv(spectrogram) 182 | assert len(self.upsample) == len(downsampled) 183 | for layer, (film_shift, film_scale) in zip(self.upsample, reversed(downsampled)): 184 | x = layer(x, film_shift, film_scale) 185 | x = self.last_conv(x) 186 | return x -------------------------------------------------------------------------------- /src/lightning_vocoders/preprocessor/dataset/glob_wav_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import Dataset 3 | import torchaudio 4 | from pathlib import Path 5 | import random 6 | import string 7 | 8 | def generate_random_string(length): 9 | letters = string.ascii_letters 10 | return ''.join(random.choice(letters) for _ in range(length)) 11 | 12 | class GlobWavDataset(Dataset): 13 | def __init__(self, roots, patterns, shuffled: bool = True,add_random_string=True) -> None: 14 | self.wav_files = [] 15 | for root,pattern in zip(roots,patterns): 16 | self.root = Path(root) 17 | self.wav_files.extend(list(self.root.glob(pattern))) 18 | if shuffled: 19 | random.shuffle(self.wav_files) 20 | self.add_random_string = add_random_string 21 | 22 | def __len__(self): 23 | return len(self.wav_files) 24 | 25 | 26 | def __getitem__(self,idx): 27 | wav_path = self.wav_files[idx] 28 | if self.add_random_string: 29 | return wav_path.stem + generate_random_string(5), torchaudio.load(wav_path),str(wav_path) 30 | else: 31 | return wav_path.stem , torchaudio.load(wav_path),str(wav_path) 32 | -------------------------------------------------------------------------------- /src/lightning_vocoders/preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import torchaudio 4 | import pathlib 5 | from omegaconf import DictConfig 6 | import numpy as np 7 | import webdataset 8 | import tqdm 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | class Preprocessor: 13 | """ 14 | Preprocess dataset 15 | """ 16 | 17 | def __init__(self, cfg: DictConfig): 18 | """ 19 | Args: 20 | cfg: hydra config 21 | """ 22 | self.cfg = cfg 23 | self.train_dataset = hydra.utils.instantiate(cfg.preprocess.preprocess_dataset.train) 24 | print(len(self.train_dataset)) 25 | self.val_dataset = hydra.utils.instantiate(cfg.preprocess.preprocess_dataset.val) 26 | self.spec_module = torchaudio.transforms.Spectrogram(**cfg.preprocess.stft) 27 | self.mel_scale = torchaudio.transforms.MelScale(**cfg.preprocess.mel) 28 | self.sampling_rate = self.cfg.sample_rate 29 | self.ssl_models = hydra.utils.instantiate(cfg.preprocess.ssl_models) 30 | if hasattr(cfg.data, "xvector"): 31 | self.use_xvector = cfg.data.xvector.use_xvector 32 | else: 33 | self.use_xvector = False 34 | if self.use_xvector: 35 | self.xvector_model = hydra.utils.instantiate(self.cfg.data.xvector.model) 36 | self.xvector_model.eval() 37 | self.xvector_sr = self.cfg.data.xvector.sr 38 | self.xvector_extract_secs = self.cfg.data.xvector.extract_secs 39 | self.xvector_embedding_size = self.cfg.data.xvector.embedding_size 40 | 41 | @torch.no_grad() 42 | def process_utterance( 43 | self, 44 | basename: str, 45 | orig_waveform: torch.Tensor, 46 | sample_rate: int, 47 | audio_file_path 48 | ): 49 | 50 | waveform = torchaudio.functional.resample( 51 | orig_waveform, sample_rate, new_freq=self.sampling_rate 52 | )[ 53 | 0 54 | ] # remove channel dimension only support mono 55 | print(wavform.size()) 56 | waveform = waveform[:20*self.sampling_rate] 57 | 58 | mel_spec, energy = self.calc_spectrogram(waveform) 59 | with open(audio_file_path, mode="rb") as f: 60 | wav_bytes = f.read() 61 | 62 | sample = { 63 | "__key__": basename, 64 | "speech.wav": wav_bytes, 65 | "resampled_speech.pth": webdataset.torch_dumps(waveform), 66 | "mel.pth": webdataset.torch_dumps(mel_spec.T), 67 | } 68 | for ssl_model, processor, feature_cfg in self.ssl_models: 69 | wav_tensor = torchaudio.functional.resample( 70 | waveform=orig_waveform, orig_freq=sample_rate, new_freq=feature_cfg.sr 71 | ) 72 | if processor is not None: 73 | inputs = processor( 74 | wav_tensor.squeeze(), return_tensors="pt", sampling_rate=feature_cfg.sr 75 | ) 76 | inputs.to("cuda") 77 | ssl_model.to("cuda") 78 | output = ssl_model(**inputs, output_hidden_states=True) 79 | sample[feature_cfg.key] = webdataset.torch_dumps( 80 | output.hidden_states[feature_cfg.layer][0].cpu() 81 | ) 82 | else: 83 | ssl_model.to("cuda") 84 | wav_tensor = wav_tensor.unsqueeze(1).to('cuda') 85 | output = ssl_model({"x": wav_tensor},sample_rate=feature_cfg.sr) 86 | sample[feature_cfg.key] = webdataset.torch_dumps( 87 | output.hidden_states[feature_cfg.layer][0].cpu() 88 | ) 89 | if self.use_xvector: 90 | resampled_for_xvector = torchaudio.functional.resample( 91 | orig_waveform, sample_rate, self.xvector_sr 92 | ).squeeze()[: int(self.xvector_sr * self.xvector_extract_secs)] 93 | embeddings = self.xvector_model.encode_batch(resampled_for_xvector.unsqueeze(0)) 94 | sample["xvector.pth"] = embeddings.view(self.xvector_embedding_size) 95 | 96 | 97 | return sample 98 | 99 | def build_from_path(self): 100 | pathlib.Path("/".join(self.cfg.preprocess.train_tar_sink.pattern.split("/")[:-1])).mkdir(exist_ok=True) 101 | train_sink = hydra.utils.instantiate(self.cfg.preprocess.train_tar_sink) 102 | val_sink = hydra.utils.instantiate(self.cfg.preprocess.val_tar_sink) 103 | dataloader = DataLoader(self.val_dataset,batch_size=1) 104 | sink = val_sink 105 | for idx, (basename, (wav,sr),wav_path) in enumerate(tqdm.tqdm(dataloader)): 106 | sample = self.process_utterance( 107 | basename[0], 108 | wav[0], 109 | sr[0], 110 | wav_path[0] 111 | ) 112 | sink.write(sample) 113 | dataloader = DataLoader(self.train_dataset,batch_size=1) 114 | sink = train_sink 115 | for idx, (basename, (wav,sr),wav_path) in enumerate(tqdm.tqdm(dataloader)): 116 | sample = self.process_utterance( 117 | basename[0], 118 | wav[0], 119 | sr[0], 120 | wav_path[0] 121 | ) 122 | sink.write(sample) 123 | 124 | train_sink.close() 125 | val_sink.close() 126 | 127 | def calc_spectrogram(self, waveform: torch.Tensor): 128 | magspec = self.spec_module(waveform) 129 | melspec = self.mel_scale(magspec) 130 | logmelspec = torch.log(torch.clamp_min(melspec, 1.0e-5) * 1.0).to(torch.float32) 131 | energy = torch.norm(magspec, dim=0) 132 | return logmelspec, energy.numpy() 133 | -------------------------------------------------------------------------------- /src/preprocess.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | import hydra 3 | from omegaconf import DictConfig 4 | from lightning.pytorch import seed_everything 5 | 6 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 7 | from lightning_vocoders.preprocessor.preprocessor import Preprocessor 8 | 9 | 10 | @hydra.main(version_base="1.3", config_name="config", config_path="../config") 11 | def main(cfg: DictConfig): 12 | seed_everything(1234) 13 | preprocssor = Preprocessor(cfg=cfg) 14 | preprocssor.build_from_path() 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /src/synthesize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import lightning.pytorch as pl 3 | import torch 4 | import io 5 | from omegaconf import DictConfig 6 | import hydra 7 | from pathlib import Path 8 | from lightning_vocoders.preprocessor.preprocessor import Preprocessor 9 | from lightning_vocoders.preprocessor.dataset.glob_wav_dataset import GlobWavDataset 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | def synthesize(cfg:DictConfig,ckpt_path:Path,wav_path:Path,output_path): 13 | lightning_module:pl.LightningModule = hydra.utils.instantiate(cfg.model.lightning_module,cfg) 14 | lightning_module = lightning_module.load_from_checkpoint(ckpt_path) 15 | 16 | dataset = GlobWavDataset([wav_path],["**/*.wav"],shuffled=False,add_random_string=False) 17 | preprocessor = Preprocessor(lightning_module.cfg) 18 | 19 | @torch.no_grad() 20 | def test_collate_fn(sample): 21 | assert len(sample) == 1 # only expect batch size of 1 22 | wav_name, (wav_data,sr), wav_path = sample[0] 23 | print(wav_data.size()) 24 | wav_data = wav_data[0].unsqueeze(0) 25 | preprocessed_sample = preprocessor.process_utterance(wav_name,wav_data,sr,wav_path) 26 | for k,v in preprocessed_sample.items(): 27 | if k.endswith(".pth"): 28 | preprocessed_sample[k] = torch.load(io.BytesIO(v)) 29 | batch = { 30 | "resampled_speech.pth": [preprocessed_sample["resampled_speech.pth"]], 31 | "input_feature": preprocessed_sample[cfg.data.target_feature.key].unsqueeze(0), 32 | "filenames": [preprocessed_sample["__key__"]], 33 | "wav_lens": None 34 | } 35 | return batch 36 | test_dataloader = DataLoader(dataset,collate_fn=test_collate_fn) 37 | lightning_module.output_path = output_path 38 | trainer = pl.Trainer(limit_test_batches=10) 39 | trainer.test(lightning_module,test_dataloader) 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--ckpt_path",required=True) 45 | parser.add_argument("--wav", required=True,type=str) 46 | parser.add_argument("--output_path", required=True,type=str) 47 | args = parser.parse_args() 48 | ckpt = torch.load(args.ckpt_path) 49 | 50 | cfg = ckpt['hyper_parameters']['cfg'] 51 | 52 | 53 | synthesize(cfg,args.ckpt_path,args.wav,args.output_path) 54 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | from pathlib import Path 3 | import hydra 4 | import torch 5 | from omegaconf import DictConfig 6 | from lightning_vocoders.models.hifigan.lightning_module import HiFiGANLightningModule 7 | from lightning.pytorch.callbacks import LearningRateMonitor 8 | from lightning.pytorch import seed_everything 9 | 10 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 11 | 12 | 13 | @hydra.main(version_base="1.3", config_name="config", config_path="../config") 14 | def main(cfg: DictConfig): 15 | seed_everything(1234) 16 | lightning_module = hydra.utils.instantiate(cfg.model.lightning_module, cfg) 17 | if cfg.compile: 18 | lightning_module = torch.compile(lightning_module, dynamic=True) 19 | callbacks = [LearningRateMonitor(logging_interval="step")] 20 | datamodule = hydra.utils.instantiate(cfg.data.datamodule, cfg) 21 | loggers = [hydra.utils.instantiate(logger) for logger in cfg.train.loggers] 22 | trainer = hydra.utils.instantiate( 23 | cfg.train.trainer, logger=loggers, callbacks=callbacks 24 | ) 25 | trainer.fit(lightning_module, datamodule,ckpt_path=cfg.train.ckpt_path) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | --------------------------------------------------------------------------------