├── .gitignore ├── LICENSE ├── README.md ├── config ├── base.json └── small.json ├── docs ├── config.md ├── credits.md ├── images │ ├── auris_architecture.png │ └── auris_decoder.png ├── infer.md ├── installation.md ├── technical_details.md ├── todo_list.md └── train.md ├── example ├── musicxml │ └── test.musicxml ├── score_inputs │ └── test.json └── text_inputs │ └── transcription.json ├── infer.py ├── infer_webui.py ├── module ├── g2p │ ├── __init__.py │ ├── english.py │ ├── extractor.py │ └── japanese.py ├── infer │ └── __init__.py ├── language_model │ ├── __init__.py │ ├── extractor.py │ └── rinna_roberta.py ├── preprocess │ ├── jvs.py │ ├── processor.py │ ├── scan.py │ └── wave_and_text.py ├── utils │ ├── common.py │ ├── config.py │ ├── dataset.py │ ├── energy_estimation.py │ ├── f0_estimation.py │ └── safetensors.py └── vits │ ├── __init__.py │ ├── convnext.py │ ├── crop.py │ ├── decoder.py │ ├── discriminator.py │ ├── duration_discriminator.py │ ├── duration_predictors.py │ ├── feature_retrieval.py │ ├── flow.py │ ├── generator.py │ ├── loss.py │ ├── monotonic_align │ ├── __init__.py │ ├── core.pyx │ └── setup.py │ ├── normalization.py │ ├── posterior_encoder.py │ ├── prior_encoder.py │ ├── speaker_embedding.py │ ├── spectrogram.py │ ├── text_encoder.py │ ├── transformer.py │ ├── transforms.py │ └── wn.py ├── preprocess.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | /*.DS_Store 163 | *.pt 164 | *.pth 165 | *.wav 166 | *.ogg 167 | *.mp3 168 | *.swp 169 | *.swo 170 | module/vits/monotonic_align/core.c 171 | module/vits/monotonic_align/build/* 172 | dataset_cache/* 173 | models/* 174 | models 175 | logs/* 176 | lightning_logs 177 | lightning_logs/* 178 | outputs 179 | outputs/* 180 | audio_inputs/* 181 | audio_inputs 182 | flagged/ 183 | flagged/* 184 | my_config/* 185 | my_config/ 186 | backups/* 187 | backups -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 uthree 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auris 2 | 歌声合成, 変換, TTSができる効率的な日本語事前学習済みモデルの開発 3 | 4 | ## 目次 5 | - [クレジット表記](docs/credits.md) 6 | - [インストール方法](docs/installation.md) 7 | - [学習方法](docs/train.md) 8 | - [推論方法](docs/infer.md) 9 | - [コンフィグについて](docs/config.md) 10 | - [技術的な詳細](docs/technical_details.md) 11 | 12 | ## Other Languages 13 | - [English](docs/readme_en.md) (comming soon) -------------------------------------------------------------------------------- /config/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "vits": { 3 | "segment_size": 32, 4 | "generator": { 5 | "decoder": { 6 | "sample_rate": 48000, 7 | "frame_size": 480, 8 | "n_fft": 1920, 9 | "speaker_embedding_dim": 256, 10 | "content_channels": 192, 11 | "pe_internal_channels": 256, 12 | "pe_num_layers": 4, 13 | "source_internal_channels": 256, 14 | "source_num_layers": 4, 15 | "num_harmonics": 30, 16 | "filter_channels": [512, 256, 128, 64, 32], 17 | "filter_factors": [ 18 | 5, 19 | 4, 20 | 4, 21 | 3, 22 | 2 23 | ], 24 | "filter_resblock_type": "1", 25 | "filter_down_dilations": [ 26 | [1, 2], 27 | [4, 8] 28 | ], 29 | "filter_down_interpolation": "conv", 30 | "filter_up_dilations": [ 31 | [1, 3, 5], 32 | [1, 3, 5], 33 | [1, 3, 5] 34 | ], 35 | "filter_up_kernel_sizes": [ 36 | 3, 37 | 7, 38 | 11 39 | ], 40 | "filter_up_interpolation": "conv" 41 | }, 42 | "posterior_encoder": { 43 | "n_fft": 1920, 44 | "frame_size": 480, 45 | "internal_channels": 192, 46 | "speaker_embedding_dim": 256, 47 | "content_channels": 192, 48 | "kernel_size": 5, 49 | "dilation": 1, 50 | "num_layers": 16 51 | }, 52 | "speaker_embedding": { 53 | "num_speakers": 8192, 54 | "embedding_dim": 256 55 | }, 56 | "prior_encoder": { 57 | "flow": { 58 | "content_channels": 192, 59 | "internal_channels": 192, 60 | "speaker_embedding_dim": 256, 61 | "kernel_size": 5, 62 | "dilation": 1, 63 | "num_layers": 4, 64 | "num_flows": 4 65 | }, 66 | "text_encoder": { 67 | "num_phonemes": 512, 68 | "num_languages": 256, 69 | "lm_dim": 768, 70 | "internal_channels": 256, 71 | "speaker_embedding_dim": 256, 72 | "content_channels": 192, 73 | "n_heads": 4, 74 | "num_layers": 4, 75 | "window_size": 4 76 | }, 77 | "stochastic_duration_predictor": { 78 | "in_channels": 192, 79 | "filter_channels": 256, 80 | "kernel_size": 5, 81 | "p_dropout": 0.0, 82 | "speaker_embedding_dim": 256 83 | }, 84 | "duration_predictor": { 85 | "content_channels": 192, 86 | "internal_channels": 256, 87 | "speaker_embedding_dim": 256, 88 | "kernel_size": 7, 89 | "num_layers": 4 90 | } 91 | } 92 | }, 93 | "discriminator": { 94 | "mrd": { 95 | "resolutions": [ 96 | 128, 97 | 256, 98 | 512 99 | ], 100 | "channels": 32, 101 | "max_channels": 256, 102 | "num_layers": 4 103 | }, 104 | "mpd": { 105 | "periods": [ 106 | 1, 107 | 2, 108 | 3, 109 | 5, 110 | 7, 111 | 11, 112 | 17, 113 | 23, 114 | 31 115 | ], 116 | "channels": 32, 117 | "channels_mul": 2, 118 | "max_channels": 256, 119 | "num_layers": 4 120 | } 121 | }, 122 | "duration_discriminator": { 123 | "content_channels": 192, 124 | "speaker_embedding_dim": 256, 125 | "num_layers": 3 126 | }, 127 | "optimizer": { 128 | "lr": 1e-4, 129 | "betas": [ 130 | 0.8, 131 | 0.99 132 | ] 133 | } 134 | }, 135 | "language_model": { 136 | "type": "rinna_roberta", 137 | "options": { 138 | "hf_repo": "rinna/japanese-roberta-base", 139 | "layer": 12 140 | } 141 | }, 142 | "preprocess": { 143 | "sample_rate": 48000, 144 | "max_waveform_length": 480000, 145 | "pitch_estimation": "fcpe", 146 | "max_phonemes": 100, 147 | "lm_max_tokens": 30, 148 | "frame_size": 480, 149 | "cache": "dataset_cache" 150 | }, 151 | "train": { 152 | "save": { 153 | "models_dir": "models", 154 | "interval": 400 155 | }, 156 | "data_module": { 157 | "cache_dir": "dataset_cache", 158 | "metadata": "models/metadata.json", 159 | "batch_size": 8, 160 | "num_workers": 15 161 | }, 162 | "trainer": { 163 | "devices": "auto", 164 | "max_epochs": 1000000, 165 | "precision": null 166 | } 167 | }, 168 | "infer": { 169 | "n_fft": 1920, 170 | "frame_size": 480, 171 | "sample_rate": 48000, 172 | "max_lm_tokens": 50, 173 | "max_phonemes": 500, 174 | "max_frames": 2000, 175 | "device": "cuda" 176 | } 177 | } -------------------------------------------------------------------------------- /config/small.json: -------------------------------------------------------------------------------- 1 | { 2 | "vits": { 3 | "segment_size": 32, 4 | "generator": { 5 | "decoder": { 6 | "sample_rate": 48000, 7 | "frame_size": 480, 8 | "n_fft": 1920, 9 | "speaker_embedding_dim": 128, 10 | "content_channels": 96, 11 | "pe_internal_channels": 128, 12 | "pe_num_layers": 3, 13 | "source_internal_channels": 128, 14 | "source_num_layers": 3, 15 | "num_harmonics": 14, 16 | "filter_channels": [ 17 | 192, 18 | 96, 19 | 48, 20 | 24 21 | ], 22 | "filter_factors": [ 23 | 4, 24 | 4, 25 | 5, 26 | 6 27 | ], 28 | "filter_resblock_type": "3", 29 | "filter_down_dilations": [ 30 | [ 31 | 1, 32 | 2, 33 | 4 34 | ] 35 | ], 36 | "filter_down_interpolation": "linear", 37 | "filter_up_dilations": [ 38 | [ 39 | 1, 40 | 3, 41 | 9, 42 | 27 43 | ] 44 | ], 45 | "filter_up_kernel_sizes": [ 46 | 3 47 | ], 48 | "filter_up_interpolation": "linear" 49 | }, 50 | "posterior_encoder": { 51 | "n_fft": 1920, 52 | "frame_size": 480, 53 | "internal_channels": 96, 54 | "speaker_embedding_dim": 128, 55 | "content_channels": 96, 56 | "kernel_size": 5, 57 | "dilation": 1, 58 | "num_layers": 16 59 | }, 60 | "speaker_embedding": { 61 | "num_speakers": 8192, 62 | "embedding_dim": 128 63 | }, 64 | "prior_encoder": { 65 | "flow": { 66 | "content_channels": 96, 67 | "internal_channels": 96, 68 | "speaker_embedding_dim": 128, 69 | "kernel_size": 5, 70 | "dilation": 1, 71 | "num_layers": 4, 72 | "num_flows": 4 73 | }, 74 | "text_encoder": { 75 | "num_phonemes": 512, 76 | "num_languages": 256, 77 | "lm_dim": 768, 78 | "internal_channels": 256, 79 | "speaker_embedding_dim": 128, 80 | "content_channels": 96, 81 | "n_heads": 4, 82 | "num_layers": 4, 83 | "window_size": 4 84 | }, 85 | "stochastic_duration_predictor": { 86 | "in_channels": 96, 87 | "filter_channels": 256, 88 | "kernel_size": 5, 89 | "p_dropout": 0.0, 90 | "speaker_embedding_dim": 128 91 | }, 92 | "duration_predictor": { 93 | "content_channels": 96, 94 | "internal_channels": 256, 95 | "speaker_embedding_dim": 128, 96 | "kernel_size": 7, 97 | "num_layers": 4 98 | } 99 | } 100 | }, 101 | "discriminator": { 102 | "mrd": { 103 | "resolutions": [ 104 | 128, 105 | 256, 106 | 512 107 | ], 108 | "channels": 32, 109 | "max_channels": 256, 110 | "num_layers": 4 111 | }, 112 | "mpd": { 113 | "periods": [ 114 | 1, 115 | 2, 116 | 3, 117 | 5, 118 | 7, 119 | 11 120 | ], 121 | "channels": 32, 122 | "channels_mul": 2, 123 | "max_channels": 256, 124 | "num_layers": 4 125 | } 126 | }, 127 | "duration_discriminator": { 128 | "content_channels": 96, 129 | "speaker_embedding_dim": 128, 130 | "num_layers": 3 131 | }, 132 | "optimizer": { 133 | "lr": 1e-4, 134 | "betas": [ 135 | 0.8, 136 | 0.99 137 | ] 138 | } 139 | }, 140 | "language_model": { 141 | "type": "rinna_roberta", 142 | "options": { 143 | "hf_repo": "rinna/japanese-roberta-base", 144 | "layer": 4 145 | } 146 | }, 147 | "preprocess": { 148 | "sample_rate": 48000, 149 | "max_waveform_length": 480000, 150 | "pitch_estimation": "fcpe", 151 | "max_phonemes": 100, 152 | "lm_max_tokens": 30, 153 | "frame_size": 480, 154 | "cache": "dataset_cache" 155 | }, 156 | "train": { 157 | "save": { 158 | "models_dir": "models", 159 | "interval": 400 160 | }, 161 | "data_module": { 162 | "cache_dir": "dataset_cache", 163 | "metadata": "models/metadata.json", 164 | "batch_size": 4, 165 | "num_workers": 7 166 | }, 167 | "trainer": { 168 | "devices": "auto", 169 | "max_epochs": 1000000, 170 | "precision": null 171 | } 172 | }, 173 | "infer": { 174 | "n_fft": 1920, 175 | "frame_size": 480, 176 | "sample_rate": 48000, 177 | "max_lm_tokens": 50, 178 | "max_phonemes": 500, 179 | "max_frames": 2000, 180 | "device": "cuda" 181 | } 182 | } -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # コンフィグ 2 | [config](../config/)は、モデルアーキテクチャの設定や学習時の設定などを記述するファイルがあります。 3 | どのコンフィグを選ぶかによって、モデルの性能が変化したりします。 4 | 5 | ## プリセット 6 | 本リポジトリにはコンフィグの設定例としていくつかのコンフィグファイルを用意しています。 7 | - [small](../config/small.json) : 最小規模の構成。動作確認実験のためのもの。 8 | - [base](../config/base.json) : デフォルト。HiFi-GANのV1モデル/VITSに相当する規模のもの。おそらく実用的な性能が出る。 9 | - [large](../config/large.json) (Comming Soon?) : baseから次元数を増やしたモデル。大規模な計算震源が必要かも。 10 | 11 | ## パラメータの意味 12 | そもそもこのコンフィグを編集しようとする人はソースコードを読めたり、パラメータ名からある程度どのようなパラメータかを察する事ができるはずなので、必要ないかもしれないが、一部説明を書いておく。 13 | -WIP- -------------------------------------------------------------------------------- /docs/credits.md: -------------------------------------------------------------------------------- 1 | # クレジット表記 2 | MITライセンスのリポジトリからソースコードをコピペしたりしている都合上、その旨を明記しなければいけないため、ここに記す。 3 | 4 | ## 引用したソースコード 5 | - [monotonic_align](../module/vits/monotonic_align/) : [ESPNet](https://github.com/espnet/espnet) から引用。エラーメッセージを一部改変。 6 | - [duration_predictors.py](../module/vits/duration_predictors.py/) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/model/duration_predictors.py) から引用、改変。 7 | - [transforms.py](../module/utils/transforms.py) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/utils/transforms.py)から引用。 8 | - [transformer.py](../module/vits/transformer.py) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/model/transformer.py) から引用、改変。 9 | - [normalization.py](../module/vits/normalization.py) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/model/normalization.py) から一部引用。 10 | 11 | ## 学習済みモデル 12 | - [rinna_roberta.py](../module/language_model/rinna_roberta.py) : [rinna/japanese-roberta-base](https://huggingface.co/rinna/japanese-roberta-base) を使用。 13 | 14 | ## 参考文献 15 | 16 | ### 参考にした記事 17 | 参考にしたブログなど記事たち。気づき次第随時追加していく。 18 | - [【機械学習】VITSでアニメ声へ変換できるボイスチェンジャー&読み上げ器を作った話](https://qiita.com/zassou65535/items/00d7d5562711b89689a8) : zassou氏によるVITS解説記事。理論から実装、損失関数の導出など非常にわかりやすく書かれている。本プロジェクトはこの記事に出合えなかったら完成するのは不可能だろう。 19 | 20 | ### 参考にしたリポジトリ 21 | - [zassou65535/VITS](https://github.com/zassou65535/VITS) : zassou氏によるVITS実装。 posterior_encoder.pyをはじめとする実装や、ディレクトリ構成などを参考にした。 22 | - [RVC-Project/Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) : いわゆるRVC, HnNSF-HifiGANによるピッチ制御可能なデコーダーや、特徴量ベクトルを検索するという案は、RVCから着想を得ている。 23 | - [fishaudio/Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) : VITSのテキストエンコーダー部分にBERTの特徴量を付与し、感情や文脈などもエンコードできるようにするという案は、Bert-VITS2から着想を得ている。 24 | - [uthree/tinyvc](https://github.com/uthree/tinyvc) 自分のリポジトリを参考にするとはどういうことだ、と言われそうだが、TinyVCのデコーダーをほぼそのままスケールアップして採用している。 25 | 26 | ### 論文 27 | 参考にした論文。 28 | #### 全体的な設計 29 | - [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) 30 | - [VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design](https://arxiv.org/abs/2307.16430) 31 | 32 | #### Decoderを設計する際に参考にした論文 33 | - [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) 34 | - [FastSVC: Fast Cross-Domain Singing Voice Conversion with Feature-wise Linear Modulation](https://arxiv.org/abs/2011.05731) 35 | - [DDSP: Differentiable Digital Signal Processing](https://arxiv.org/abs/2001.04643) 36 | - [VISinger 2: High-Fidelity End-to-End Singing Voice Synthesis Enhanced by Digital Signal Processing Synthesizer](https://arxiv.org/abs/2211.02903) 37 | - [Neural Concatenative Singing Voice Conversion: Rethinking Concatenation-Based Approach for One-Shot Singing Voice Conversion](https://arxiv.org/abs/2312.04919) 38 | 39 | #### Feature Retrieval (特徴量検索) 40 | - [Voice Conversion With Just Nearest Neighbors](https://arxiv.org/abs/2305.18975) 41 | 42 | #### 言語モデル 43 | - [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 44 | 45 | #### その他 46 | - [GAN Vocoder: Multi-Resolution Discriminator Is All You Need](https://arxiv.org/abs/2103.05236) -------------------------------------------------------------------------------- /docs/images/auris_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uthree/auris_experimental_vits_dsp/faffcafa3e38028f25ecaa79dd420207a732e324/docs/images/auris_architecture.png -------------------------------------------------------------------------------- /docs/images/auris_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uthree/auris_experimental_vits_dsp/faffcafa3e38028f25ecaa79dd420207a732e324/docs/images/auris_decoder.png -------------------------------------------------------------------------------- /docs/infer.md: -------------------------------------------------------------------------------- 1 | # 推論方法 2 | このドキュメントでは、推論方法について記す 3 | 4 | ## gradioによるUIを使った推論 5 | webuiを起動する。 6 | ```sh 7 | python3 infer_webui.py 8 | ``` 9 | `localhost:7860` にブラウザでアクセスする。 10 | 11 | ## CUIによる推論 12 | 13 | ### 音声再構築タスク 14 | 入力された音声を再構築するタスク。VAEの性能確認用。 15 | 16 | 1. 音声ファイルが複数入ったディレクトリを用意する 17 | ```sh 18 | mkdir audio_inputs # ここに入力音声ファイルを入れる 19 | ``` 20 | 21 | 2. 推論する。 22 | `-s 話者名`で話者を指定する必要がある。 23 | ```sh 24 | python3 infer.py -i audio_inputs -t recon -s jvs001 25 | ``` 26 | 27 | 3. `outputs/`内に出力ファイルが生成されるので、確認する。 28 | 29 | ### 音声読み上げ(TTS) 30 | テキストを読み上げるタスク。 31 | 32 | 1. 台本フォルダを用意する。 33 | 台本の内容は`example/text_inputs/`に例があるので、それを参考に作成する。 34 | ```sh 35 | mkdir text_inputs # ここに台本を入れる 36 | ``` 37 | 38 | 2. 推論する。 39 | ```sh 40 | python3 infer.py -i text_inputs -t tts 41 | ``` 42 | 43 | #### 設定項目の詳細 44 | - `speaker`: 話者名 45 | - `text`: 読み上げるテキスト 46 | - `style_text`: スタイルテキスト。任意。つけない場合は`text`と同じ内容になる。 47 | - `language`: 言語 -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # インストール方法 2 | ## 事前に用意するもの 3 | - Python 3.10.6 or later 4 | - torch 2.0.1 or later, cuda等GPUが使用可能な環境 5 | 6 | ## 手順 7 | 1. このリポジトリをクローンし、ディレクトリ内に移動する 8 | ```sh 9 | git clone https://github.com/uthree/auris 10 | cd auris 11 | ``` 12 | 13 | 2. 依存関係をインストール 14 | ```sh 15 | pip3 install -r requirements.txt 16 | ``` 17 | 18 | 3. monotonic_arginをビルドする(任意) 19 | (ビルドしない場合はnumba実装が使われる。ビルドしたほうがパフォーマンスが良い。) 20 | ```sh 21 | cd module/vits/monotonic_align 22 | python3 setup.py build_ext --inplace 23 | ``` 24 | -------------------------------------------------------------------------------- /docs/technical_details.md: -------------------------------------------------------------------------------- 1 | # 技術的な詳細 2 | このドキュメントは、本リポジトリの技術的な詳細を記す。 3 | 4 | ## 開発の動機 5 | - オープンソースのTTS(Text-To-Speech), SVS(Singing-Voice-Synthesis), SVC(Singing-Voice-Conversion)ができるモデルが欲しい。 6 | - なるべくライセンスが緩いものがいい。(MITライセンスなど) 7 | - 日本語の事前学習モデルが欲しい。 8 | 9 | ## モデルアーキテクチャ 10 | VITSをベースに改造するという形になる。 11 | 具体的には、 12 | - DecoderをDSP+HnNSF-HiFiGANに変更。DSP機能を取り入れることで、外部からピッチを制御可能に。これにより歌声合成ができる。 13 | - Text Encoderに言語モデルの特徴量を参照する機能をつけ、感情や文脈などを読み取れるように。 14 | - Feature Retrievalを追加し話者の再現性を向上させる 15 | - VITS2のDuration Discriminatorを追加 16 | - Discriminatorのうち、MultiScaleDiscriminatorをMultiResolutionalDiscriminatorに変更。スペクトログラムのオーバースムージングを回避する。 17 | 18 | 等の改造があげられる。 19 | 20 | ### 全体の構造 21 | ![](./images/auris_architecture.png) 22 | VITSの構造とほぼ同じ。 23 | 24 | ### デコーダー 25 | ![](./images/auris_decoder.png) 26 | DDSPのような加算シンセサイザによる音声合成の後にHiFi-GANに似た構造のフィルターをかけるハイブリッド構造。 -------------------------------------------------------------------------------- /docs/todo_list.md: -------------------------------------------------------------------------------- 1 | # TODO List 2 | - [ ] 話者モーフィング機能の追加 3 | - [ ] 特徴ベクトル辞書の生成 4 | - [ ] メタデータに話者の平均f0等の追加情報をいれる 5 | - [ ] musicxmlを読み込む機能の追加 6 | - [ ] 楽譜から歌声合成を作成する 7 | - [ ] onnxへの出力 -------------------------------------------------------------------------------- /docs/train.md: -------------------------------------------------------------------------------- 1 | # 学習方法 2 | このドキュメントでは、モデルの学習方法について記す。 3 | 4 | ## 前処理 5 | いくつかのデータセットのための前処理スクリプトを用意しておいた。 6 | データセットをダウンロードし、これらのスクリプトを実行するだけで前処理が完了する。 7 | `-c`, `--config` オプションで使用するコンフィグファイルを指定できる。 8 | 9 | ### JVSコーパス 10 | [JVSコーパス](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus)の前処理は以下のコマンドで行う。 11 | ```sh 12 | python3 preprocess.py jvs jvs_ver1/ -c config/base.json 13 | ``` 14 | 15 | 16 | ### 自作データセット 17 | データセットを自作する場合、まず以下の構成のディレクトリを用意する。 18 | `root_dir`の名前、`speaker001`等の名前は何でもよいが、半角英数字で命名することを推奨。 19 | 書き起こしの言語が日本語でない場合は、 `preprocess/wave_and_text.py`の`LANGUAGE`を書き換える。 20 | ``` 21 | root_dir/ 22 | - speaker001/ 23 | - speech001.wav 24 | - speech001.txt 25 | - speech002.wav 26 | - speech002.txt 27 | ... 28 | - speaker02/ 29 | - speech001.wav 30 | - speech001.txt 31 | - speech002.wav 32 | - speech002.txt 33 | ... 34 | ... 35 | ``` 36 | wavファイルと同じファイル名のテキストファイルに、その書き起こしが入る形にする。 37 | データセットが用意できたら、前処理を実行する。 38 | ```sh 39 | python3 preprocess.py wav-txt root_dir/ -c config/base.json 40 | ``` 41 | 42 | 43 | ## 学習を実行 44 | ```sh 45 | python3 train.py -c config/base.json 46 | ``` 47 | 48 | ## 学習を再開する 49 | `models/vits.ckpt`を自動的に読み込んで再開してくれる。 50 | 読み込みに失敗した場合はファイルが壊れているので、`lightning_logs/`内にある最新のckptを`vits.ckpt`に名前を変更して`models/`に配置することで復旧できる。 51 | 52 | ## 学習の状態を確認 53 | tensorboardというライブラリを使って学習進捗を可視化することができる。 54 | ```sh 55 | tensorboard --logdir lightning_logs 56 | ``` 57 | をscreen等を用いてバックグラウンドで実行する。 58 | これが実行されている間はtensorboardのサーバーが動いているので、ブラウザで`http://localhost:6006`にアクセスすると進捗を見ることができる。 59 | 60 | ## FAQ 61 | - `models/metadata.json`は何のファイルですか? 62 | 話者名と話者IDを関連付けるためのメタデータが含まれるJSONファイルです。 63 | データセットの前処理を行うと自動的に生成されます。 64 | - `models/config.json`は何のファイルですか? 65 | 推論時にデフォルトでロードするコンフィグです。前処理時に`-c`, `--config`オプションで指定したコンフィグが複製されます。 -------------------------------------------------------------------------------- /example/musicxml/test.musicxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Title 6 | 7 | 8 | Composer 9 | 10 | MuseScore 3.6.2 11 | 2024-04-13 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 6.99911 22 | 40 23 | 24 | 25 | 1696.94 26 | 1200.48 27 | 28 | 85.7252 29 | 85.7252 30 | 85.7252 31 | 85.7252 32 | 33 | 34 | 85.7252 35 | 85.7252 36 | 85.7252 37 | 85.7252 38 | 39 | 40 | 41 | 42 | 43 | 44 | title 45 | test score 46 | 47 | 48 | composer 49 | Composer 50 | 51 | 52 | 53 | Piano 54 | Pno. 55 | 56 | Piano 57 | 58 | 59 | 60 | 1 61 | 1 62 | 78.7402 63 | 0 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 50.00 73 | 0.00 74 | 75 | 170.00 76 | 77 | 78 | 79 | 1 80 | 81 | 0 82 | 83 | 87 | 88 | G 89 | 2 90 | 91 | 92 | 93 | 94 | C 95 | 4 96 | 97 | 1 98 | 1 99 | quarter 100 | up 101 | 102 | single 103 | 104 | 105 | 106 | 107 | 108 | E 109 | 4 110 | 111 | 1 112 | 1 113 | quarter 114 | up 115 | 116 | single 117 | 118 | 119 | 120 | 121 | 122 | G 123 | 4 124 | 125 | 1 126 | 1 127 | quarter 128 | up 129 | 130 | single 131 | 132 | 133 | 134 | 135 | 136 | B 137 | 4 138 | 139 | 1 140 | 1 141 | quarter 142 | down 143 | 144 | single 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 4 153 | 1 154 | 155 | 156 | 157 | 158 | 159 | 4 160 | 1 161 | 162 | 163 | 164 | 165 | 166 | 4 167 | 1 168 | 169 | 170 | 171 | 172 | 173 | 174 | 0.00 175 | 0.00 176 | 177 | 135.55 178 | 179 | 180 | 181 | 182 | 4 183 | 1 184 | 185 | 186 | 187 | 188 | 189 | 4 190 | 1 191 | 192 | 193 | 194 | 195 | 196 | 4 197 | 1 198 | 199 | 200 | 201 | 202 | 203 | 4 204 | 1 205 | 206 | 207 | 208 | 209 | 210 | 211 | 0.00 212 | 0.00 213 | 214 | 135.55 215 | 216 | 217 | 218 | 219 | 4 220 | 1 221 | 222 | 223 | 224 | 225 | 226 | 4 227 | 1 228 | 229 | 230 | 231 | 232 | 233 | 4 234 | 1 235 | 236 | 237 | 238 | 239 | 240 | 4 241 | 1 242 | 243 | 244 | 245 | 246 | 247 | 248 | 0.00 249 | 0.00 250 | 251 | 135.55 252 | 253 | 254 | 255 | 256 | 4 257 | 1 258 | 259 | 260 | 261 | 262 | 263 | 4 264 | 1 265 | 266 | 267 | 268 | 269 | 270 | 4 271 | 1 272 | 273 | 274 | 275 | 276 | 277 | 4 278 | 1 279 | 280 | 281 | 282 | 283 | 284 | 285 | 0.00 286 | 0.00 287 | 288 | 135.55 289 | 290 | 291 | 292 | 293 | 4 294 | 1 295 | 296 | 297 | 298 | 299 | 300 | 4 301 | 1 302 | 303 | 304 | 305 | 306 | 307 | 4 308 | 1 309 | 310 | 311 | 312 | 313 | 314 | 4 315 | 1 316 | 317 | 318 | 319 | 320 | 321 | 322 | 0.00 323 | 0.00 324 | 325 | 135.55 326 | 327 | 328 | 329 | 330 | 4 331 | 1 332 | 333 | 334 | 335 | 336 | 337 | 4 338 | 1 339 | 340 | 341 | 342 | 343 | 344 | 4 345 | 1 346 | 347 | 348 | 349 | 350 | 351 | 4 352 | 1 353 | 354 | 355 | 356 | 357 | 358 | 359 | 0.00 360 | 0.00 361 | 362 | 135.55 363 | 364 | 365 | 366 | 367 | 4 368 | 1 369 | 370 | 371 | 372 | 373 | 374 | 4 375 | 1 376 | 377 | 378 | 379 | 380 | 381 | 4 382 | 1 383 | 384 | 385 | 386 | 387 | 388 | 4 389 | 1 390 | 391 | 392 | 393 | 394 | 395 | 396 | 0.00 397 | 503.83 398 | 399 | 135.55 400 | 401 | 402 | 403 | 404 | 4 405 | 1 406 | 407 | 408 | 409 | 410 | 411 | 4 412 | 1 413 | 414 | 415 | 416 | 417 | 418 | 4 419 | 1 420 | 421 | 422 | 423 | 424 | 425 | 4 426 | 1 427 | 428 | 429 | light-heavy 430 | 431 | 432 | 433 | 434 | -------------------------------------------------------------------------------- /example/score_inputs/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "parts": { 3 | "0": { 4 | "language": "ja", 5 | "style_text": "あ", 6 | "speaker": "jvs001", 7 | "notes": [ 8 | { 9 | "pitch": 64, 10 | "onset": 0.1, 11 | "offset": 0.9, 12 | "lyrics": "あ", 13 | "vibrato": { 14 | "strength": 1.0, 15 | "onset": 0.5, 16 | "offset": 0.8, 17 | "period": 0.05 18 | }, 19 | "yodel": { 20 | "strength": 1.0, 21 | "onset": 0.1, 22 | "offset": 0.2 23 | }, 24 | "fall": { 25 | "strength": 12.0, 26 | "onset": 0.8, 27 | "offset": 0.9 28 | }, 29 | "energy": 1.0 30 | } 31 | ] 32 | } 33 | } 34 | } -------------------------------------------------------------------------------- /example/text_inputs/transcription.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": { 3 | "speaker": "jvs001", 4 | "text": "また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。", 5 | "style_text": "また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。", 6 | "language": "ja" 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import json 4 | 5 | import torch 6 | import torchaudio 7 | from torchaudio.functional import resample 8 | 9 | from module.infer import Infer 10 | 11 | parser = argparse.ArgumentParser(description="inference") 12 | parser.add_argument('-c', '--config', default='models/config.json') 13 | parser.add_argument('-t', '--task', choices=['tts', 'recon', 'svc', 'svs'], default='tts') 14 | parser.add_argument('-s', '--speaker', default='jvs001') 15 | parser.add_argument('-i', '--inputs', default='inputs') 16 | parser.add_argument('-o', '--outputs', default='outputs') 17 | parser.add_argument('-m', '--model', default='models/generator.safetensors') 18 | parser.add_argument('-meta', '--metadata', default='models/metadata.json') 19 | args = parser.parse_args() 20 | 21 | outputs_dir = Path(args.outputs) 22 | 23 | # load model 24 | infer = Infer(args.model, args.config, args.metadata) 25 | device = infer.device 26 | 27 | # support audio formats 28 | audio_formats = ['mp3', 'wav', 'ogg'] 29 | 30 | # make outputs directory if not exists 31 | if not outputs_dir.exists(): 32 | outputs_dir.mkdir() 33 | 34 | if args.task == 'recon': 35 | print("Task: Audio Reconstruction") 36 | # audio reconstruction task 37 | spk = args.speaker 38 | 39 | # get input path 40 | inputs_dir = Path(args.inputs) 41 | inputs = [] 42 | 43 | # load files 44 | for fmt in audio_formats: 45 | for path in inputs_dir.glob(f"*.{fmt}"): 46 | inputs.append(path) 47 | 48 | # inference 49 | for path in inputs: 50 | print(f"Inferencing {path}") 51 | 52 | # load audio 53 | wf, sr = torchaudio.load(path) 54 | 55 | # resample 56 | if sr != infer.sample_rate: 57 | wf = resample(wf, sr) 58 | 59 | # infer 60 | spk = args.speaker 61 | wf = infer.audio_reconstruction(wf, spk).cpu() 62 | 63 | # save 64 | save_path = outputs_dir / (path.stem + ".wav") 65 | torchaudio.save(save_path, wf, infer.sample_rate) 66 | 67 | elif args.task == 'tts': 68 | print("Task: Text to Speech") 69 | 70 | # get input path 71 | inputs_dir = Path(args.inputs) 72 | 73 | # load files 74 | inputs = [] 75 | for path in inputs_dir.glob("*.json"): 76 | inputs.append(path) 77 | 78 | # inference 79 | for path in inputs: 80 | print(f"Inferencing {path}") 81 | t = json.load(open(path, encoding='utf-8')) 82 | for k, v in zip(t.keys(), t.values()): 83 | print(f" Inferencing {k}") 84 | wf = infer.text_to_speech(**v).cpu() 85 | 86 | # save 87 | save_path = outputs_dir / (f"{path.stem}_{k}.wav") 88 | torchaudio.save(save_path, wf, infer.sample_rate) 89 | 90 | elif args.task == 'svs': 91 | print("Task: Singing Voice Synthesis") 92 | 93 | inputs_dir = Path(args.inputs) 94 | 95 | # load score 96 | inputs = [] 97 | for path in inputs_dir.glob("*.json"): 98 | inputs.append(path) 99 | 100 | # inference 101 | for path in inputs: 102 | print(f"Inferencing {path}") 103 | score = json.load(open(path, encoding='utf-8')) 104 | wf = infer.singing_voice_synthesis(score) 105 | 106 | # save 107 | save_path = outputs_dir / (path.stem + ".wav") 108 | torchaudio.save(save_path, wf, infer.sample_rate) 109 | 110 | -------------------------------------------------------------------------------- /infer_webui.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import torch 4 | import numpy as np 5 | 6 | from module.infer import Infer 7 | 8 | import gradio as gr 9 | 10 | parser = argparse.ArgumentParser(description="inference") 11 | parser.add_argument('-c', '--config', default='models/config.json') 12 | parser.add_argument('-t', '--task', choices=['tts', 'recon', 'svc', 'svs'], default='tts') 13 | parser.add_argument('-m', '--model', default='models/generator.safetensors') 14 | parser.add_argument('-meta', '--metadata', default='models/metadata.json') 15 | parser.add_argument('-p', '--port', default=7860, type=int) 16 | args = parser.parse_args() 17 | 18 | # load model 19 | infer = Infer(args.model, args.config, args.metadata) 20 | device = infer.device 21 | 22 | def text_to_speech(text, style_text, speaker, language, duration_scale, pitch_shift): 23 | if style_text == "": 24 | style_text = text 25 | wf = infer.text_to_speech(text, speaker, language, style_text, duration_scale, pitch_shift) 26 | wf = wf.squeeze(0).cpu().numpy() 27 | wf = (wf * 32768).astype(np.int16) 28 | sample_rate = infer.sample_rate 29 | return sample_rate, wf 30 | 31 | demo = gr.Interface(text_to_speech, inputs=[ 32 | gr.Text(label="Text"), 33 | gr.Text(label="Style"), 34 | gr.Dropdown(infer.speakers(), label="Speaker", value=infer.speakers()[0]), 35 | gr.Dropdown(infer.languages(), label="Language", value=infer.languages()[0]), 36 | gr.Slider(0.1, 3.0, 1.0, label="Duration Scale"), 37 | gr.Slider(-12.0, 12.0, 0.0, label="Pitch Shift"), 38 | ], 39 | outputs=[gr.Audio()]) 40 | 41 | demo.launch(debug=True, server_port=args.port) -------------------------------------------------------------------------------- /module/g2p/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Tuple 2 | import torch 3 | 4 | from .japanese import JapaneseExtractor 5 | #from .english import EnglishExtractor 6 | 7 | 8 | class G2PProcessor: 9 | def __init__(self): 10 | self.extractors = {} 11 | 12 | # If you want to add a language, add processing here 13 | # --- 14 | self.extractors['ja'] = JapaneseExtractor() 15 | #self.extractors['en'] = EnglishExtractor() 16 | # --- 17 | 18 | self.languages = [] 19 | phoneme_vocabs = [] 20 | for mod in self.extractors.values(): 21 | phoneme_vocabs += mod.possible_phonemes() 22 | self.languages += self.extractors.keys() 23 | self.phoneme_vocabs = [''] + phoneme_vocabs 24 | 25 | def grapheme_to_phoneme(self, text: Union[str, List[str]], language: Union[str, List[str]]): 26 | if type(text) == list: 27 | return self._g2p_multiple(text, language) 28 | elif type(text) == str: 29 | return self._g2p_single(text, language) 30 | 31 | def _g2p_single(self, text, language): 32 | mod = self.extractors[language] 33 | return mod.g2p(text) 34 | 35 | def _g2p_multiple(self, text, language): 36 | result = [] 37 | for t, l in zip(text, language): 38 | result.append(self._g2p_single(t, l)) 39 | return result 40 | 41 | def phoneme_to_id(self, phonemes: Union[List[str], List[List[str]]]): 42 | if type(phonemes[0]) == list: 43 | return self._p2id_multiple(phonemes) 44 | elif type(phonemes[0]) == str: 45 | return self._p2id_single(phonemes) 46 | 47 | def _p2id_single(self, phonemes: List[str]): 48 | ids = [] 49 | for p in phonemes: 50 | if p in self.phoneme_vocabs: 51 | ids.append(self.phoneme_vocabs.index(p)) 52 | else: 53 | print("warning: unknown phoneme.") 54 | ids.append(0) 55 | return ids 56 | 57 | def _p2id_multiple(self, phonemes: List[List[str]]): 58 | sequences = [] 59 | for s in phonemes: 60 | out = self._p2id_single(s) 61 | sequences.append(out) 62 | return sequences 63 | 64 | def language_to_id(self, languages: Union[str, List[str]]): 65 | if type(languages) == str: 66 | return self._l2id_single(languages) 67 | elif type(languages) == list: 68 | return self._l2id_multiple(languages) 69 | 70 | def _l2id_single(self, language): 71 | if language in self.languages: 72 | return self.languages.index(language) 73 | else: 74 | return 0 75 | 76 | def _l2id_multiple(self, languages): 77 | result = [] 78 | for l in languages: 79 | result.append(self._l2id_single(l)) 80 | return result 81 | 82 | def id_to_phoneme(self, ids): 83 | if type(ids[0]) == list: 84 | return self._id2p_multiple(ids) 85 | elif type(ids[0]) == int: 86 | return self._id2p_single(ids) 87 | 88 | def _id2p_single(self, ids: List[int]) -> List[str]: 89 | phonemes = [] 90 | for i in ids: 91 | if i < len(self.phoneme_vocabs): 92 | p = self.phoneme_vocabs[i] 93 | else: 94 | p = '' 95 | phonemes.append(p) 96 | return phonemes 97 | 98 | def _id2p_multiple(self, ids: List[List[int]]) -> List[List[str]]: 99 | results = [] 100 | for s in ids: 101 | results.append(self._id2p_single(s)) 102 | return results 103 | 104 | def encode(self, sentences: List[str], languages: List[str], max_length: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 105 | ids, lengths = self._enc_multiple(sentences, languages, max_length) 106 | language_ids = self.language_to_id(languages) 107 | 108 | ids = torch.LongTensor(ids) 109 | lengths = torch.LongTensor(lengths) 110 | language_ids = torch.LongTensor(language_ids) 111 | 112 | return ids, lengths, language_ids 113 | 114 | def _enc_single(self, sentence, language, max_length): 115 | phonemes = self.grapheme_to_phoneme(sentence, language) 116 | ids = self.phoneme_to_id(phonemes) 117 | length = min(len(ids), max_length) 118 | if len(ids) > max_length: 119 | ids = ids[:max_length] 120 | while len(ids) < max_length: 121 | ids.append(0) 122 | return ids, length 123 | 124 | def _enc_multiple(self, sentences, languages, max_length): 125 | seq, lengths = [], [] 126 | for s, l in zip(sentences, languages): 127 | ids, length = self._enc_single(s, l, max_length) 128 | seq.append(ids) 129 | lengths.append(length) 130 | return seq, lengths 131 | 132 | -------------------------------------------------------------------------------- /module/g2p/english.py: -------------------------------------------------------------------------------- 1 | from .extractor import PhoneticExtractor 2 | from g2p_en import G2p 3 | 4 | 5 | class EnglishExtractor(PhoneticExtractor): 6 | def __init__(self): 7 | super().__init__() 8 | self.g2p_instance = G2p() 9 | 10 | def g2p(self, text): 11 | return self.g2p_instance(text) 12 | 13 | def possible_phonemes(self): 14 | phonemes = self.g2p_instance.phonemes 15 | phonemes.remove('') 16 | return phonemes 17 | -------------------------------------------------------------------------------- /module/g2p/extractor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | # grapheme-to-phoneme module base class 5 | # Create a module by inheriting this class for each language. 6 | class PhoneticExtractor(): 7 | def __init__(self, *args, **kwargs): 8 | pass 9 | 10 | # grapheme to phoneme 11 | def g2p(self, text: str) -> List[str]: 12 | raise "Not Implemented" 13 | 14 | def possible_phonemes(self) -> List[str]: 15 | raise "Not Implemented" 16 | -------------------------------------------------------------------------------- /module/g2p/japanese.py: -------------------------------------------------------------------------------- 1 | from .extractor import PhoneticExtractor 2 | import pyopenjtalk 3 | from typing import List 4 | 5 | 6 | class JapaneseExtractor(PhoneticExtractor): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def g2p(self, text) -> List[str]: 11 | phonemes = pyopenjtalk.g2p(text).split(" ") 12 | new_phonemes = [] 13 | for p in phonemes: 14 | if p == 'pau': 15 | new_phonemes.append('') 16 | else: 17 | new_phonemes.append(p) 18 | return new_phonemes 19 | 20 | def possible_phonemes(self): 21 | return ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 22 | 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky', 23 | 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 24 | 't', 'ts', 'ty', 'u', 'v', 'w', 'y', 'z'] 25 | -------------------------------------------------------------------------------- /module/infer/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | import math 5 | import json 6 | from pathlib import Path 7 | 8 | from safetensors import safe_open 9 | from safetensors.torch import save_file 10 | 11 | from module.vits import Generator, spectrogram 12 | from module.utils.safetensors import load_tensors 13 | from module.g2p import G2PProcessor 14 | from module.language_model import LanguageModel 15 | from module.utils.config import load_json_file 16 | 17 | 18 | class Infer: 19 | def __init__(self, safetensors_path, config_path, metadata_path, device=torch.device('cpu')): 20 | self.device = device 21 | self.config = load_json_file(config_path) 22 | self.metadata = load_json_file(metadata_path) 23 | self.g2p = G2PProcessor() 24 | self.lm = LanguageModel(self.config.language_model.type, self.config.language_model.options) 25 | 26 | # load generator 27 | generator = Generator(self.config.vits.generator) 28 | generator.load_state_dict(load_tensors(safetensors_path)) 29 | generator = generator.to(self.device) 30 | self.generator = generator 31 | 32 | self.max_lm_tokens = self.config.infer.max_lm_tokens 33 | self.max_phonemes = self.config.infer.max_phonemes 34 | self.max_frames = self.config.infer.max_frames 35 | 36 | self.n_fft = self.config.infer.n_fft 37 | self.frame_size = self.config.infer.frame_size 38 | self.sample_rate = self.config.infer.sample_rate 39 | 40 | def speakers(self): 41 | return self.metadata.speakers 42 | 43 | def speaker_id(self, speaker): 44 | return self.speakers().index(speaker) 45 | 46 | def languages(self): 47 | return self.g2p.languages 48 | 49 | def language_id(self, language): 50 | return self.g2p.language_to_id(language) 51 | 52 | @torch.inference_mode() 53 | def text_to_speech( 54 | self, 55 | text: str, 56 | speaker: str, 57 | language: str, 58 | style_text: Union[None, str] = None, 59 | duration_scale=1.0, 60 | pitch_shift=0.0, 61 | ): 62 | spk = torch.LongTensor([self.speaker_id(speaker)]) 63 | if style_text is None: 64 | style_text = text 65 | lm_feat, lm_feat_len = self.lm.encode([style_text], self.max_lm_tokens) 66 | phoneme, phoneme_len, lang = self.g2p.encode([text], [language], self.max_phonemes) 67 | 68 | device = self.device 69 | phoneme = phoneme.to(device) 70 | phoneme_len = phoneme_len.to(device) 71 | lm_feat = lm_feat.to(device) 72 | lm_feat_len = lm_feat_len.to(device) 73 | spk = spk.to(device) 74 | lang = lang.to(device) 75 | 76 | wf = self.generator.text_to_speech( 77 | phoneme, 78 | phoneme_len, 79 | lm_feat, 80 | lm_feat_len, 81 | lang, 82 | spk, 83 | duration_scale=duration_scale, 84 | pitch_shift=pitch_shift, 85 | ) 86 | return wf.squeeze(0) 87 | 88 | # wf: [Channels, Length] 89 | @torch.inference_mode() 90 | def audio_reconstruction(self, wf: torch.Tensor, speaker:str): 91 | spk = torch.LongTensor([self.speaker_id(speaker)]) 92 | wf = wf.sum(dim=0, keepdim=True) 93 | spec = spectrogram(wf, self.n_fft, self.frame_size) 94 | spec_len = torch.LongTensor([spec.shape[2]]) 95 | 96 | device = self.device 97 | spec = spec.to(device) 98 | spec_len = spec_len.to(device) 99 | spk = spk.to(device) 100 | 101 | wf = self.generator.audio_reconstruction(spec, spec_len, spk) 102 | return wf.squeeze(0) 103 | 104 | def singing_voice_synthesis(self, score): 105 | parts = score['parts'] 106 | for part_name, part in zip(parts.keys(), parts.values()): 107 | print(f"processing {part_name}") 108 | self._svs_generate_part(part) 109 | 110 | # TODO: コメントを英語にする、いつかやる。多分。 111 | def _svs_generate_part(self, part): 112 | language = part['language'] 113 | style_text = part['style_text'] 114 | speaker = part['speaker'] 115 | notes = part['notes'] 116 | 117 | # notes をonset でソート 118 | notes.sort(key=lambda x: x['onset']) 119 | 120 | # get begin and end time 121 | # 開始時刻[秒]と終了時刻[秒]をノート一覧から探す。もっとも小さいonsetが開始時刻で最も大きいoffsetが終了時刻。 122 | t_begin = None 123 | t_end = None 124 | # それと歌詞情報を取得する 125 | part_phonemes = [] # このパートの音素列 126 | note_phoneme_indices = [] # 各ノート毎の(開始index, 終了index) 127 | for note in notes: 128 | b = note['onset'] 129 | e = note['offset'] 130 | if t_begin is None: 131 | t_begin = b 132 | elif b < t_begin: 133 | t_begin = b 134 | if t_end is None: 135 | t_end = e 136 | elif b > t_end: 137 | t_end = e 138 | 139 | # 歌詞情報の処理 140 | note_phonemes = self.g2p.grapheme_to_phoneme(note['lyrics'], language) 141 | note_phoneme_indices.append((len(part_phonemes), len(part_phonemes) + len(note_phonemes) - 1)) 142 | part_phonemes.extend(note_phonemes) 143 | # 音素の数 144 | num_phonemes = len(part_phonemes) 145 | # パートの長さを求める 146 | part_length = t_end - t_begin 147 | # 1秒間に何フレームか 148 | fps = self.sample_rate / self.frame_size 149 | # 生成するフレーム数 150 | num_frames = math.ceil(part_length * fps) 151 | # ピッチ列のバッファ。この段階ではまだMIDIのスケール。-infに近い値で埋めておく。(self._midi2f0(-inf) = 0なので、発声がない区間を0Hzにしたい。) 152 | pitch = torch.full([num_frames], -1e10) 153 | # エネルギー列のバッファ。 これは初期値0 154 | energy = torch.full([num_frames], 0.0) 155 | # Duration 156 | duration = torch.full([num_phonemes], 0.0) 157 | # 話者をエンコードする 158 | speaker_id = self.speaker_id(speaker) 159 | speaker_id = torch.LongTensor([speaker_id]) 160 | spk = self.generator.speaker_embedding(speaker_id) 161 | # 言語をエンコードする 162 | lang_id = self.language_id(language) 163 | lang = torch.LongTensor([lang_id]) 164 | 165 | # 音素とテキスト. LMの特徴量をエンコードする 166 | phonemes = torch.LongTensor(self.g2p.phoneme_to_id(part_phonemes)).unsqueeze(1) 167 | phonemes_len = torch.LongTensor([phonemes.shape[1]]) 168 | lm_feat, lm_feat_len = self.lm.encode([style_text], self.max_lm_tokens) 169 | text_encoded, text_mean, text_logvar, text_mask = self.generator.prior_encoder.text_encoder(phonemes, phonemes_len, lm_feat, lm_feat_len, spk, lang) 170 | # durationを推定する 171 | log_dur = self.generator.prior_encoder.stochastic_duration_predictor(text_encoded, text_mask, g=spk, reverse=True) 172 | duration = torch.ceil(torch.exp(log_dur)).to(torch.long) 173 | 174 | # ノートごとに処理する 175 | for i, note in enumerate(notes): 176 | phoneme_begin, phoneme_end = note_phoneme_indices[i] 177 | 178 | # ノートの始点と終点をフレーム単位に変換 179 | onset = round((note['onset'] - t_begin) * fps) 180 | offset = round((note['offset'] - t_begin) * fps) 181 | 182 | # 代入 183 | pitch[onset:offset] = float(note['pitch']) 184 | energy[onset:offset] = float(note['energy']) 185 | 186 | # TODO: ビブラートとかフォールとかenergyの調整とか 187 | 188 | print(pitch) 189 | 190 | 191 | def _f02midi(self, f0): 192 | return torch.log2(f0 / 440.0) * 12.0 + 69.0 193 | 194 | def _midi2f0(self, n): 195 | return 440.0 * 2 ** ((n - 69.0) / 12.0) 196 | -------------------------------------------------------------------------------- /module/language_model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .rinna_roberta import RinnaRoBERTaExtractor 3 | 4 | 5 | def get_extractor(typename): 6 | if typename == "rinna_roberta": 7 | return RinnaRoBERTaExtractor 8 | else: 9 | raise "Unknown linguistic extractor type" 10 | 11 | 12 | class LanguageModel: 13 | def __init__(self, extractor_type, options): 14 | ext_constructor = get_extractor(extractor_type) 15 | self.extractor = ext_constructor(**options) 16 | 17 | def encode(self, sentences, max_length: int): 18 | if type(sentences) == list: 19 | return self._ext_multiple(sentences, max_length) 20 | elif type(sentences) == str: 21 | return self._ext_single(sentences, max_length) 22 | 23 | def _ext_single(self, sentence, max_length: int): 24 | features, length = self.extractor.extract(sentence) 25 | features = features.cpu() 26 | 27 | N, L, D = features.shape 28 | # add padding 29 | if L < max_length: 30 | pad = torch.zeros(N, max_length - L, D) 31 | features = torch.cat([pad, features], dim=1) 32 | # crop 33 | if L > max_length: 34 | features = features[:, :max_length, :] 35 | # length 36 | length = min(length, max_length) 37 | 38 | return features, length 39 | 40 | def _ext_multiple(self, sentences, max_length: int): 41 | lengths = [] 42 | features = [] 43 | for s in sentences: 44 | f, l = self._ext_single(s, max_length) 45 | features.append(f) 46 | lengths.append(l) 47 | features = torch.cat(features, dim=0) 48 | lengths = torch.LongTensor(lengths) 49 | return features, lengths 50 | -------------------------------------------------------------------------------- /module/language_model/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Tuple 3 | 4 | # base class of linguistic feature extractor 5 | class LinguisticExtractor: 6 | def __init__(self, *args, **kwargs): 7 | pass 8 | 9 | # Output: [1, length, lm_dim] 10 | def extract(self, str) -> Tuple[torch.Tensor, int]: 11 | pass 12 | -------------------------------------------------------------------------------- /module/language_model/rinna_roberta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM 3 | from .extractor import LinguisticExtractor 4 | 5 | 6 | # feature extractor using rinna/japanese-roberta-base 7 | # from: https://huggingface.co/rinna/japanese-roberta-base 8 | class RinnaRoBERTaExtractor(LinguisticExtractor): 9 | def __init__(self, hf_repo: str, layer: int, device=None): 10 | super().__init__() 11 | self.layer = layer 12 | 13 | if device is None: 14 | if torch.cuda.is_available(): # CUDA available 15 | device = torch.device('cuda') 16 | elif torch.backends.mps.is_available(): # on macos 17 | device = torch.device('mps') 18 | else: 19 | device = torch.device('cpu') 20 | else: 21 | device = torch.device(device) 22 | self.device = device 23 | 24 | self.tokenizer = AutoTokenizer.from_pretrained(hf_repo) 25 | self.tokenizer.do_lower_case = True 26 | self.model = AutoModelForMaskedLM.from_pretrained(hf_repo) 27 | self.model.to(device) 28 | 29 | def extract(self, text): 30 | # add [CLS] token 31 | text = "[CLS]" + text 32 | 33 | # tokenize 34 | tokens = self.tokenizer.tokenize(text) 35 | 36 | # convert to ids 37 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 38 | 39 | # convert to tensor 40 | token_tensor = torch.LongTensor([token_ids]).to(self.device) 41 | 42 | # provide position id explicitly 43 | position_ids = list(range(0, token_tensor.shape[1])) 44 | position_id_tensor = torch.LongTensor([position_ids]).to(self.device) 45 | 46 | with torch.no_grad(): 47 | outputs = self.model( 48 | input_ids=token_tensor, 49 | position_ids=position_id_tensor, 50 | output_hidden_states=True) 51 | features = outputs.hidden_states[self.layer] 52 | # return features and length 53 | return features, token_tensor.shape[1] 54 | -------------------------------------------------------------------------------- /module/preprocess/jvs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from .processor import Preprocessor 3 | from tqdm import tqdm 4 | 5 | 6 | def process_category(path: Path, category, processor: Preprocessor, speaker_name, config): 7 | print(f"Ppocessing {str(path)}") 8 | audio_dir = path / "wav24kHz16bit" 9 | transcription_path = path / "transcripts_utf8.txt" 10 | with open(transcription_path, encoding='utf-8') as f: 11 | transcription_text = f.read() 12 | 13 | counter = 0 14 | for metadata in tqdm(transcription_text.split("\n")): 15 | s = metadata.split(":") 16 | if len(s) >= 2: 17 | audio_file_name, transcription = s[0], s[1] 18 | audio_file_path = audio_dir / (audio_file_name + ".wav") 19 | if not audio_file_path.exists(): 20 | continue 21 | processor.write_cache( 22 | audio_file_path, 23 | transcription, 24 | 'ja', 25 | speaker_name, 26 | f"{category}_{counter}" 27 | ) 28 | counter += 1 29 | 30 | def preprocess_jvs(jvs_root: Path, config): 31 | processor = Preprocessor(config) 32 | cache_dir = Path(config['preprocess']['cache']) 33 | for subdir in jvs_root.glob("*/"): 34 | if subdir.is_dir(): 35 | print(f"Processing {subdir}") 36 | speaker_name = subdir.name 37 | process_category(subdir / "nonpara30", "nonpara30", processor, speaker_name, config) 38 | process_category(subdir / "parallel100", "paralell100", processor, speaker_name, config) 39 | -------------------------------------------------------------------------------- /module/preprocess/processor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torchaudio 5 | from torchaudio.functional import resample 6 | 7 | from module.g2p import G2PProcessor 8 | from module.language_model import LanguageModel 9 | from module.utils.f0_estimation import estimate_f0 10 | 11 | 12 | # for dataset preprocess 13 | class Preprocessor: 14 | def __init__(self, config): 15 | self.g2p = G2PProcessor() 16 | self.lm = LanguageModel(config.language_model.type, config.language_model.options) 17 | self.max_phonemes = config.preprocess.max_phonemes 18 | self.lm_max_tokens = config.preprocess.lm_max_tokens 19 | self.pitch_estimation = config.preprocess.pitch_estimation 20 | self.max_waveform_length = config.preprocess.max_waveform_length 21 | self.sample_rate = config.preprocess.sample_rate 22 | self.frame_size = config.preprocess.frame_size 23 | self.config = config 24 | 25 | def write_cache(self, waveform_path: Path, transcription: str, language: str, speaker_name: str, data_name: str): 26 | # load waveform file 27 | wf, sr = torchaudio.load(waveform_path) 28 | 29 | # resampling 30 | if sr != self.sample_rate: 31 | wf = resample(wf, sr, self.sample_rate) # [Channels, Length_wf] 32 | 33 | # mix down 34 | wf = wf.sum(dim=0) # [Length_wf] 35 | 36 | # get length frame size 37 | spec_len = torch.LongTensor([wf.shape[0] // self.frame_size]) 38 | 39 | # padding 40 | if wf.shape[0] < self.max_waveform_length: 41 | wf = torch.cat([wf, torch.zeros(self.max_waveform_length - wf.shape[0])]) 42 | 43 | # crop 44 | if wf.shape[0] > self.max_waveform_length: 45 | wf = wf[:self.max_waveform_length] 46 | 47 | wf = wf.unsqueeze(0) # [1, Length_wf] 48 | 49 | # estimate f0(pitch) 50 | f0 = estimate_f0(wf, self.sample_rate, self.frame_size, self.pitch_estimation) 51 | 52 | # get phonemes 53 | phonemes, phonemes_len, language = self.g2p.encode([transcription], [language], self.max_phonemes) 54 | 55 | # get lm features 56 | lm_feat, lm_feat_len = self.lm.encode([transcription], self.lm_max_tokens) 57 | 58 | # to dict 59 | metadata = { 60 | "spec_len": spec_len, 61 | "f0": f0, 62 | "phonemes": phonemes, 63 | "phonemes_len": phonemes_len, 64 | "language": language, 65 | "lm_feat": lm_feat, 66 | "lm_feat_len": lm_feat_len, 67 | } 68 | 69 | # get target dir. 70 | cache_dir = Path(self.config.preprocess.cache) 71 | subdir = cache_dir / speaker_name 72 | 73 | # check exists subdir 74 | if not subdir.exists(): 75 | subdir.mkdir() 76 | 77 | audio_path = subdir / (data_name + ".wav") 78 | metadata_path = subdir / (data_name + ".pt") 79 | 80 | # save 81 | torchaudio.save(audio_path, wf, self.sample_rate) 82 | torch.save(metadata, metadata_path) 83 | -------------------------------------------------------------------------------- /module/preprocess/scan.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from module.g2p import G2PProcessor 4 | 5 | 6 | # create metadata 7 | def scan_cache(config): 8 | cache_dir = Path(config.preprocess.cache) 9 | models_dir = Path("models") 10 | metadata_path = models_dir / "metadata.json" 11 | if not models_dir.exists(): 12 | models_dir.mkdir() 13 | 14 | speaker_names = [] 15 | for subdir in cache_dir.glob("*"): 16 | if subdir.is_dir(): 17 | speaker_names.append(subdir.name) 18 | speaker_names = sorted(speaker_names) 19 | g2p = G2PProcessor() 20 | phonemes = g2p.phoneme_vocabs 21 | languages = g2p.languages 22 | num_harmonics = config.vits.generator.decoder.num_harmonics 23 | sample_rate = config.vits.generator.decoder.sample_rate 24 | frame_size = config.vits.generator.decoder.frame_size 25 | metadata = { 26 | "speakers": speaker_names, # speaker list 27 | "phonemes": phonemes, 28 | "languages": languages, 29 | "num_harmonics": num_harmonics, 30 | "sample_rate": sample_rate, 31 | "frame_size": frame_size 32 | } 33 | 34 | with open(metadata_path, 'w') as f: 35 | json.dump(metadata, f) -------------------------------------------------------------------------------- /module/preprocess/wave_and_text.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from .processor import Preprocessor 3 | from tqdm import tqdm 4 | 5 | 6 | LANGUAGE = 'ja' 7 | 8 | def process_speaker(subdir: Path, processor): 9 | speaker_name = subdir.stem 10 | counter = 0 11 | for wave_file in tqdm(subdir.rglob("*.wav")): 12 | text_file = wave_file.parent / (wave_file.stem + ".txt") 13 | if text_file.exists(): 14 | with open(text_file) as f: 15 | text = f.read() 16 | else: 17 | continue 18 | processor.write_cache( 19 | wave_file, 20 | text, 21 | LANGUAGE, 22 | speaker_name, 23 | f"{speaker_name}_{counter}", 24 | ) 25 | counter += 1 26 | 27 | 28 | def preprocess_wave_and_text(root: Path, config): 29 | processor = Preprocessor(config) 30 | for subdir in root.glob("*/"): 31 | if subdir.is_dir(): 32 | print(f"Processing {subdir.stem}") 33 | process_speaker(subdir, processor) -------------------------------------------------------------------------------- /module/utils/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def init_weights(m, mean=0.0, std=0.01): 6 | classname = m.__class__.__name__ 7 | if classname.find("Conv") != -1: 8 | m.weight.data.normal_(mean, std) 9 | 10 | 11 | def get_padding(kernel_size, dilation=1): 12 | return int((kernel_size*dilation - dilation)/2) 13 | 14 | 15 | def convert_pad_shape(pad_shape): 16 | l = pad_shape[::-1] 17 | pad_shape = [item for sublist in l for item in sublist] 18 | return pad_shape 19 | -------------------------------------------------------------------------------- /module/utils/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Config: 5 | def __init__(self, **kwargs): 6 | for k, v in kwargs.items(): 7 | if type(v) == dict: 8 | v = Config(**v) 9 | self[k] = v 10 | 11 | def keys(self): 12 | return self.__dict__.keys() 13 | 14 | def items(self): 15 | return self.__dict__.items() 16 | 17 | def values(self): 18 | return self.__dict__.values() 19 | 20 | def __len__(self): 21 | return len(self.__dict__) 22 | 23 | def __getitem__(self, key): 24 | return getattr(self, key) 25 | 26 | def __setitem__(self, key, value): 27 | return setattr(self, key, value) 28 | 29 | def __contains__(self, key): 30 | return key in self.__dict__ 31 | 32 | def __repr__(self): 33 | return self.__dict__.__repr__() 34 | 35 | 36 | def load_json_file(path): 37 | return Config(**json.load(open(path, encoding='utf-8'))) 38 | -------------------------------------------------------------------------------- /module/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | import json 5 | from pathlib import Path 6 | import lightning as L 7 | from torch.utils.data import DataLoader, random_split 8 | 9 | 10 | class VitsDataset(torch.utils.data.Dataset): 11 | def __init__(self, cache_dir='dataset_cache', metadata='models/metadata.json'): 12 | super().__init__() 13 | self.root = Path(cache_dir) 14 | metadata = json.load(open(Path(metadata))) 15 | self.speakers = metadata['speakers'] 16 | self.audio_file_paths = [] 17 | self.speaker_ids = [] 18 | self.metadata_paths = [] 19 | for path in self.root.glob("*/*.wav"): 20 | self.audio_file_paths.append(path) 21 | spk = path.parent.name 22 | self.speaker_ids.append(self._speaker_id(spk)) 23 | metadata_path = path.parent / (path.stem + ".pt") 24 | self.metadata_paths.append(metadata_path) 25 | 26 | def _speaker_id(self, speaker: str) -> int: 27 | return self.speakers.index(speaker) 28 | 29 | def __getitem__(self, idx): 30 | speaker_id = self.speaker_ids[idx] 31 | wf, sr = torchaudio.load(self.audio_file_paths[idx]) 32 | metadata = torch.load(self.metadata_paths[idx]) 33 | 34 | spec_len = metadata['spec_len'].item() 35 | f0 = metadata['f0'].squeeze(0) 36 | phoneme = metadata['phonemes'].squeeze(0) 37 | phoneme_len = metadata['phonemes_len'].item() 38 | language = metadata['language'].item() 39 | lm_feat = metadata['lm_feat'].squeeze(0) 40 | lm_feat_len = metadata['lm_feat_len'].item() 41 | return wf, spec_len, speaker_id, f0, phoneme, phoneme_len, lm_feat, lm_feat_len, language 42 | 43 | def __len__(self): 44 | return len(self.audio_file_paths) 45 | 46 | 47 | class VitsDataModule(L.LightningDataModule): 48 | def __init__( 49 | self, 50 | cache_dir='dataset_cache', 51 | metadata='models/metadata.json', 52 | batch_size=1, 53 | num_workers=1, 54 | ): 55 | super().__init__() 56 | self.cache_dir = cache_dir 57 | self.metadata = metadata 58 | self.batch_size = batch_size 59 | self.num_workers = num_workers 60 | 61 | def train_dataloader(self): 62 | dataset = VitsDataset( 63 | self.cache_dir, 64 | self.metadata) 65 | dataloader = DataLoader( 66 | dataset, 67 | self.batch_size, 68 | shuffle=True, 69 | num_workers=self.num_workers, 70 | persistent_workers=(os.name=='nt')) 71 | return dataloader 72 | -------------------------------------------------------------------------------- /module/utils/energy_estimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # estimate energy from power spectrogram 6 | # input: 7 | # spec: [BatchSize, n_fft//2+1, Length] 8 | # output: 9 | # energy: [BatchSize, 1, Length] 10 | def estimate_energy(spec): 11 | fft_bin = spec.shape[1] 12 | return spec.max(dim=1, keepdim=True).values / fft_bin 13 | -------------------------------------------------------------------------------- /module/utils/f0_estimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchaudio.functional import resample 6 | from torchfcpe import spawn_bundled_infer_model 7 | 8 | import numpy as np 9 | import pyworld as pw 10 | 11 | 12 | def estimate_f0_dio(wf, sample_rate=48000, segment_size=960, f0_min=20, f0_max=20000): 13 | if wf.ndim == 1: 14 | device = wf.device 15 | signal = wf.detach().cpu().numpy() 16 | signal = signal.astype(np.double) 17 | _f0, t = pw.dio(signal, sample_rate, f0_floor=f0_min, f0_ceil=f0_max) 18 | f0 = pw.stonemask(signal, _f0, t, sample_rate) 19 | f0 = torch.from_numpy(f0).to(torch.float) 20 | f0 = f0.to(device) 21 | f0 = f0.unsqueeze(0).unsqueeze(0) 22 | f0 = F.interpolate(f0, wf.shape[0] // segment_size, mode='linear') 23 | f0 = f0.squeeze(0) 24 | return f0 25 | elif wf.ndim == 2: 26 | waves = wf.split(1, dim=0) 27 | pitchs = [estimate_f0_dio(wave[0], sample_rate, segment_size) for wave in waves] 28 | pitchs = torch.stack(pitchs, dim=0) 29 | return pitchs 30 | 31 | 32 | def estimate_f0_harvest(wf, sample_rate=4800, segment_size=960, f0_min=20, f0_max=20000): 33 | if wf.ndim == 1: 34 | device = wf.device 35 | signal = wf.detach().cpu().numpy() 36 | signal = signal.astype(np.double) 37 | f0, t = pw.harvest(signal, sample_rate, f0_floor=f0_min, f0_ceil=f0_max) 38 | f0 = torch.from_numpy(f0).to(torch.float) 39 | f0 = f0.to(device) 40 | f0 = f0.unsqueeze(0).unsqueeze(0) 41 | f0 = F.interpolate(f0, wf.shape[0] // segment_size, mode='linear') 42 | f0 = f0.squeeze(0) 43 | return f0 44 | elif wf.ndim == 2: 45 | waves = wf.split(1, dim=0) 46 | pitchs = [estimate_f0_harvest(wave[0], sample_rate, segment_size) for wave in waves] 47 | pitchs = torch.stack(pitchs, dim=0) 48 | return pitchs 49 | 50 | 51 | global torchfcpe_model 52 | torchfcpe_model = None 53 | def estimate_f0_fcpe(wf, sample_rate=48000, segment_size=960, f0_min=20, f0_max=20000): 54 | input_device = wf.device 55 | global torchfcpe_model 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | wf = wf.to(device) 58 | if torchfcpe_model is None: 59 | torchfcpe_model = spawn_bundled_infer_model(device) 60 | f0 = torchfcpe_model.infer(wf.unsqueeze(2), sample_rate) 61 | f0 = f0.transpose(1, 2) 62 | f0 = f0.to(input_device) 63 | return f0 64 | 65 | # wf: [BatchSize, Length] 66 | def estimate_f0(wf, sample_rate=48000, segment_size=960, algorithm='harvest'): 67 | l = wf.shape[1] 68 | if algorithm == 'harvest': 69 | f0 = estimate_f0_harvest(wf, sample_rate) 70 | elif algorithm == 'dio': 71 | f0 = estimate_f0_dio(wf, sample_rate) 72 | elif algorithm == 'fcpe': 73 | f0 = estimate_f0_fcpe(wf, sample_rate) 74 | return F.interpolate(f0, l // segment_size, mode='linear') 75 | -------------------------------------------------------------------------------- /module/utils/safetensors.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from safetensors import safe_open 4 | from safetensors.torch import save_file 5 | 6 | 7 | def save_tensors(tensors: Dict[str, torch.Tensor], path): 8 | save_file(tensors, path) 9 | 10 | def load_tensors(path) -> Dict[str, torch.Tensor]: 11 | tensors = {} 12 | with safe_open(path, framework="pt", device="cpu") as f: 13 | for key in f.keys(): 14 | tensors[key] = f.get_tensor(key) 15 | return tensors -------------------------------------------------------------------------------- /module/vits/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | import lightning as L 7 | 8 | from .generator import Generator 9 | from .discriminator import Discriminator 10 | from .duration_discriminator import DurationDiscriminator 11 | from .loss import mel_spectrogram_loss, generator_adversarial_loss, discriminator_adversarial_loss, feature_matching_loss, duration_discriminator_adversarial_loss, duration_generator_adversarial_loss 12 | from .crop import crop_features, crop_waveform, decide_crop_range 13 | from .spectrogram import spectrogram 14 | 15 | 16 | class Vits(L.LightningModule): 17 | def __init__( 18 | self, 19 | config, 20 | ): 21 | super().__init__() 22 | self.generator = Generator(config.generator) 23 | self.discriminator = Discriminator(config.discriminator) 24 | self.duration_discriminator = DurationDiscriminator(**config.duration_discriminator) 25 | self.config = config 26 | 27 | # disable automatic optimization 28 | self.automatic_optimization = False 29 | # save hyperparameters 30 | self.save_hyperparameters() 31 | 32 | def training_step(self, batch): 33 | wf, spec_len, spk, f0, phoneme, phoneme_len, lm_feat, lm_feat_len, lang = batch 34 | wf = wf.squeeze(1) # [Batch, WaveLength] 35 | # get optimizer 36 | opt_g, opt_d, opt_dd = self.optimizers() 37 | 38 | # spectrogram 39 | n_fft = self.generator.posterior_encoder.n_fft 40 | frame_size = self.generator.posterior_encoder.frame_size 41 | spec = spectrogram(wf, n_fft, frame_size) 42 | 43 | # decide crop range 44 | crop_range = decide_crop_range(spec.shape[2], self.config.segment_size) 45 | 46 | # crop real waveform 47 | real = crop_waveform(wf, crop_range, frame_size) 48 | 49 | # start tracking gradient G. 50 | self.toggle_optimizer(opt_g) 51 | 52 | # calculate loss 53 | lossG, loss_dict, (text_encoded, text_mask, fake_log_duration, real_log_duration, spk_emb, dsp_out, fake) = self.generator( 54 | spec, spec_len, phoneme, phoneme_len, lm_feat, lm_feat_len, f0, spk, lang, crop_range) 55 | 56 | loss_dsp = mel_spectrogram_loss(dsp_out, real) 57 | loss_mel = mel_spectrogram_loss(fake, real) 58 | logits_fake, fmap_fake = self.discriminator(fake) 59 | _, fmap_real = self.discriminator(real) 60 | loss_feat = feature_matching_loss(fmap_real, fmap_fake) 61 | loss_adv = generator_adversarial_loss(logits_fake) 62 | dur_logit_fake = self.duration_discriminator(text_encoded, text_mask, fake_log_duration, spk_emb) 63 | loss_dadv = duration_generator_adversarial_loss(dur_logit_fake, text_mask) 64 | 65 | lossG += loss_mel * 45.0 + loss_dsp + loss_feat + loss_adv + loss_dadv 66 | self.manual_backward(lossG) 67 | opt_g.step() 68 | opt_g.zero_grad() 69 | 70 | # stop tracking gradient G. 71 | self.untoggle_optimizer(opt_g) 72 | 73 | # start tracking gradient D. 74 | self.toggle_optimizer(opt_d) 75 | 76 | # calculate loss 77 | fake = fake.detach() 78 | logits_fake, _ = self.discriminator(fake) 79 | logits_real, _ = self.discriminator(real) 80 | 81 | lossD = discriminator_adversarial_loss(logits_real, logits_fake) 82 | self.manual_backward(lossD) 83 | opt_d.step() 84 | opt_d.zero_grad() 85 | 86 | # stop tracking gradient D. 87 | self.untoggle_optimizer(opt_d) 88 | 89 | # start tracking gradient Duration Discriminator 90 | self.toggle_optimizer(opt_dd) 91 | 92 | fake_log_duration = fake_log_duration.detach() 93 | real_log_duration = real_log_duration.detach() 94 | text_mask = text_mask.detach() 95 | text_encoded = text_encoded.detach() 96 | spk_emb = spk_emb.detach() 97 | dur_logit_real = self.duration_discriminator(text_encoded, text_mask, real_log_duration, spk_emb) 98 | dur_logit_fake = self.duration_discriminator(text_encoded, text_mask, fake_log_duration, spk_emb) 99 | 100 | lossDD = duration_discriminator_adversarial_loss(dur_logit_real, dur_logit_fake, text_mask) 101 | self.manual_backward(lossDD) 102 | opt_dd.step() 103 | opt_dd.zero_grad() 104 | 105 | # stop tracking gradient Duration Discriminator 106 | self.untoggle_optimizer(opt_dd) 107 | 108 | # write log 109 | loss_dict['Mel'] = loss_mel.item() 110 | loss_dict['Generator Adversarial'] = loss_adv.item() 111 | loss_dict['DSP'] = loss_dsp.item() 112 | loss_dict['Feature Matching'] = loss_feat.item() 113 | loss_dict['Discriminator Adversarial'] = lossD.item() 114 | loss_dict['Duration Generator Adversarial'] = loss_dadv.item() 115 | loss_dict['Duration Discriminator Adversarial'] = lossDD.item() 116 | 117 | for k, v in zip(loss_dict.keys(), loss_dict.values()): 118 | self.log(f"loss/{k}", v) 119 | 120 | def configure_optimizers(self): 121 | lr = self.config.optimizer.lr 122 | betas = self.config.optimizer.betas 123 | 124 | opt_g = optim.AdamW(self.generator.parameters(), lr, betas) 125 | opt_d = optim.AdamW(self.discriminator.parameters(), lr, betas) 126 | opt_dd = optim.AdamW(self.duration_discriminator.parameters(), lr, betas) 127 | return opt_g, opt_d, opt_dd 128 | -------------------------------------------------------------------------------- /module/vits/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .normalization import LayerNorm, GRN 6 | 7 | # ConvNeXt v2 8 | class ConvNeXtLayer(nn.Module): 9 | def __init__(self, channels=512, kernel_size=7, mlp_mul=3): 10 | super().__init__() 11 | padding = kernel_size // 2 12 | self.c1 = nn.Conv1d(channels, channels, kernel_size, 1, padding, groups=channels) 13 | self.norm = LayerNorm(channels) 14 | self.c2 = nn.Conv1d(channels, channels * mlp_mul, 1) 15 | self.grn = GRN(channels * mlp_mul) 16 | self.c3 = nn.Conv1d(channels * mlp_mul, channels, 1) 17 | 18 | # x: [batchsize, channels, length] 19 | def forward(self, x): 20 | res = x 21 | x = self.c1(x) 22 | x = self.norm(x) 23 | x = self.c2(x) 24 | x = F.gelu(x) 25 | x = self.grn(x) 26 | x = self.c3(x) 27 | x = x + res 28 | return x 29 | -------------------------------------------------------------------------------- /module/vits/crop.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | def decide_crop_range(max_length=500, frames=50): 5 | left = random.randint(0, max_length-frames) 6 | right = left + frames 7 | return (left, right) 8 | 9 | 10 | def crop_features(z, crop_range): 11 | left, right = crop_range[0], crop_range[1] 12 | return z[:, :, left:right] 13 | 14 | 15 | def crop_waveform(wf, crop_range, frame_size): 16 | left, right = crop_range[0], crop_range[1] 17 | return wf[:, left*frame_size:right*frame_size] 18 | 19 | -------------------------------------------------------------------------------- /module/vits/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.nn.utils.parametrizations import weight_norm 8 | from torch.nn.utils import remove_weight_norm 9 | from module.utils.common import get_padding, init_weights 10 | from .normalization import LayerNorm 11 | from .convnext import ConvNeXtLayer 12 | from .feature_retrieval import match_features 13 | 14 | 15 | # Oscillate harmonic signal 16 | # 17 | # Inputs --- 18 | # f0: [BatchSize, 1, Frames] 19 | # 20 | # frame_size: int 21 | # sample_rate: float or int 22 | # min_frequency: float 23 | # num_harmonics: int 24 | # 25 | # Output: [BatchSize, NumHarmonics+1, Length] 26 | # 27 | # length = Frames * frame_size 28 | @torch.cuda.amp.autocast(enabled=False) 29 | def oscillate_harmonics( 30 | f0, 31 | frame_size=960, 32 | sample_rate=48000, 33 | num_harmonics=0, 34 | min_frequency=20.0 35 | ): 36 | N = f0.shape[0] 37 | C = num_harmonics + 1 38 | Lf = f0.shape[2] 39 | Lw = Lf * frame_size 40 | 41 | device = f0.device 42 | 43 | # generate frequency of harmonics 44 | mul = (torch.arange(C, device=device) + 1).unsqueeze(0).unsqueeze(2) 45 | 46 | # change length to wave's 47 | fs = F.interpolate(f0, Lw, mode='linear') * mul 48 | 49 | # unvoiced / voiced mask 50 | uv = (f0 > min_frequency).to(torch.float) 51 | uv = F.interpolate(uv, Lw, mode='linear') 52 | 53 | # generate harmonics 54 | I = torch.cumsum(fs / sample_rate, dim=2) # numerical integration 55 | theta = 2 * math.pi * (I % 1) # convert to radians 56 | 57 | harmonics = torch.sin(theta) * uv 58 | 59 | return harmonics.to(device) 60 | 61 | 62 | # Oscillate noise via gaussian noise and equalizer 63 | # 64 | # fft_bin = n_fft // 2 + 1 65 | # kernels: [BatchSize, fft_bin, Frames] 66 | # 67 | # Output: [BatchSize, 1, Frames * frame_size] 68 | @torch.cuda.amp.autocast(enabled=False) 69 | def oscillate_noise(kernels, frame_size=960, n_fft=3840): 70 | device = kernels.device 71 | N = kernels.shape[0] 72 | Lf = kernels.shape[2] # frame length 73 | Lw = Lf * frame_size # waveform length 74 | dtype = kernels.dtype 75 | 76 | gaussian_noise = torch.randn(N, Lw, device=device, dtype=torch.float) 77 | kernels = kernels.to(torch.float) # to fp32 78 | 79 | # calculate convolution in fourier-domain 80 | # Since the input is an aperiodic signal such as Gaussian noise, 81 | # there is no need to consider the phase on the kernel side. 82 | w = torch.hann_window(n_fft, dtype=torch.float, device=device) 83 | noise_stft = torch.stft(gaussian_noise, n_fft, frame_size, window=w, return_complex=True)[:, :, 1:] 84 | y_stft = noise_stft * kernels # In fourier domain, Multiplication means convolution. 85 | y_stft = F.pad(y_stft, [1, 0]) # pad 86 | y = torch.istft(y_stft, n_fft, frame_size, window=w) 87 | y = y.unsqueeze(1) 88 | y = y.to(dtype) 89 | return y 90 | 91 | 92 | # this model estimates fundamental frequency (f0) and energy from z 93 | class PitchEnergyEstimator(nn.Module): 94 | def __init__(self, 95 | content_channels=192, 96 | speaker_embedding_dim=256, 97 | internal_channels=256, 98 | num_classes=512, 99 | kernel_size=7, 100 | num_layers=6, 101 | mlp_mul=3, 102 | min_frequency=20.0, 103 | classes_per_octave=48, 104 | ): 105 | super().__init__() 106 | self.content_channels = content_channels 107 | self.num_classes = num_classes 108 | self.classes_per_octave = classes_per_octave 109 | self.min_frequency = min_frequency 110 | 111 | self.content_input = nn.Conv1d(content_channels, internal_channels, 1) 112 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1) 113 | self.input_norm = LayerNorm(internal_channels) 114 | self.mid_layers = nn.Sequential( 115 | *[ConvNeXtLayer(internal_channels, kernel_size, mlp_mul) for _ in range(num_layers)]) 116 | self.output_norm = LayerNorm(internal_channels) 117 | self.to_f0_logits = nn.Conv1d(internal_channels, num_classes, 1) 118 | self.to_energy = nn.Conv1d(internal_channels, 1, 1) 119 | 120 | # z_p: [BatchSize, content_channels, Length] 121 | # spk: [BatchSize, speaker_embedding_dim, Length] 122 | def forward(self, z_p, spk): 123 | x = self.content_input(z_p) + self.speaker_input(spk) 124 | x = self.input_norm(x) 125 | x = self.mid_layers(x) 126 | x = self.output_norm(x) 127 | logits, energy = self.to_f0_logits(x), self.to_energy(x) 128 | energy = F.elu(energy) + 1.0 129 | return logits, energy 130 | 131 | # f: [] 132 | def freq2id(self, f): 133 | fmin = self.min_frequency 134 | cpo = self.classes_per_octave 135 | nc = self.num_classes 136 | return torch.ceil(torch.clamp(cpo * torch.log2(f / fmin), 0, nc-1)).to(torch.long) 137 | 138 | # ids: [] 139 | def id2freq(self, ids): 140 | fmin = self.min_frequency 141 | cpo = self.classes_per_octave 142 | x = ids.to(torch.float) 143 | x = fmin * (2 ** (x / cpo)) 144 | x[x <= self.min_frequency] = 0 145 | return x 146 | 147 | # z_p: [BatchSize, content_channels, Length] 148 | # spk: [BatchSize, speaker_embedding_dim, Length] 149 | # Outputs: 150 | # f0: [BatchSize, 1, Length] 151 | # energy: [BatchSize, 1, Length] 152 | def infer(self, z_p, spk, k=4): 153 | logits, energy = self.forward(z_p, spk) 154 | probs, indices = torch.topk(logits, k, dim=1) 155 | probs = F.softmax(probs, dim=1) 156 | freqs = self.id2freq(indices) 157 | f0 = (probs * freqs).sum(dim=1, keepdim=True) 158 | f0[f0 <= self.min_frequency] = 0 159 | return f0, energy 160 | 161 | 162 | class FiLM(nn.Module): 163 | def __init__(self, in_channels, condition_channels): 164 | super().__init__() 165 | self.to_shift = weight_norm(nn.Conv1d(condition_channels, in_channels, 1)) 166 | self.to_scale = weight_norm(nn.Conv1d(condition_channels, in_channels, 1)) 167 | 168 | # x: [BatchSize, in_channels, Length] 169 | # c: [BatchSize, condition_channels, Length] 170 | def forward(self, x, c): 171 | shift = self.to_shift(c) 172 | scale = self.to_scale(c) 173 | return x * scale + shift 174 | 175 | def remove_weight_norm(self): 176 | remove_weight_norm(self.to_shift) 177 | remove_weight_norm(self.to_scale) 178 | 179 | 180 | class SourceNet(nn.Module): 181 | def __init__( 182 | self, 183 | sample_rate=48000, 184 | n_fft=3840, 185 | frame_size=960, 186 | content_channels=192, 187 | speaker_embedding_dim=256, 188 | internal_channels=512, 189 | num_harmonics=30, 190 | kernel_size=7, 191 | num_layers=6, 192 | mlp_mul=3 193 | ): 194 | super().__init__() 195 | self.sample_rate = sample_rate 196 | self.n_fft = n_fft 197 | self.frame_size = frame_size 198 | self.num_harmonics = num_harmonics 199 | 200 | self.content_input = nn.Conv1d(content_channels, internal_channels, 1) 201 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1) 202 | self.energy_input = nn.Conv1d(1, internal_channels, 1) 203 | self.f0_input = nn.Conv1d(1, internal_channels, 1) 204 | self.input_norm = LayerNorm(internal_channels) 205 | self.mid_layers = nn.Sequential( 206 | *[ConvNeXtLayer(internal_channels, kernel_size, mlp_mul) for _ in range(num_layers)]) 207 | self.output_norm = LayerNorm(internal_channels) 208 | self.to_amps = nn.Conv1d(internal_channels, num_harmonics + 1, 1) 209 | self.to_kernels = nn.Conv1d(internal_channels, n_fft // 2 + 1, 1) 210 | 211 | # x: [BatchSize, content_channels, Length] 212 | # f0: [BatchSize, 1, Length] 213 | # energy: [BatchSize, 1, Length] 214 | # spk: [BatchSize, speaker_embedding_dim, 1] 215 | # Outputs: 216 | # amps: [BatchSize, num_harmonics+1, Length * frame_size] 217 | # kernels: [BatchSize, n_fft //2 + 1, Length] 218 | def amps_and_kernels(self, x, f0, energy, spk): 219 | x = self.content_input(x) + self.speaker_input(spk) + self.f0_input(torch.log(F.relu(f0) + 1e-6)) 220 | x = x * self.energy_input(energy) 221 | x = self.input_norm(x) 222 | x = self.mid_layers(x) 223 | x = self.output_norm(x) 224 | amps = F.elu(self.to_amps(x)) + 1.0 225 | kernels = F.elu(self.to_kernels(x)) + 1.0 226 | return amps, kernels 227 | 228 | 229 | # x: [BatchSize, content_channels, Length] 230 | # f0: [BatchSize, 1, Length] 231 | # spk: [BatchSize, speaker_embedding_dim, 1] 232 | # Outputs: 233 | # dsp_out: [BatchSize, 1, Length * frame_size] 234 | # source: [BatchSize, 1, Length * frame_size] 235 | def forward(self, x, f0, energy, spk): 236 | amps, kernels = self.amps_and_kernels(x, f0, energy, spk) 237 | 238 | # oscillate source signals 239 | harmonics = oscillate_harmonics(f0, self.frame_size, self.sample_rate, self.num_harmonics) 240 | amps = F.interpolate(amps, scale_factor=self.frame_size, mode='linear') 241 | harmonics = harmonics * amps 242 | noise = oscillate_noise(kernels, self.frame_size, self.n_fft) 243 | source = torch.cat([harmonics, noise], dim=1) 244 | dsp_out = torch.sum(source, dim=1, keepdim=True) 245 | 246 | return dsp_out, source 247 | 248 | 249 | # HiFi-GAN's ResBlock1 250 | class ResBlock1(nn.Module): 251 | def __init__(self, channels, condition_channels, kernel_size=3, dilations=[1, 3, 5]): 252 | super().__init__() 253 | self.convs1 = nn.ModuleList([]) 254 | self.convs2 = nn.ModuleList([]) 255 | self.films = nn.ModuleList([]) 256 | 257 | for d in dilations: 258 | padding = get_padding(kernel_size, 1) 259 | self.convs1.append( 260 | weight_norm( 261 | nn.Conv1d(channels, channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate'))) 262 | padding = get_padding(kernel_size, d) 263 | self.convs2.append( 264 | weight_norm( 265 | nn.Conv1d(channels, channels, kernel_size, 1, padding, 1, padding_mode='replicate'))) 266 | self.films.append( 267 | FiLM(channels, condition_channels)) 268 | self.convs1.apply(init_weights) 269 | self.convs2.apply(init_weights) 270 | self.films.apply(init_weights) 271 | 272 | # x: [BatchSize, channels, Length] 273 | # c: [BatchSize, condition_channels, Length] 274 | def forward(self, x, c): 275 | for c1, c2, film in zip(self.convs1, self.convs2, self.films): 276 | res = x 277 | x = F.leaky_relu(x, 0.1) 278 | x = c1(x) 279 | x = F.leaky_relu(x, 0.1) 280 | x = c2(x) 281 | x = film(x, c) 282 | x = x + res 283 | return x 284 | 285 | def remove_weight_norm(self): 286 | for c1, c2, film in zip(self.convs1, self.convs2, self.films): 287 | remove_weight_norm(c1) 288 | remove_weight_norm(c2) 289 | film.remove_weight_norm() 290 | 291 | 292 | # HiFi-GAN's ResBlock2 293 | class ResBlock2(nn.Module): 294 | def __init__(self, channels, condition_channels, kernel_size=3, dilations=[1, 3]): 295 | super().__init__() 296 | self.convs = nn.ModuleList([]) 297 | self.films = nn.ModuleList([]) 298 | for d in dilations: 299 | padding = get_padding(kernel_size, d) 300 | self.convs.append( 301 | weight_norm( 302 | nn.Conv1d(channels, channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate'))) 303 | self.films.append(FiLM(channels, condition_channels)) 304 | self.convs.apply(init_weights) 305 | self.films.apply(init_weights) 306 | 307 | # x: [BatchSize, channels, Length] 308 | # c: [BatchSize, condition_channels, Length] 309 | def forward(self, x, c): 310 | for conv, film in zip(self.convs, self.films): 311 | res = x 312 | x = F.leaky_relu(x, 0.1) 313 | x = conv(x) 314 | x = film(x, c) 315 | x = x + res 316 | return x 317 | 318 | def remove_weight_norm(self): 319 | for conv, film in zip(self.convs, self.films): 320 | conv.remove_weight_norm() 321 | film.remove_weight_norm() 322 | 323 | 324 | # TinyVC's Block (from https://github.com/uthree/tinyvc) 325 | class ResBlock3(nn.Module): 326 | def __init__(self, channels, condition_channels, kernel_size=3, dilations=[1, 3, 9, 27]): 327 | super().__init__() 328 | assert len(dilations) == 4, "Resblock 3's len(dilations) should be 4." 329 | self.convs = nn.ModuleList([]) 330 | self.films = nn.ModuleList([]) 331 | for d in dilations: 332 | padding = get_padding(kernel_size, d) 333 | self.convs.append( 334 | weight_norm( 335 | nn.Conv1d(channels, channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate'))) 336 | for _ in range(2): 337 | self.films.append( 338 | FiLM(channels, condition_channels)) 339 | 340 | # x: [BatchSize, channels, Length] 341 | # c: [BatchSize, condition_channels, Length] 342 | def forward(self, x, c): 343 | res = x 344 | x = F.leaky_relu(x, 0.1) 345 | x = self.convs[0](x) 346 | x = F.leaky_relu(x, 0.1) 347 | x = self.convs[1](x) 348 | x = self.films[0](x, c) 349 | x = x + res 350 | 351 | res = x 352 | x = F.leaky_relu(x, 0.1) 353 | x = self.convs[2](x) 354 | x = F.leaky_relu(x, 0.1) 355 | x = self.convs[3](x) 356 | x = self.films[1](x, c) 357 | x = x + res 358 | 359 | return x 360 | 361 | def remove_weight_norm(self): 362 | for conv in self.convs: 363 | remove_weight_norm(conv) 364 | for film in self.films: 365 | film.remove_weight_norm() 366 | 367 | 368 | class MRF(nn.Module): 369 | def __init__(self, 370 | channels, 371 | condition_channels, 372 | resblock_type='1', 373 | kernel_sizes=[3, 7, 11], 374 | dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]]): 375 | super().__init__() 376 | self.blocks = nn.ModuleList([]) 377 | self.num_blocks = len(kernel_sizes) 378 | if resblock_type == '1': 379 | block = ResBlock1 380 | elif resblock_type == '2': 381 | block = ResBlock2 382 | elif resblock_type == '3': 383 | block = ResBlock3 384 | for k, d in zip(kernel_sizes, dilations): 385 | self.blocks.append(block(channels, condition_channels, k, d)) 386 | 387 | # x: [BatchSize, channels, Length] 388 | # c: [BatchSize, condition_channels, Length] 389 | def forward(self, x, c): 390 | xs = None 391 | for block in self.blocks: 392 | if xs is None: 393 | xs = block(x, c) 394 | else: 395 | xs += block(x, c) 396 | return xs / self.num_blocks 397 | 398 | def remove_weight_norm(self): 399 | for block in self.blocks: 400 | block.remove_weight_norm() 401 | 402 | 403 | class UpBlock(nn.Module): 404 | def __init__(self, 405 | in_channels, 406 | out_channels, 407 | condition_channels, 408 | factor, 409 | resblock_type='1', 410 | kernel_sizes=[3, 7, 11], 411 | dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 412 | interpolation='conv'): 413 | super().__init__() 414 | self.MRF = MRF(out_channels, condition_channels, resblock_type, kernel_sizes, dilations) 415 | self.interpolation = interpolation 416 | self.factor = factor 417 | if interpolation == 'conv': 418 | self.up_conv = weight_norm( 419 | nn.ConvTranspose1d(in_channels, out_channels, factor*2, factor)) 420 | self.pad_left = factor // 2 421 | self.pad_right = factor - self.pad_left 422 | elif interpolation == 'linear': 423 | self.up_conv = weight_norm(nn.Conv1d(in_channels, out_channels, 1)) 424 | 425 | # x: [BatchSize, in_channels, Length] 426 | # c: [BatchSize, condition_channels, Length(upsampled)] 427 | # Output: [BatchSize, out_channels, Length(upsampled)] 428 | def forward(self, x, c): 429 | x = F.leaky_relu(x, 0.1) 430 | if self.interpolation == 'conv': 431 | x = self.up_conv(x) 432 | x = x[:, :, self.pad_left:-self.pad_right] 433 | elif self.interpolation == 'linear': 434 | x = self.up_conv(x) 435 | x = F.interpolate(x, scale_factor=self.factor) 436 | x = self.MRF(x, c) 437 | return x 438 | 439 | def remove_weight_norm(self): 440 | remove_weight_norm(self.up_conv) 441 | self.MRF.remove_weight_norm() 442 | 443 | 444 | class DownBlock(nn.Module): 445 | def __init__(self, 446 | in_channels, 447 | out_channels, 448 | factor, 449 | dilations=[[1, 2], [4, 8]], 450 | kernel_size=3, 451 | interpolation='conv'): 452 | super().__init__() 453 | pad_left = factor // 2 454 | pad_right = factor - pad_left 455 | self.pad = nn.ReplicationPad1d([pad_left, pad_right]) 456 | self.interpolation = interpolation 457 | self.factor = factor 458 | if interpolation == 'conv': 459 | self.input_conv = weight_norm(nn.Conv1d(in_channels, out_channels, factor*2, factor)) 460 | elif interpolation == 'linear': 461 | self.input_conv = weight_norm(nn.Conv1d(in_channels, out_channels, 1)) 462 | 463 | self.convs = nn.ModuleList([]) 464 | for ds in dilations: 465 | cs = nn.ModuleList([]) 466 | for d in ds: 467 | padding = get_padding(kernel_size, d) 468 | cs.append( 469 | weight_norm( 470 | nn.Conv1d(out_channels, out_channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate'))) 471 | self.convs.append(cs) 472 | 473 | 474 | # x: [BatchSize, in_channels, Length] 475 | # Output: [BatchSize, out_channels, Length] 476 | def forward(self, x): 477 | if self.interpolation == 'conv': 478 | x = self.pad(x) 479 | x = self.input_conv(x) 480 | elif self.interpolation == 'linear': 481 | x = self.pad(x) 482 | x = F.avg_pool1d(x, self.factor*2, self.factor) # approximation of linear interpolation 483 | x = self.input_conv(x) 484 | for block in self.convs: 485 | res = x 486 | for c in block: 487 | x = F.leaky_relu(x, 0.1) 488 | x = c(x) 489 | x = x + res 490 | return x 491 | 492 | def remove_weight_norm(self): 493 | for block in self.convs: 494 | for c in block: 495 | remove_weight_norm(c) 496 | remove_weight_norm(self.output_conv) 497 | 498 | 499 | class FilterNet(nn.Module): 500 | def __init__(self, 501 | content_channels=192, 502 | speaker_embedding_dim=256, 503 | channels=[512, 256, 128, 64, 32], 504 | resblock_type='1', 505 | factors=[5, 4, 4, 4, 3], 506 | up_dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 507 | up_kernel_sizes=[3, 7, 11], 508 | up_interpolation='conv', 509 | down_dilations=[[1, 2], [4, 8]], 510 | down_kernel_size=3, 511 | down_interpolation='conv', 512 | num_harmonics=30, 513 | ): 514 | super().__init__() 515 | # input layer 516 | self.content_input = weight_norm(nn.Conv1d(content_channels, channels[0], 1)) 517 | self.speaker_input = weight_norm(nn.Conv1d(speaker_embedding_dim, channels[0], 1)) 518 | self.energy_input = weight_norm(nn.Conv1d(1, channels[0], 1)) 519 | self.f0_input = weight_norm(nn.Conv1d(1, channels[0], 1)) 520 | 521 | # downsamples 522 | self.downs = nn.ModuleList([]) 523 | self.downs.append(weight_norm(nn.Conv1d(num_harmonics + 2, channels[-1], 1))) 524 | cs = list(reversed(channels[1:])) 525 | ns = cs[1:] + [channels[0]] 526 | fs = list(reversed(factors[1:])) 527 | for c, n, f, in zip(cs, ns, fs): 528 | self.downs.append(DownBlock(c, n, f, down_dilations, down_kernel_size, down_interpolation)) 529 | 530 | # upsamples 531 | self.ups = nn.ModuleList([]) 532 | cs = channels 533 | ns = channels[1:] + [channels[-1]] 534 | fs = factors 535 | for c, n, f in zip(cs, ns, fs): 536 | self.ups.append(UpBlock(c, n, c, f, resblock_type, up_kernel_sizes, up_dilations, up_interpolation)) 537 | self.output_layer = weight_norm( 538 | nn.Conv1d(channels[-1], 1, 7, 1, 3, padding_mode='replicate')) 539 | 540 | # content: [BatchSize, content_channels, Length(frame)] 541 | # f0: [BatchSize, 1, Length(frame)] 542 | # spk: [BatchSize, speaker_embedding_dim, 1] 543 | # source: [BatchSize, num_harmonics+2, Length(Waveform)] 544 | # Output: [Batchsize, 1, Length * frame_size] 545 | def forward(self, content, f0, energy, spk, source): 546 | x = self.content_input(content) + self.speaker_input(spk) + self.f0_input(torch.log(F.relu(f0) + 1e-6)) 547 | x = x * self.energy_input(energy) 548 | 549 | skips = [] 550 | for down in self.downs: 551 | source = down(source) 552 | skips.append(source) 553 | 554 | for up, s in zip(self.ups, reversed(skips)): 555 | x = up(x, s) 556 | x = self.output_layer(x) 557 | return x 558 | 559 | def remove_weight_norm(self): 560 | remove_weight_norm(self.content_input) 561 | remove_weight_norm(self.output_layer) 562 | remove_weight_norm(self.speaker_input) 563 | for down in self.downs: 564 | down.remove_weight_norm() 565 | for up in self.ups: 566 | up.remove_weight_norm() 567 | 568 | 569 | def pitch_estimation_loss(logits, label): 570 | num_classes = logits.shape[1] 571 | device = logits.device 572 | weight = torch.ones(num_classes, device=device) 573 | weight[0] = 1e-2 574 | return F.cross_entropy(logits, label, weight) 575 | 576 | 577 | class Decoder(nn.Module): 578 | def __init__(self, 579 | sample_rate=48000, 580 | frame_size=960, 581 | n_fft=3840, 582 | content_channels=192, 583 | speaker_embedding_dim=256, 584 | pe_internal_channels=256, 585 | pe_num_layers=6, 586 | source_internal_channels=512, 587 | source_num_layers=6, 588 | num_harmonics=30, 589 | filter_channels=[512, 256, 128, 64, 32], 590 | filter_factors=[5, 4, 4, 3, 2], 591 | filter_resblock_type='1', 592 | filter_down_dilations=[[1, 2], [4, 8]], 593 | filter_down_kernel_size=3, 594 | filter_down_interpolation='conv', 595 | filter_up_dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 596 | filter_up_kernel_sizes=[3, 7, 11], 597 | filter_up_interpolation='conv'): 598 | super().__init__() 599 | self.frame_size = frame_size 600 | self.n_fft = n_fft 601 | self.sample_rate = sample_rate 602 | self.num_harmonics = num_harmonics 603 | self.pitch_energy_estimator = PitchEnergyEstimator( 604 | content_channels=content_channels, 605 | speaker_embedding_dim=speaker_embedding_dim, 606 | num_layers=pe_num_layers, 607 | internal_channels=pe_internal_channels 608 | ) 609 | self.source_net = SourceNet( 610 | sample_rate=sample_rate, 611 | frame_size=frame_size, 612 | n_fft=n_fft, 613 | num_harmonics=num_harmonics, 614 | content_channels=content_channels, 615 | speaker_embedding_dim=speaker_embedding_dim, 616 | internal_channels=source_internal_channels, 617 | num_layers=source_num_layers, 618 | ) 619 | self.filter_net = FilterNet( 620 | content_channels=content_channels, 621 | speaker_embedding_dim=speaker_embedding_dim, 622 | channels=filter_channels, 623 | resblock_type=filter_resblock_type, 624 | factors=filter_factors, 625 | up_dilations=filter_up_dilations, 626 | up_kernel_sizes=filter_up_kernel_sizes, 627 | up_interpolation=filter_up_interpolation, 628 | down_dilations=filter_down_dilations, 629 | down_kernel_size=filter_down_kernel_size, 630 | down_interpolation=filter_down_interpolation, 631 | num_harmonics=num_harmonics 632 | ) 633 | # training pass 634 | # 635 | # content: [BatchSize, content_channels, Length] 636 | # f0: [BatchSize, 1, Length] 637 | # spk: [BatchSize, speaker_embedding_dim, 1] 638 | # 639 | # Outputs: 640 | # f0_logits [BatchSize, num_f0_classes, Length] 641 | # estimated_energy: [BatchSize, 1, Length] 642 | # dsp_out: [BatchSize, Length * frame_size] 643 | # output: [BatchSize, Length * frame_size] 644 | def forward(self, content, f0, energy, spk): 645 | # estimate energy, f0 646 | f0_logits, estimated_energy = self.pitch_energy_estimator(content, spk) 647 | f0_label = self.pitch_energy_estimator.freq2id(f0).squeeze(1) 648 | loss_pe = pitch_estimation_loss(f0_logits, f0_label) 649 | loss_ee = (estimated_energy - energy).abs().mean() 650 | loss = loss_pe * 45.0 + loss_ee * 45.0 651 | loss_dict = { 652 | "Pitch Estimation": loss_pe.item(), 653 | "Energy Estimation": loss_ee.item() 654 | } 655 | 656 | # source net 657 | dsp_out, source = self.source_net(content, f0, energy, spk) 658 | dsp_out = dsp_out.squeeze(1) 659 | 660 | # GAN output 661 | output = self.filter_net(content, f0, energy, spk, source) 662 | output = output.squeeze(1) 663 | 664 | return dsp_out, output, loss, loss_dict 665 | 666 | # inference pass 667 | # 668 | # content: [BatchSize, content_channels, Length] 669 | # f0: [BatchSize, 1, Length] 670 | # Output: [BatchSize, 1, Length * frame_size] 671 | # energy: [BatchSize, 1, Length] 672 | # f0: [BatchSize, 1, Length] 673 | # reference: None or [BatchSize, content_channels, NumReferenceVectors] 674 | # alpha: float 0 ~ 1.0 675 | # k: int 676 | def infer(self, content, spk, energy=None, f0=None, reference=None, alpha=0, k=4, metrics='cos'): 677 | if energy is None or f0 is None: 678 | f0_est, energy_est = self.pitch_energy_estimator.infer(content, spk) 679 | if energy is None: 680 | energy = energy_est 681 | if f0 is None: 682 | f0 = f0_est 683 | 684 | # run feature retrieval if got reference vectors 685 | if reference is not None: 686 | content = match_features(content, reference, k, alpha, metrics) 687 | 688 | dsp_out, source = self.source_net(content, f0, energy, spk) 689 | 690 | # filter network 691 | output = self.filter_net(content, f0, energy, spk, source) 692 | return output 693 | 694 | def estimate_pitch_energy(self, content, spk): 695 | f0, energy = self.pitch_energy_estimator.infer(content, spk) 696 | return f0, energy 697 | -------------------------------------------------------------------------------- /module/vits/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from module.utils.common import get_padding 6 | 7 | 8 | class DiscriminatorP(nn.Module): 9 | def __init__(self, period, kernel_size=5, stride=3, channels=32, channels_mul=4, num_layers=4, max_channels=1024, use_spectral_norm=False): 10 | super().__init__() 11 | self.period = period 12 | norm_f = nn.utils.parametrizations.weight_norm if use_spectral_norm == False else nn.utils.parametrizations.spectral_norm 13 | 14 | k = kernel_size 15 | s = stride 16 | c = channels 17 | 18 | convs = [nn.Conv2d(1, c, (k, 1), (s, 1), (get_padding(5, 1), 0), padding_mode='replicate')] 19 | for i in range(num_layers): 20 | c_next = min(c * channels_mul, max_channels) 21 | convs.append(nn.Conv2d(c, c_next, (k, 1), (s, 1), (get_padding(5, 1), 0), padding_mode='replicate')) 22 | c = c_next 23 | self.convs = nn.ModuleList([norm_f(c) for c in convs]) 24 | self.post = norm_f(nn.Conv2d(c, 1, (3, 1), 1, (1, 0), padding_mode='replicate')) 25 | 26 | def forward(self, x): 27 | fmap = [] 28 | 29 | # 1d to 2d 30 | b, c, t = x.shape 31 | if t % self.period != 0: 32 | n_pad = self.period - (t % self.period) 33 | x = F.pad(x, (0, n_pad), "reflect") 34 | t = t + n_pad 35 | x = x.view(b, c, t // self.period, self.period) 36 | 37 | for l in self.convs: 38 | x = l(x) 39 | x = F.leaky_relu(x, 0.1) 40 | fmap.append(x) 41 | x = self.post(x) 42 | fmap.append(x) 43 | return x, fmap 44 | 45 | 46 | class MultiPeriodicDiscriminator(nn.Module): 47 | def __init__( 48 | self, 49 | periods=[1, 2, 3, 5, 7, 11], 50 | channels=32, 51 | channels_mul=4, 52 | max_channels=1024, 53 | num_layers=4, 54 | ): 55 | super().__init__() 56 | self.sub_discs = nn.ModuleList([]) 57 | for p in periods: 58 | self.sub_discs.append(DiscriminatorP(p, 59 | channels=channels, 60 | channels_mul=channels_mul, 61 | max_channels=max_channels, 62 | num_layers=num_layers)) 63 | 64 | def forward(self, x): 65 | x = x.unsqueeze(1) 66 | feats = [] 67 | logits = [] 68 | for d in self.sub_discs: 69 | logit, fmap = d(x) 70 | logits.append(logit) 71 | feats += fmap 72 | return logits, feats 73 | 74 | 75 | class DiscriminatorR(nn.Module): 76 | def __init__(self, resolution=128, channels=16, num_layers=4, max_channels=256): 77 | super().__init__() 78 | norm_f = nn.utils.weight_norm 79 | self.convs = nn.ModuleList([norm_f(nn.Conv2d(1, channels, (7, 3), (2, 1), (3, 1)))]) 80 | self.hop_size = resolution 81 | self.n_fft = resolution * 4 82 | c = channels 83 | for _ in range(num_layers): 84 | c_next = min(c * 2, max_channels) 85 | self.convs.append(norm_f(nn.Conv2d(c, c_next, (5, 3), (2, 1), (2, 1)))) 86 | c = c_next 87 | self.post = norm_f(nn.Conv2d(c, 1, 3, 1, 1)) 88 | 89 | def forward(self, x): 90 | w = torch.hann_window(self.n_fft).to(x.device) 91 | x = torch.stft(x, self.n_fft, self.hop_size, window=w, return_complex=True).abs() 92 | x = x.unsqueeze(1) 93 | feats = [] 94 | for l in self.convs: 95 | x = l(x) 96 | F.leaky_relu(x, 0.1) 97 | feats.append(x) 98 | x = self.post(x) 99 | feats.append(x) 100 | return x, feats 101 | 102 | 103 | class MultiResolutionDiscriminator(nn.Module): 104 | def __init__( 105 | self, 106 | resolutions=[128, 256, 512], 107 | channels=32, 108 | num_layers=4, 109 | max_channels=256, 110 | ): 111 | super().__init__() 112 | self.sub_discs = nn.ModuleList([]) 113 | for r in resolutions: 114 | self.sub_discs.append(DiscriminatorR(r, channels, num_layers, max_channels)) 115 | 116 | def forward(self, x): 117 | feats = [] 118 | logits = [] 119 | for d in self.sub_discs: 120 | logit, fmap = d(x) 121 | logits.append(logit) 122 | feats += fmap 123 | return logits, feats 124 | 125 | 126 | class Discriminator(nn.Module): 127 | def __init__(self, config): 128 | super().__init__() 129 | self.MPD = MultiPeriodicDiscriminator(**config.mpd) 130 | self.MRD = MultiResolutionDiscriminator(**config.mrd) 131 | 132 | # x: [BatchSize, Length(waveform)] 133 | def forward(self, x): 134 | mpd_logits, mpd_feats = self.MPD(x) 135 | mrd_logits, mrd_feats = self.MRD(x) 136 | return mpd_logits + mrd_logits, mpd_feats + mrd_feats 137 | -------------------------------------------------------------------------------- /module/vits/duration_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .normalization import LayerNorm 4 | from .convnext import ConvNeXtLayer 5 | 6 | class DurationDiscriminator(nn.Module): 7 | def __init__( 8 | self, 9 | content_channels=192, 10 | speaker_embedding_dim=256, 11 | internal_channels=192, 12 | num_layers=3, 13 | ): 14 | super().__init__() 15 | self.text_input = nn.Conv1d(content_channels, internal_channels, 1) 16 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1) 17 | self.duration_input = nn.Conv1d(1, internal_channels, 1) 18 | self.input_norm = LayerNorm(internal_channels) 19 | self.mid_layers = nn.Sequential(*[ConvNeXtLayer(internal_channels) for _ in range(num_layers)]) 20 | self.output_norm = LayerNorm(internal_channels) 21 | self.output_layer = nn.Conv1d(internal_channels, 1, 1) 22 | 23 | def forward( 24 | self, 25 | text_encoded, 26 | text_mask, 27 | log_duration, 28 | spk, 29 | ): 30 | x = self.text_input(text_encoded) + self.duration_input(log_duration) + self.speaker_input(spk) 31 | x = self.input_norm(x) 32 | x = self.mid_layers(x) * text_mask 33 | x = self.output_norm(x) 34 | x = self.output_layer(x) * text_mask 35 | return x 36 | -------------------------------------------------------------------------------- /module/vits/duration_predictors.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .transforms import piecewise_rational_quadratic_transform 7 | from .normalization import LayerNorm 8 | from .convnext import ConvNeXtLayer 9 | 10 | 11 | DEFAULT_MIN_BIN_WIDTH = 1e-3 12 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 13 | DEFAULT_MIN_DERIVATIVE = 1e-3 14 | 15 | 16 | class Flip(nn.Module): 17 | def forward(self, x, *args, reverse=False, **kwargs): 18 | x = torch.flip(x, [1]) 19 | if not reverse: 20 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 21 | return x, logdet 22 | else: 23 | return x 24 | 25 | 26 | class StochasticDurationPredictor(nn.Module): 27 | def __init__( 28 | self, 29 | in_channels=192, 30 | filter_channels=256, 31 | kernel_size=5, 32 | p_dropout=0.0, 33 | n_flows=4, 34 | speaker_embedding_dim=256 35 | ): 36 | super().__init__() 37 | self.log_flow = Log() 38 | self.flows = nn.ModuleList() 39 | self.flows.append(ElementwiseAffine(2)) 40 | for i in range(n_flows): 41 | self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 42 | self.flows.append(Flip()) 43 | 44 | self.pre = nn.Linear(in_channels, filter_channels) 45 | self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 46 | self.proj = nn.Linear(filter_channels, filter_channels) 47 | 48 | self.post_pre = nn.Linear(1, filter_channels) 49 | self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 50 | self.post_proj = nn.Linear(filter_channels, filter_channels) 51 | 52 | self.post_flows = nn.ModuleList() 53 | self.post_flows.append(ElementwiseAffine(2)) 54 | for i in range(4): 55 | self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 56 | self.post_flows.append(Flip()) 57 | 58 | if speaker_embedding_dim != 0: 59 | self.cond = nn.Linear(speaker_embedding_dim, filter_channels) 60 | 61 | 62 | # x: [BatchSize, in_chanels, Length] 63 | # x_mask [BatchSize, 1, Length] 64 | # g: [BatchSize, speaker_embedding_dim, 1] 65 | # w: Optional, Training only shape=[BatchSize, 1, Length] 66 | # Output: [BatchSize, 1, Length] 67 | # 68 | # note: g is speaker embedding 69 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, w=None, g=None, reverse=False, noise_scale=1.0): 70 | x = torch.detach(x) 71 | x = self.pre(x.mT).mT 72 | if g is not None: 73 | g = torch.detach(g) 74 | x = x + self.cond(g.mT).mT 75 | x = self.convs(x, x_mask) 76 | x = self.proj(x.mT).mT * x_mask 77 | 78 | if not reverse: 79 | flows = self.flows 80 | assert w is not None 81 | 82 | logdet_tot_q = 0 83 | h_w = self.post_pre(w.mT).mT 84 | h_w = self.post_convs(h_w, x_mask) 85 | h_w = self.post_proj(h_w.mT).mT * x_mask 86 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 87 | z_q = e_q 88 | for flow in self.post_flows: 89 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 90 | logdet_tot_q += logdet_q 91 | z_u, z1 = torch.split(z_q, [1, 1], 1) 92 | u = torch.sigmoid(z_u) * x_mask 93 | z0 = (w - u) * x_mask 94 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) 95 | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q 96 | 97 | logdet_tot = 0 98 | z0, logdet = self.log_flow(z0, x_mask) 99 | logdet_tot += logdet 100 | z = torch.cat([z0, z1], 1) 101 | for flow in flows: 102 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 103 | logdet_tot = logdet_tot + logdet 104 | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot 105 | return nll + logq # [b] 106 | else: 107 | flows = list(reversed(self.flows)) 108 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 109 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 110 | for flow in flows: 111 | z = flow(z, x_mask, g=x, reverse=reverse) 112 | z0, z1 = torch.split(z, [1, 1], 1) 113 | logw = z0 114 | return logw 115 | 116 | 117 | class ConvFlow(nn.Module): 118 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 119 | super().__init__() 120 | self.filter_channels = filter_channels 121 | self.num_bins = num_bins 122 | self.tail_bound = tail_bound 123 | self.half_channels = in_channels // 2 124 | 125 | self.pre = nn.Linear(self.half_channels, filter_channels) 126 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) 127 | self.proj = nn.Linear(filter_channels, self.half_channels * (num_bins * 3 - 1)) 128 | self.proj.weight.data.zero_() 129 | self.proj.bias.data.zero_() 130 | 131 | def forward(self, x, x_mask, g=None, reverse=False): 132 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 133 | h = self.pre(x0.mT).mT 134 | h = self.convs(h, x_mask, g=g) 135 | h = self.proj(h.mT).mT * x_mask 136 | 137 | b, c, t = x0.shape 138 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 139 | 140 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) 141 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) 142 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 143 | 144 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, inverse=reverse, tails="linear", tail_bound=self.tail_bound) 145 | 146 | x = torch.cat([x0, x1], 1) * x_mask 147 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 148 | if not reverse: 149 | return x, logdet 150 | else: 151 | return x 152 | 153 | 154 | class DDSConv(nn.Module): 155 | """ 156 | Dilated and Depth-Separable Convolution 157 | """ 158 | 159 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): 160 | super().__init__() 161 | self.n_layers = n_layers 162 | 163 | self.drop = nn.Dropout(p_dropout) 164 | self.convs_sep = nn.ModuleList() 165 | self.linears = nn.ModuleList() 166 | self.norms_1 = nn.ModuleList() 167 | self.norms_2 = nn.ModuleList() 168 | for i in range(n_layers): 169 | dilation = kernel_size**i 170 | padding = (kernel_size * dilation - dilation) // 2 171 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)) 172 | self.linears.append(nn.Linear(channels, channels)) 173 | self.norms_1.append(LayerNorm(channels)) 174 | self.norms_2.append(LayerNorm(channels)) 175 | 176 | def forward(self, x, x_mask, g=None): 177 | if g is not None: 178 | x = x + g 179 | for i in range(self.n_layers): 180 | y = self.convs_sep[i](x * x_mask) 181 | y = self.norms_1[i](y) 182 | y = F.gelu(y) 183 | y = self.linears[i](y.mT).mT 184 | y = self.norms_2[i](y) 185 | y = F.gelu(y) 186 | y = self.drop(y) 187 | x = x + y 188 | return x * x_mask 189 | 190 | 191 | # TODO convert to class method 192 | class Log(nn.Module): 193 | def forward(self, x, x_mask, reverse=False, **kwargs): 194 | if not reverse: 195 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 196 | logdet = torch.sum(-y, [1, 2]) 197 | return y, logdet 198 | else: 199 | x = torch.exp(x) * x_mask 200 | return x 201 | 202 | 203 | class ElementwiseAffine(nn.Module): 204 | def __init__(self, channels): 205 | super().__init__() 206 | self.m = nn.Parameter(torch.zeros(channels, 1)) 207 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 208 | 209 | def forward(self, x, x_mask, reverse=False, **kwargs): 210 | if not reverse: 211 | y = self.m + torch.exp(self.logs) * x 212 | y = y * x_mask 213 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 214 | return y, logdet 215 | else: 216 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 217 | return x 218 | 219 | 220 | class DurationPredictor(nn.Module): 221 | def __init__(self, 222 | content_channels=192, 223 | internal_channels=256, 224 | speaker_embedding_dim=256, 225 | kernel_size=7, 226 | num_layers=4): 227 | super().__init__() 228 | self.phoneme_input = nn.Conv1d(content_channels, internal_channels, 1) 229 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1) 230 | self.input_norm = LayerNorm(internal_channels) 231 | self.mid_layers = nn.ModuleList() 232 | for _ in range(num_layers): 233 | self.mid_layers.append(ConvNeXtLayer(internal_channels, kernel_size)) 234 | self.output_norm = LayerNorm(internal_channels) 235 | self.output_layer = nn.Conv1d(internal_channels, 1, 1) 236 | 237 | def forward(self, x, x_mask, g): 238 | x = (self.phoneme_input(x) + self.speaker_input(g)) * x_mask 239 | x = self.input_norm(x) * x_mask 240 | for layer in self.mid_layers: 241 | x = layer(x) * x_mask 242 | x = self.output_norm(x) * x_mask 243 | x = self.output_layer(x) * x_mask 244 | x = F.relu(x) 245 | return x 246 | -------------------------------------------------------------------------------- /module/vits/feature_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Feature retrieval via kNN (like kNN-VC, RVC, etc...) 7 | # Warning: this method is not optimized. 8 | # Do not give long sequence. computing complexy is quadratic. 9 | # 10 | # source: [BatchSize, Channels, Length] 11 | # reference: [BatchSize, Channels, Length] 12 | # k: int 13 | # alpha: float (0.0 ~ 1.0) 14 | # metrics: one of ['IP', 'L2', 'cos'], 'IP' means innner product, 'L2' means euclid distance, 'cos' means cosine similarity 15 | # Output: [BatchSize, Channels, Length] 16 | def match_features(source, reference, k=4, alpha=0.0, metrics='cos'): 17 | input_data = source 18 | 19 | source = source.transpose(1, 2) 20 | reference = reference.transpose(1, 2) 21 | if metrics == 'IP': 22 | sims = torch.bmm(source, reference.transpose(1, 2)) 23 | elif metrics == 'L2': 24 | sims = -torch.cdist(source, reference) 25 | elif metrics == 'cos': 26 | reference_norm = torch.norm(reference, dim=2, keepdim=True, p=2) + 1e-6 27 | source_norm = torch.norm(source, dim=2, keepdim=True, p=2) + 1e-6 28 | sims = torch.bmm(source / source_norm, (reference / reference_norm).transpose(1, 2)) 29 | best = torch.topk(sims, k, dim=2) 30 | 31 | result = torch.stack([reference[n][best.indices[n]] for n in range(source.shape[0])], dim=0).mean(dim=2) 32 | result = result.transpose(1, 2) 33 | 34 | return result * (1-alpha) + input_data * alpha 35 | -------------------------------------------------------------------------------- /module/vits/flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .wn import WN 5 | 6 | 7 | class Flip(nn.Module): 8 | def forward(self, x, *args, **kwargs): 9 | x = torch.flip(x, [1]) 10 | return x 11 | 12 | 13 | class ResidualCouplingLayer(nn.Module): 14 | def __init__(self, 15 | in_channels=192, 16 | internal_channels=192, 17 | speaker_embedding_dim=256, 18 | kernel_size=5, 19 | dilation=1, 20 | num_layers=4): 21 | super().__init__() 22 | self.in_channels = in_channels 23 | self.half_channels = in_channels // 2 24 | 25 | self.pre = nn.Conv1d(self.half_channels, internal_channels, 1) 26 | self.wn = WN(internal_channels, kernel_size, dilation, num_layers, speaker_embedding_dim) 27 | self.post = nn.Conv1d(internal_channels, self.half_channels, 1) 28 | 29 | self.post.weight.data.zero_() 30 | self.post.bias.data.zero_() 31 | 32 | # x: [BatchSize, in_channels, Length] 33 | # x_mask: [BatchSize, 1, Length] 34 | # g: [BatchSize, speaker_embedding_dim, 1] 35 | def forward(self, x, x_mask, g, reverse=False): 36 | x_0, x_1 = torch.chunk(x, 2, dim=1) 37 | h = self.pre(x_0) * x_mask 38 | h = self.wn(h, x_mask, g) 39 | x_1_mean = self.post(h) * x_mask 40 | 41 | if not reverse: 42 | x_1 = x_1_mean + x_1 * x_mask 43 | else: 44 | x_1 = (x_1 - x_1_mean) * x_mask 45 | 46 | x = torch.cat([x_0, x_1], dim=1) 47 | return x 48 | 49 | 50 | class Flow(nn.Module): 51 | def __init__(self, 52 | content_channels=192, 53 | internal_channels=192, 54 | speaker_embedding_dim=256, 55 | kernel_size=5, 56 | dilation=1, 57 | num_flows=4, 58 | num_layers=4): 59 | super().__init__() 60 | 61 | self.flows = nn.ModuleList() 62 | for i in range(num_flows): 63 | self.flows.append( 64 | ResidualCouplingLayer( 65 | content_channels, 66 | internal_channels, 67 | speaker_embedding_dim, 68 | kernel_size, 69 | dilation, 70 | num_layers)) 71 | self.flows.append(Flip()) 72 | 73 | # z: [BatchSize, content_channels, Length] 74 | # z_mask: [BatchSize, 1, Length] 75 | # g: [Batchsize, speaker_embedding_dim, 1] 76 | # Output: [BatchSize, content_channels, Length] 77 | def forward(self, z, z_mask, g, reverse=False): 78 | if not reverse: 79 | for flow in self.flows: 80 | z = flow(z, z_mask, g, reverse=False) 81 | else: 82 | for flow in reversed(self.flows): 83 | z = flow(z, z_mask, g, reverse=True) 84 | return z 85 | -------------------------------------------------------------------------------- /module/vits/generator.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .decoder import Decoder 9 | from .posterior_encoder import PosteriorEncoder 10 | from .prior_encoder import PriorEncoder 11 | from .speaker_embedding import SpeakerEmbedding 12 | from .flow import Flow 13 | from .crop import crop_features 14 | from module.utils.energy_estimation import estimate_energy 15 | 16 | 17 | class Generator(nn.Module): 18 | # initialize from config 19 | def __init__(self, config): 20 | super().__init__() 21 | self.decoder = Decoder(**config.decoder) 22 | self.posterior_encoder = PosteriorEncoder(**config.posterior_encoder) 23 | self.prior_encoder = PriorEncoder(config.prior_encoder) 24 | self.speaker_embedding = SpeakerEmbedding(**config.speaker_embedding) 25 | 26 | # training pass 27 | # 28 | # spec: [BatchSize, fft_bin, Length] 29 | # spec_len: [BatchSize] 30 | # phone: [BatchSize, NumPhonemes] 31 | # phone_len: [BatchSize] 32 | # lm_feat: [BatchSize, lm_dim, NumLMfeatures] 33 | # lm_feat_len: [BatchSize] 34 | # f0: [Batchsize, 1, Length] 35 | # spk: [BatchSize] 36 | # lang: [BatchSize] 37 | # crop_range: Tuple[int, int] 38 | # 39 | # Outputs: 40 | # dsp_out: [BatchSize, Length * frame_size] 41 | # fake: [BatchSize, Length * frame_size] 42 | # lossG: [1] 43 | # loss_dict: Dict[str: float] 44 | # 45 | def forward( 46 | self, 47 | spec, 48 | spec_len, 49 | phoneme, 50 | phoneme_len, 51 | lm_feat, 52 | lm_feat_len, 53 | f0, 54 | spk, 55 | lang, 56 | crop_range 57 | ): 58 | 59 | spk = self.speaker_embedding(spk) 60 | z, m_q, logs_q, spec_mask = self.posterior_encoder.forward(spec, spec_len, spk) 61 | energy = estimate_energy(spec) 62 | loss_prior, loss_dict_prior, (text_encoded, text_mask, fake_log_duration, real_log_duration) = self.prior_encoder.forward(spec_mask, z, logs_q, phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk) 63 | 64 | z_crop = crop_features(z, crop_range) 65 | f0_crop = crop_features(f0, crop_range) 66 | energy_crop = crop_features(energy, crop_range) 67 | dsp_out, fake, loss_decoder, loss_dict_decoder = self.decoder.forward(z_crop, f0_crop, energy_crop, spk) 68 | 69 | loss_dict = (loss_dict_decoder | loss_dict_prior) # merge dict 70 | loss = loss_prior + loss_decoder 71 | 72 | return loss, loss_dict, (text_encoded, text_mask, fake_log_duration, real_log_duration, spk.detach(), dsp_out, fake) 73 | 74 | @torch.no_grad() 75 | def text_to_speech( 76 | self, 77 | phoneme, 78 | phoneme_len, 79 | lm_feat, 80 | lm_feat_len, 81 | lang, 82 | spk, 83 | noise_scale=0.6, 84 | max_frames=2000, 85 | use_sdp=True, 86 | duration_scale=1.0, 87 | pitch_shift=0.0, 88 | energy_scale=1.0, 89 | ): 90 | spk = self.speaker_embedding(spk) 91 | z = self.prior_encoder.text_to_speech( 92 | phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk, 93 | noise_scale=noise_scale, 94 | max_frames=max_frames, 95 | use_sdp=use_sdp, 96 | duration_scale=duration_scale) 97 | f0, energy = self.decoder.estimate_pitch_energy(z, spk) 98 | pitch = torch.log2((f0 + 1e-6) / 440.0) * 12.0 99 | pitch += pitch_shift 100 | f0 = 440.0 * 2 ** (pitch / 12.0) 101 | energy = energy * energy_scale 102 | fake = self.decoder.infer(z, spk, f0=f0, energy=energy) 103 | return fake 104 | 105 | @torch.no_grad() 106 | def audio_reconstruction(self, spec, spec_len, spk): 107 | # embed speaker 108 | spk = self.speaker_embedding(spk) 109 | 110 | # encode linear spectrogram and speaker infomation 111 | z, m_q, logs_q, spec_mask = self.posterior_encoder(spec, spec_len, spk) 112 | 113 | # decode 114 | fake = self.decoder.infer(z, spk) 115 | return fake 116 | 117 | @torch.no_grad() 118 | def singing_voice_conversion(self): 119 | pass # TODO: write this 120 | 121 | @torch.no_grad() 122 | def singing_voice_synthesis(self): 123 | pass # TODO: write this 124 | -------------------------------------------------------------------------------- /module/vits/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio 5 | 6 | 7 | def safe_log(x, eps=1e-6): 8 | return torch.log(x + eps) 9 | 10 | def multiscale_stft_loss(x, y, scales=[16, 32, 64, 128, 256, 512]): 11 | x = x.to(torch.float) 12 | y = y.to(torch.float) 13 | 14 | loss = 0 15 | num_scales = len(scales) 16 | for s in scales: 17 | hop_length = s 18 | n_fft = s * 4 19 | window = torch.hann_window(n_fft, device=x.device) 20 | x_spec = torch.stft(x, n_fft, hop_length, return_complex=True, window=window).abs() 21 | y_spec = torch.stft(y, n_fft, hop_length, return_complex=True, window=window).abs() 22 | 23 | x_spec[x_spec.isnan()] = 0 24 | x_spec[x_spec.isinf()] = 0 25 | y_spec[y_spec.isnan()] = 0 26 | y_spec[y_spec.isinf()] = 0 27 | 28 | loss += (safe_log(x_spec) - safe_log(y_spec)).abs().mean() 29 | return loss / num_scales 30 | 31 | 32 | global mel_spectrogram_modules 33 | mel_spectrogram_modules = {} 34 | def mel_spectrogram_loss(x, y, sample_rate=48000, n_fft=2048, hop_length=512, power=2.0, log=True): 35 | device = x.device 36 | if device not in mel_spectrogram_modules: 37 | mel_spectrogram_modules[device] = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, power=power).to(device) 38 | mel_spectrogram = mel_spectrogram_modules[device] 39 | 40 | x_mel = mel_spectrogram(x) 41 | y_mel = mel_spectrogram(y) 42 | if log: 43 | x_mel = safe_log(x_mel) 44 | y_mel = safe_log(y_mel) 45 | 46 | x_mel[x_mel.isnan()] = 0 47 | x_mel[x_mel.isinf()] = 0 48 | y_mel[y_mel.isnan()] = 0 49 | y_mel[y_mel.isinf()] = 0 50 | 51 | loss = F.l1_loss(x_mel, y_mel) 52 | return loss 53 | 54 | # 1 = fake, 0 = real 55 | def discriminator_adversarial_loss(real_outputs, fake_outputs): 56 | loss = 0 57 | n = min(len(real_outputs), len(fake_outputs)) 58 | for dr, df in zip(real_outputs, fake_outputs): 59 | dr = dr.float() 60 | df = df.float() 61 | real_loss = (dr ** 2).mean() 62 | fake_loss = ((df - 1) ** 2).mean() 63 | loss += real_loss + fake_loss 64 | return loss / n 65 | 66 | 67 | def generator_adversarial_loss(fake_outputs): 68 | loss = 0 69 | n = len(fake_outputs) 70 | for dg in fake_outputs: 71 | dg = dg.float() 72 | loss += (dg ** 2).mean() 73 | return loss / n 74 | 75 | 76 | def duration_discriminator_adversarial_loss(real_output, fake_output, text_mask): 77 | loss = (((fake_output - 1) ** 2) * text_mask).sum() / text_mask.sum() 78 | loss += ((real_output ** 2) * text_mask).sum() / text_mask.sum() 79 | return loss 80 | 81 | 82 | def duration_generator_adversarial_loss(fake_output, text_mask): 83 | loss = ((fake_output ** 2) * text_mask).sum() / text_mask.sum() 84 | return loss 85 | 86 | 87 | def feature_matching_loss(fmap_real, fmap_fake): 88 | loss = 0 89 | n = min(len(fmap_real), len(fmap_fake)) 90 | for r, f in zip(fmap_real, fmap_fake): 91 | f = f.float() 92 | r = r.float() 93 | loss += (f - r).abs().mean() 94 | return loss * (2 / n) 95 | -------------------------------------------------------------------------------- /module/vits/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | """Maximum path calculation module. 2 | 3 | This code is based on https://github.com/jaywalnut310/vits. 4 | 5 | """ 6 | 7 | import warnings 8 | 9 | import numpy as np 10 | import torch 11 | from numba import njit, prange 12 | 13 | try: 14 | from .core import maximum_path_c 15 | 16 | is_cython_avalable = True 17 | except ImportError: 18 | is_cython_avalable = False 19 | warnings.warn( 20 | "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " 21 | "If you want to use the cython version, please build it as follows: " 22 | "`cd auris/module/vits/monotonic_align; python setup.py build_ext --inplace`" 23 | ) 24 | 25 | 26 | def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: 27 | """Calculate maximum path. 28 | 29 | Args: 30 | neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). 31 | attn_mask (Tensor): Attention mask (B, T_feats, T_text). 32 | 33 | Returns: 34 | Tensor: Maximum path tensor (B, T_feats, T_text). 35 | 36 | """ 37 | device, dtype = neg_x_ent.device, neg_x_ent.dtype 38 | neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) 39 | path = np.zeros(neg_x_ent.shape, dtype=np.int32) 40 | t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) 41 | t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) 42 | if is_cython_avalable: 43 | maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) 44 | else: 45 | maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) 46 | 47 | return torch.from_numpy(path).to(device=device, dtype=dtype) 48 | 49 | 50 | @njit 51 | def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): 52 | """Calculate a single maximum path with numba.""" 53 | index = t_x - 1 54 | for y in range(t_y): 55 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 56 | if x == y: 57 | v_cur = max_neg_val 58 | else: 59 | v_cur = value[y - 1, x] 60 | if x == 0: 61 | if y == 0: 62 | v_prev = 0.0 63 | else: 64 | v_prev = max_neg_val 65 | else: 66 | v_prev = value[y - 1, x - 1] 67 | value[y, x] += max(v_prev, v_cur) 68 | 69 | for y in range(t_y - 1, -1, -1): 70 | path[y, index] = 1 71 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): 72 | index = index - 1 73 | 74 | 75 | @njit(parallel=True) 76 | def maximum_path_numba(paths, values, t_ys, t_xs): 77 | """Calculate batch maximum path with numba.""" 78 | for i in prange(paths.shape[0]): 79 | maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) 80 | -------------------------------------------------------------------------------- /module/vits/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | """Maximum path calculation module with cython optimization. 2 | 3 | This code is copied from https://github.com/jaywalnut310/vits and modifed code format. 4 | 5 | """ 6 | 7 | cimport cython 8 | 9 | from cython.parallel import prange 10 | 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 15 | cdef int x 16 | cdef int y 17 | cdef float v_prev 18 | cdef float v_cur 19 | cdef float tmp 20 | cdef int index = t_x - 1 21 | 22 | for y in range(t_y): 23 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 24 | if x == y: 25 | v_cur = max_neg_val 26 | else: 27 | v_cur = value[y - 1, x] 28 | if x == 0: 29 | if y == 0: 30 | v_prev = 0.0 31 | else: 32 | v_prev = max_neg_val 33 | else: 34 | v_prev = value[y - 1, x - 1] 35 | value[y, x] += max(v_prev, v_cur) 36 | 37 | for y in range(t_y - 1, -1, -1): 38 | path[y, index] = 1 39 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): 40 | index = index - 1 41 | 42 | 43 | @cython.boundscheck(False) 44 | @cython.wraparound(False) 45 | cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 46 | cdef int b = paths.shape[0] 47 | cdef int i 48 | for i in prange(b, nogil=True): 49 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 50 | -------------------------------------------------------------------------------- /module/vits/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | """Setup cython code.""" 2 | 3 | from Cython.Build import cythonize 4 | from setuptools import Extension, setup 5 | from setuptools.command.build_ext import build_ext as _build_ext 6 | 7 | 8 | class build_ext(_build_ext): 9 | """Overwrite build_ext.""" 10 | 11 | def finalize_options(self): 12 | """Prevent numpy from thinking it is still in its setup process.""" 13 | _build_ext.finalize_options(self) 14 | __builtins__.__NUMPY_SETUP__ = False 15 | import numpy 16 | 17 | self.include_dirs.append(numpy.get_include()) 18 | 19 | 20 | exts = [ 21 | Extension( 22 | name="core", 23 | sources=["core.pyx"], 24 | ) 25 | ] 26 | setup( 27 | name="monotonic_align", 28 | ext_modules=cythonize(exts, language_level=3), 29 | cmdclass={"build_ext": build_ext}, 30 | ) 31 | -------------------------------------------------------------------------------- /module/vits/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | def __init__(self, channels, eps=1e-5): 8 | super().__init__() 9 | self.channels = channels 10 | self.eps = eps 11 | 12 | self.gamma = nn.Parameter(torch.ones(channels)) 13 | self.beta = nn.Parameter(torch.zeros(channels)) 14 | 15 | # x: [BatchSize, cnannels, *] 16 | def forward(self, x: torch.Tensor): 17 | x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps) 18 | return x.mT 19 | 20 | 21 | # Global Resnponse Normalization for 1d Sequence (shape=[BatchSize, Channels, Length]) 22 | class GRN(nn.Module): 23 | def __init__(self, channels, eps=1e-6): 24 | super().__init__() 25 | self.beta = nn.Parameter(torch.zeros(1, channels, 1)) 26 | self.gamma = nn.Parameter(torch.zeros(1, channels, 1)) 27 | self.eps = eps 28 | 29 | # x: [batchsize, channels, length] 30 | def forward(self, x): 31 | gx = torch.norm(x, p=2, dim=2, keepdim=True) 32 | nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps) 33 | return self.gamma * (x * nx) + self.beta + x 34 | -------------------------------------------------------------------------------- /module/vits/posterior_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .wn import WN 4 | 5 | 6 | class PosteriorEncoder(nn.Module): 7 | def __init__(self, 8 | n_fft=3840, 9 | frame_size=960, 10 | internal_channels=192, 11 | speaker_embedding_dim=256, 12 | content_channels=192, 13 | kernel_size=5, 14 | dilation=1, 15 | num_layers=16): 16 | super().__init__() 17 | 18 | self.input_channels = n_fft // 2 + 1 19 | self.n_fft = n_fft 20 | self.frame_size = frame_size 21 | self.pre = nn.Conv1d(self.input_channels, internal_channels, 1) 22 | self.wn = WN(internal_channels, kernel_size, dilation, num_layers, speaker_embedding_dim) 23 | self.post = nn.Conv1d(internal_channels, content_channels * 2, 1) 24 | 25 | # x: [BatchSize, fft_bin, Length] 26 | # x_length: [BatchSize] 27 | # g: [BatchSize, speaker_embedding_dim, 1] 28 | # 29 | # Outputs: 30 | # z: [BatchSize, content_channels, Length] 31 | # mean: [BatchSize, content_channels, Length] 32 | # logvar: [BatchSize, content_channels, Length] 33 | # z_mask: [BatchSize, 1, Length] 34 | # 35 | # where fft_bin = input_channels = n_fft // 2 + 1 36 | def forward(self, x, x_length, g): 37 | # generate mask 38 | max_length = x.shape[2] 39 | progression = torch.arange(max_length, dtype=x_length.dtype, device=x_length.device) 40 | z_mask = (progression.unsqueeze(0) < x_length.unsqueeze(1)) 41 | z_mask = z_mask.unsqueeze(1).to(x.dtype) 42 | 43 | # pass network 44 | x = self.pre(x) * z_mask 45 | x = self.wn(x, z_mask, g) 46 | x = self.post(x) * z_mask 47 | mean, logvar = torch.chunk(x, 2, dim=1) 48 | z = mean + torch.randn_like(mean) * torch.exp(logvar) * z_mask 49 | return z, mean, logvar, z_mask 50 | -------------------------------------------------------------------------------- /module/vits/prior_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .duration_predictors import StochasticDurationPredictor, DurationPredictor 7 | from .monotonic_align import maximum_path 8 | from .text_encoder import TextEncoder 9 | from .flow import Flow 10 | from module.utils.energy_estimation import estimate_energy 11 | 12 | 13 | def sequence_mask(length, max_length=None): 14 | if max_length is None: 15 | max_length = length.max() 16 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 17 | return x.unsqueeze(0) < length.unsqueeze(1) 18 | 19 | 20 | def generate_path(duration, mask): 21 | """ 22 | duration: [b, 1, t_x] 23 | mask: [b, 1, t_y, t_x] 24 | """ 25 | b, _, t_y, t_x = mask.shape 26 | cum_duration = torch.cumsum(duration, -1) 27 | 28 | cum_duration_flat = cum_duration.view(b * t_x) 29 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 30 | path = path.view(b, t_x, t_y) 31 | 32 | padding_shape = [[0, 0], [1, 0], [0, 0]] 33 | padding = [item for sublist in padding_shape[::-1] for item in sublist] 34 | 35 | path = path - F.pad(path, padding)[:, :-1] 36 | path = path.unsqueeze(1).transpose(2,3) * mask 37 | return path 38 | 39 | 40 | def kl_divergence_loss(z_p, logs_q, m_p, logs_p, z_mask): 41 | z_p = z_p.float() 42 | logs_q = logs_q.float() 43 | m_p = m_p.float() 44 | logs_p = logs_p.float() 45 | z_mask = z_mask.float() 46 | 47 | kl = logs_p - logs_q - 0.5 48 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 49 | kl = torch.sum(kl * z_mask) 50 | l = kl / torch.sum(z_mask) 51 | return l 52 | 53 | 54 | # run Monotonic Alignment Search (MAS). 55 | # MAS associates phoneme sequences with sounds. 56 | # 57 | # z_p: [b, d, t'] 58 | # m_p: [b, d, t] 59 | # logs_p: [b, d, t] 60 | # text_mask: [b, 1, t] 61 | # spec_mask: [b, 1, t'] 62 | # Output: [b, 1, t', t] 63 | def search_path(z_p, m_p, logs_p, text_mask, spec_mask, mas_noise_scale=0.0): 64 | with torch.no_grad(): 65 | # calculate nodes 66 | # b = batch size, d = feature dim, t = text length, t' = spec length 67 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] 68 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, dim=1, keepdim=True) # [b, 1, t] 69 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).mT, s_p_sq_r) # [b, t', d] x [b, d, t] = [b, t', t] 70 | neg_cent3 = torch.matmul(z_p.mT, (m_p * s_p_sq_r)) # [b, t', s] x [b, d, t] = [b, t', t] 71 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, dim=1, keepdim=True) # [b, 1, t] 72 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 # [b, t', t] 73 | 74 | # add noise 75 | if mas_noise_scale > 0.0: 76 | eps = torch.std(neg_cent) * torch.randn_like(neg_cent) * mas_noise_scale 77 | neg_cent += eps 78 | 79 | # mask unnecessary nodes, run D.P. 80 | MAS_node_mask = text_mask.unsqueeze(2) * spec_mask.unsqueeze(-1) # [b, 1, 't, t] 81 | MAS_path = maximum_path(neg_cent, MAS_node_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, 't, t] 82 | return MAS_path 83 | 84 | 85 | class PriorEncoder(nn.Module): 86 | def __init__( 87 | self, 88 | config, 89 | ): 90 | super().__init__() 91 | self.flow = Flow(**config.flow) 92 | self.text_encoder = TextEncoder(**config.text_encoder) 93 | self.duration_predictor = DurationPredictor(**config.duration_predictor) 94 | self.stochastic_duration_predictor = StochasticDurationPredictor(**config.stochastic_duration_predictor) 95 | 96 | def forward(self, spec_mask, z, logs_q, phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk): 97 | # encode text 98 | text_encoded, m_p, logs_p, text_mask = self.text_encoder(phoneme, phoneme_len, lm_feat, lm_feat_len, spk, lang) 99 | 100 | # remove speaker infomation 101 | z_p = self.flow(z, spec_mask, spk) 102 | 103 | # search path 104 | MAS_path = search_path(z_p, m_p, logs_p, text_mask, spec_mask) 105 | 106 | # KL Divergence loss 107 | m_p = torch.matmul(MAS_path.squeeze(1), m_p.mT).mT 108 | logs_p = torch.matmul(MAS_path.squeeze(1), logs_p.mT).mT 109 | loss_kl = kl_divergence_loss(z_p, logs_q, m_p, logs_p, spec_mask) 110 | 111 | # calculate duration each phonemes 112 | duration = MAS_path.sum(2) 113 | loss_sdp = self.stochastic_duration_predictor( 114 | text_encoded, 115 | text_mask, 116 | w=duration, 117 | g=spk 118 | ).sum() / text_mask.sum() 119 | 120 | logw_y = torch.log(duration + 1e-6) * text_mask 121 | logw_x = self.duration_predictor(text_encoded, text_mask, spk) 122 | loss_dp = torch.sum(((logw_x - logw_y) ** 2) * text_mask) / torch.sum(text_mask) 123 | 124 | # predict duration 125 | fake_log_duration = self.duration_predictor(text_encoded, text_mask, spk) 126 | real_log_duration = logw_y 127 | 128 | loss_dict = { 129 | "StochasticDurationPredictor": loss_sdp.item(), 130 | "DurationPredictor": loss_dp.item(), 131 | "KL Divergence": loss_kl.item(), 132 | } 133 | 134 | loss = loss_sdp + loss_dp + loss_kl 135 | return loss, loss_dict, (text_encoded.detach(), text_mask, fake_log_duration, real_log_duration) 136 | 137 | def text_to_speech(self, phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk, noise_scale=0.6, max_frames=2000, use_sdp=True, duration_scale=1.0): 138 | # encode text 139 | text_encoded, m_p, logs_p, text_mask = self.text_encoder(phoneme, phoneme_len, lm_feat, lm_feat_len, spk, lang) 140 | 141 | # predict duration 142 | if use_sdp: 143 | log_duration = self.stochastic_duration_predictor(text_encoded, text_mask, g=spk, reverse=True) 144 | else: 145 | log_duration = self.duration_predictor(text_encoded, text_mask, spk) 146 | duration = torch.exp(log_duration) 147 | duration = duration * text_mask * duration_scale 148 | duration = torch.ceil(duration) 149 | 150 | spec_len = torch.clamp_min(torch.sum(duration, dim=(1, 2)), 1).long() 151 | spec_len = torch.clamp_max(spec_len, max_frames) 152 | spec_mask = sequence_mask(spec_len).unsqueeze(1).to(text_mask.dtype) 153 | 154 | MAS_node_mask = text_mask.unsqueeze(2) * spec_mask.unsqueeze(-1) 155 | MAS_path = generate_path(duration, MAS_node_mask).float() 156 | 157 | # projection 158 | m_p = torch.matmul(MAS_path.squeeze(1), m_p.mT).mT 159 | logs_p = torch.matmul(MAS_path.squeeze(1), logs_p.mT).mT 160 | 161 | # sample from gaussian 162 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 163 | 164 | # crop max frames 165 | if z_p.shape[2] > max_frames: 166 | z_p = z_p[:, :, :max_frames] 167 | 168 | # add speaker infomation 169 | z = self.flow(z_p, spec_mask, spk, reverse=True) 170 | 171 | return z -------------------------------------------------------------------------------- /module/vits/speaker_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpeakerEmbedding(nn.Module): 6 | def __init__(self, num_speakers=8192, embedding_dim=256): 7 | super().__init__() 8 | self.embedding = nn.Embedding(num_speakers, embedding_dim) 9 | 10 | # i: [BatchSize] 11 | # Output: [BatchSize, embedding_dim, 1] 12 | def forward(self, i): 13 | x = self.embedding(i) 14 | x = x.unsqueeze(2) 15 | return x 16 | -------------------------------------------------------------------------------- /module/vits/spectrogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # wave: [BatchSize, 1, Length] 7 | # Output: [BatchSize, 1, Frames] 8 | def spectrogram(wave, n_fft, hop_size, power=2.0): 9 | dtype = wave.dtype 10 | wave = wave.to(torch.float) 11 | window = torch.hann_window(n_fft, device=wave.device) 12 | spec = torch.stft(wave, n_fft, hop_size, return_complex=True, window=window).abs() 13 | spec = spec[:, :, 1:] 14 | spec = spec ** power 15 | spec = spec.to(dtype) 16 | return spec 17 | 18 | -------------------------------------------------------------------------------- /module/vits/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .transformer import RelativePositionTransformerDecoder 6 | 7 | 8 | class TextEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | num_phonemes=512, 12 | num_languages=256, 13 | internal_channels=256, 14 | speaker_embedding_dim=256, 15 | content_channels=192, 16 | n_heads=4, 17 | lm_dim=768, 18 | kernel_size=1, 19 | dropout=0.0, 20 | window_size=4, 21 | num_layers=4, 22 | ): 23 | super().__init__() 24 | self.lm_proj = nn.Linear(lm_dim, internal_channels) 25 | self.phoneme_embedding = nn.Embedding(num_phonemes, internal_channels) 26 | self.language_embedding = nn.Embedding(num_languages, internal_channels) 27 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1) 28 | self.transformer = RelativePositionTransformerDecoder( 29 | internal_channels, 30 | internal_channels * 4, 31 | n_heads, 32 | num_layers, 33 | kernel_size, 34 | dropout, 35 | window_size 36 | ) 37 | self.post = nn.Conv1d(internal_channels, content_channels * 2, 1) 38 | 39 | # Note: 40 | # x: Phoneme IDs 41 | # x_length: length of phoneme sequence, for generating mask 42 | # y: language model's output features 43 | # y_length: length of language model's output features, for generating mask 44 | # lang: Language ID 45 | # 46 | # x: [BatchSize, Length_x] 47 | # x_length: [BatchSize] 48 | # y: [BatchSize, Length_y, lm_dim] 49 | # y_length [BatchSize] 50 | # spk: [BatchSize, speaker_embedding_dim, 1] 51 | # lang: [BatchSize] 52 | # 53 | # Outputs: 54 | # z: [BatchSize, content_channels, Length_x] 55 | # mean: [BatchSize, content_channels, Length_x] 56 | # logvar: [BatchSize, content_channels, Length_x] 57 | # z_mask: [BatchSize, 1, Length_x] 58 | def forward(self, x, x_length, y, y_length, spk, lang): 59 | # generate mask 60 | # x mask 61 | max_length = x.shape[1] 62 | progression = torch.arange(max_length, dtype=x_length.dtype, device=x_length.device) 63 | x_mask = (progression.unsqueeze(0) < x_length.unsqueeze(1)) 64 | x_mask = x_mask.unsqueeze(1).to(y.dtype) 65 | z_mask = x_mask 66 | 67 | # y mask 68 | max_length = y.shape[1] 69 | progression = torch.arange(max_length, dtype=y_length.dtype, device=y_length.device) 70 | y_mask = (progression.unsqueeze(0) < y_length.unsqueeze(1)) 71 | y_mask = y_mask.unsqueeze(1).to(y.dtype) 72 | 73 | # pass network 74 | y = self.lm_proj(y).mT # [B, C, L_y] where C = internal_channels, L_y = Length_y, L_x = Length_x, B = BatchSize 75 | x = self.phoneme_embedding(x) # [B, L_x, C] 76 | lang = self.language_embedding(lang) # [B, C] 77 | lang = lang.unsqueeze(1) # [B, 1, C] 78 | x = x + lang # language conditioning 79 | x = x.mT # [B, C, L_x] 80 | x = x + self.speaker_input(spk) 81 | x = self.transformer(x, x_mask, y, y_mask) # [B, C, L_x] 82 | x = self.post(x) * x_mask # [B, 2C, L_x] 83 | mean, logvar = torch.chunk(x, 2, dim=1) 84 | z = mean + torch.randn_like(mean) * torch.exp(logvar) * z_mask 85 | return z, mean, logvar, z_mask 86 | -------------------------------------------------------------------------------- /module/vits/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from module.utils.common import convert_pad_shape 7 | from .normalization import LayerNorm 8 | 9 | 10 | class RelativePositionTransformerEncoder(nn.Module): 11 | def __init__( 12 | self, 13 | hidden_channels: int, 14 | hidden_channels_ffn: int, 15 | n_heads: int, 16 | n_layers: int, 17 | kernel_size=1, 18 | dropout=0.0, 19 | window_size=4, 20 | ): 21 | super().__init__() 22 | self.n_layers = n_layers 23 | 24 | self.drop = nn.Dropout(dropout) 25 | self.attn_layers = nn.ModuleList() 26 | self.norm_layers_1 = nn.ModuleList() 27 | self.ffn_layers = nn.ModuleList() 28 | self.norm_layers_2 = nn.ModuleList() 29 | for i in range(self.n_layers): 30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=dropout, window_size=window_size)) 31 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, hidden_channels_ffn, kernel_size, p_dropout=dropout)) 33 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 34 | 35 | # x: [BatchSize, hidden_channels, Length] 36 | # x_mask: [BatchSize, 1, Length] 37 | # 38 | # Output: [BatchSize, hidden_channels, Length] 39 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor): 40 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 41 | x = x * x_mask 42 | for i in range(self.n_layers): 43 | res = x 44 | x = self.attn_layers[i](x, x, attn_mask) 45 | x = self.drop(x) 46 | x = self.norm_layers_1[i](x + res) 47 | 48 | res = x 49 | x = self.ffn_layers[i](x, x_mask) 50 | x = self.drop(x) 51 | x = self.norm_layers_2[i](x + res) 52 | x = x * x_mask 53 | return x 54 | 55 | 56 | class RelativePositionTransformerDecoder(nn.Module): 57 | def __init__( 58 | self, 59 | hidden_channels: int, 60 | hidden_channels_ffn: int, 61 | n_heads: int, 62 | n_layers: int, 63 | kernel_size=1, 64 | dropout=0.0, 65 | window_size=4, 66 | ): 67 | super().__init__() 68 | self.n_layers = n_layers 69 | 70 | self.drop = nn.Dropout(dropout) 71 | self.self_attn_layers = nn.ModuleList() 72 | self.norm_layers_1 = nn.ModuleList() 73 | self.cross_attn_layers = nn.ModuleList() 74 | self.norm_layers_2 = nn.ModuleList() 75 | self.ffn_layers = nn.ModuleList() 76 | self.norm_layers_3 = nn.ModuleList() 77 | for i in range(self.n_layers): 78 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=dropout, window_size=window_size)) 79 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 80 | self.cross_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=dropout)) 81 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 82 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, hidden_channels_ffn, kernel_size, p_dropout=dropout)) 83 | self.norm_layers_3.append(LayerNorm(hidden_channels)) 84 | 85 | # x: [BatchSize, hidden_channels, Length_x] 86 | # x_mask: [BatchSize, 1, Length_x] 87 | # y: [BatchSize, hidden_channels, Length_y] 88 | # y_mask: [BatchSize, 1, Length_y] 89 | # 90 | # Output: [BatchSize, hidden_channels, Length_x] 91 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor): 92 | self_attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 93 | cross_attn_mask = y_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 94 | x = x * x_mask 95 | for i in range(self.n_layers): 96 | res = x 97 | x = self.self_attn_layers[i](x, x, self_attn_mask) 98 | x = self.drop(x) 99 | x = self.norm_layers_1[i](x + res) 100 | 101 | res = x 102 | x = self.cross_attn_layers[i](x, y, cross_attn_mask) 103 | x = self.drop(x) 104 | x = self.norm_layers_2[i](x + res) 105 | 106 | res = x 107 | x = self.ffn_layers[i](x, x_mask) 108 | x = self.drop(x) 109 | x = self.norm_layers_3[i](x + res) 110 | x = x * x_mask 111 | return x 112 | 113 | 114 | class MultiHeadAttention(nn.Module): 115 | def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 116 | super().__init__() 117 | assert channels % n_heads == 0 118 | 119 | self.channels = channels 120 | self.out_channels = out_channels 121 | self.n_heads = n_heads 122 | self.p_dropout = p_dropout 123 | self.window_size = window_size 124 | self.heads_share = heads_share 125 | self.block_length = block_length 126 | self.proximal_bias = proximal_bias 127 | self.proximal_init = proximal_init 128 | self.attn = None 129 | 130 | self.k_channels = channels // n_heads 131 | self.conv_q = nn.Linear(channels, channels) 132 | self.conv_k = nn.Linear(channels, channels) 133 | self.conv_v = nn.Linear(channels, channels) 134 | self.conv_o = nn.Linear(channels, out_channels) 135 | self.drop = nn.Dropout(p_dropout) 136 | 137 | if window_size is not None: 138 | n_heads_rel = 1 if heads_share else n_heads 139 | rel_stddev = self.k_channels**-0.5 140 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 141 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 142 | 143 | nn.init.xavier_uniform_(self.conv_q.weight) 144 | nn.init.xavier_uniform_(self.conv_k.weight) 145 | nn.init.xavier_uniform_(self.conv_v.weight) 146 | if proximal_init: 147 | with torch.no_grad(): 148 | self.conv_k.weight.copy_(self.conv_q.weight) 149 | self.conv_k.bias.copy_(self.conv_q.bias) 150 | 151 | def forward(self, x, c, attn_mask=None): 152 | q = self.conv_q(x.mT).mT 153 | k = self.conv_k(c.mT).mT 154 | v = self.conv_v(c.mT).mT 155 | 156 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 157 | 158 | x = self.conv_o(x.mT).mT 159 | return x 160 | 161 | # query: [b, n_h, t_t, d_k] 162 | # key: [b, n_h, t_s, d_k] 163 | # value: [b, n_h, t_s, d_k] 164 | # mask: [b, n_h, t_t, t_s] 165 | def attention(self, query, key, value, mask=None): 166 | # reshape [b, d, t] -> [b, n_h, t, d_k] 167 | b, d, t_s, t_t = (*key.size(), query.size(2)) 168 | query = query.view(b, self.n_heads, self.k_channels, t_t).mT # [b, n_h, t_t, d_k] 169 | key = key.view(b, self.n_heads, self.k_channels, t_s).mT # [b, n_h, t_s, d_k] 170 | value = value.view(b, self.n_heads, self.k_channels, t_s).mT # [b, n_h, t_s, d_k] 171 | 172 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT) # [b, n_h, t_t, t_s] 173 | if self.window_size is not None: 174 | assert t_s == t_t, "Relative attention is only available for self-attention." 175 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 176 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) 177 | scores_local = self._relative_position_to_absolute_position(rel_logits) 178 | scores = scores + scores_local 179 | if self.proximal_bias: 180 | assert t_s == t_t, "Proximal bias is only available for self-attention." 181 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 182 | if mask is not None: 183 | scores = scores.masked_fill(mask == 0, -1e4) 184 | if self.block_length is not None: 185 | assert t_s == t_t, "Local attention is only available for self-attention." 186 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 187 | scores = scores.masked_fill(block_mask == 0, -1e4) 188 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 189 | p_attn = self.drop(p_attn) 190 | output = torch.matmul(p_attn, value) 191 | if self.window_size is not None: 192 | relative_weights = self._absolute_position_to_relative_position(p_attn) 193 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 194 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 195 | output = output.mT.contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 196 | return output, p_attn 197 | 198 | def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor): 199 | """ 200 | x: [b, h, l, m] 201 | y: [h or 1, m, d] 202 | ret: [b, h, l, d] 203 | """ 204 | return torch.matmul(x, y.unsqueeze(0)) 205 | 206 | def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor): 207 | """ 208 | x: [b, h, l, d] 209 | y: [h or 1, m, d] 210 | ret: [b, h, l, m] 211 | """ 212 | return torch.matmul(x, y.unsqueeze(0).mT) 213 | 214 | def _get_relative_embeddings(self, relative_embeddings, length): 215 | max_relative_position = 2 * self.window_size + 1 216 | # Pad first before slice to avoid using cond ops. 217 | pad_length = max(length - (self.window_size + 1), 0) 218 | slice_start_position = max((self.window_size + 1) - length, 0) 219 | slice_end_position = slice_start_position + 2 * length - 1 220 | if pad_length > 0: 221 | padded_relative_embeddings = F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 222 | else: 223 | padded_relative_embeddings = relative_embeddings 224 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] 225 | return used_relative_embeddings 226 | 227 | def _relative_position_to_absolute_position(self, x): 228 | """ 229 | x: [b, h, l, 2*l-1] 230 | ret: [b, h, l, l] 231 | """ 232 | batch, heads, length, _ = x.size() 233 | # Concat columns of pad to shift from relative to absolute indexing. 234 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 235 | 236 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 237 | x_flat = x.view([batch, heads, length * 2 * length]) 238 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 239 | 240 | # Reshape and slice out the padded elements. 241 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] 242 | return x_final 243 | 244 | def _absolute_position_to_relative_position(self, x): 245 | """ 246 | x: [b, h, l, l] 247 | ret: [b, h, l, 2*l-1] 248 | """ 249 | batch, heads, length, _ = x.size() 250 | # padd along column 251 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 252 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 253 | # add 0's in the beginning that will skew the elements after reshape 254 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 255 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 256 | return x_final 257 | 258 | def _attention_bias_proximal(self, length): 259 | """Bias for self-attention to encourage attention to close positions. 260 | Args: 261 | length: an integer scalar. 262 | Returns: 263 | a Tensor with shape [1, 1, length, length] 264 | """ 265 | r = torch.arange(length, dtype=torch.float32) 266 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 267 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 268 | 269 | 270 | class FFN(nn.Module): 271 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, causal=False): 272 | super().__init__() 273 | self.kernel_size = kernel_size 274 | self.padding = self._causal_padding if causal else self._same_padding 275 | 276 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 277 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 278 | self.drop = nn.Dropout(p_dropout) 279 | 280 | def forward(self, x, x_mask): 281 | x = self.conv_1(self.padding(x * x_mask)) 282 | x = torch.relu(x) 283 | x = self.drop(x) 284 | x = self.conv_2(self.padding(x * x_mask)) 285 | return x * x_mask 286 | 287 | def _causal_padding(self, x): 288 | if self.kernel_size == 1: 289 | return x 290 | pad_l = self.kernel_size - 1 291 | pad_r = 0 292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 293 | x = F.pad(x, convert_pad_shape(padding)) 294 | return x 295 | 296 | def _same_padding(self, x): 297 | if self.kernel_size == 1: 298 | return x 299 | pad_l = (self.kernel_size - 1) // 2 300 | pad_r = self.kernel_size // 2 301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 302 | x = F.pad(x, convert_pad_shape(padding)) 303 | return x 304 | -------------------------------------------------------------------------------- /module/vits/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | # transforms for stochasitc duration predictor 8 | 9 | 10 | DEFAULT_MIN_BIN_WIDTH = 1e-3 11 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 12 | DEFAULT_MIN_DERIVATIVE = 1e-3 13 | 14 | 15 | def piecewise_rational_quadratic_transform( 16 | inputs, 17 | unnormalized_widths, 18 | unnormalized_heights, 19 | unnormalized_derivatives, 20 | inverse=False, 21 | tails=None, 22 | tail_bound=1.0, 23 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 24 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 25 | min_derivative=DEFAULT_MIN_DERIVATIVE, 26 | ): 27 | if tails is None: 28 | spline_fn = rational_quadratic_spline 29 | spline_kwargs = {} 30 | else: 31 | spline_fn = unconstrained_rational_quadratic_spline 32 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 33 | 34 | outputs, logabsdet = spline_fn( 35 | inputs=inputs, 36 | unnormalized_widths=unnormalized_widths, 37 | unnormalized_heights=unnormalized_heights, 38 | unnormalized_derivatives=unnormalized_derivatives, 39 | inverse=inverse, 40 | min_bin_width=min_bin_width, 41 | min_bin_height=min_bin_height, 42 | min_derivative=min_derivative, 43 | **spline_kwargs 44 | ) 45 | return outputs, logabsdet 46 | 47 | 48 | def searchsorted(bin_locations, inputs, eps=1e-6): 49 | bin_locations[..., -1] += eps 50 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 51 | 52 | 53 | def unconstrained_rational_quadratic_spline( 54 | inputs, 55 | unnormalized_widths, 56 | unnormalized_heights, 57 | unnormalized_derivatives, 58 | inverse=False, 59 | tails="linear", 60 | tail_bound=1.0, 61 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 62 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 63 | min_derivative=DEFAULT_MIN_DERIVATIVE, 64 | ): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == "linear": 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError("{} tails are not implemented.".format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + input_heights * (input_delta - input_derivatives) 163 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 164 | c = -input_delta * (inputs - input_cumheights) 165 | 166 | discriminant = b.pow(2) - 4 * a * c 167 | assert (discriminant >= 0).all() 168 | 169 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 170 | outputs = root * input_bin_widths + input_cumwidths 171 | 172 | theta_one_minus_theta = root * (1 - root) 173 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta) 174 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) 183 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta) 184 | outputs = input_cumheights + numerator / denominator 185 | 186 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2)) 187 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 188 | 189 | return outputs, logabsdet 190 | -------------------------------------------------------------------------------- /module/vits/wn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.parametrizations import weight_norm 4 | from torch.nn.utils import remove_weight_norm 5 | 6 | 7 | # WN module from https://arxiv.org/abs/1609.03499 8 | class WNLayer(nn.Module): 9 | def __init__(self, 10 | hidden_channels=192, 11 | kernel_size=5, 12 | dilation=1, 13 | speaker_embedding_dim=256): 14 | super().__init__() 15 | self.speaker_in = weight_norm( 16 | nn.Conv1d( 17 | speaker_embedding_dim , 18 | hidden_channels * 2, 1)) 19 | padding = int((kernel_size * dilation - dilation) / 2) 20 | self.conv = weight_norm( 21 | nn.Conv1d( 22 | hidden_channels, 23 | hidden_channels * 2, 24 | kernel_size, 25 | 1, 26 | padding, 27 | dilation=dilation, 28 | padding_mode='replicate')) 29 | self.out = weight_norm( 30 | nn.Conv1d(hidden_channels, hidden_channels * 2, 1)) 31 | 32 | # x: [BatchSize, hidden_channels, Length] 33 | # x_mask: [BatchSize, 1, Length] 34 | # g: [BatchSize, speaker_embedding_dim, 1] 35 | # Output: [BatchSize, hidden_channels, Length] 36 | def forward(self, x, x_mask, g): 37 | res = x 38 | x = self.conv(x) + self.speaker_in(g) 39 | x_0, x_1 = torch.chunk(x, 2, dim=1) 40 | x = torch.tanh(x_0) * torch.sigmoid(x_1) 41 | x = self.out(x) 42 | out, skip = torch.chunk(x, 2, dim=1) 43 | out = (out + res) * x_mask 44 | skip = skip * x_mask 45 | return out, skip 46 | 47 | def remove_weight_norm(self): 48 | remove_weight_norm(self.speaker_in) 49 | remove_weight_norm(self.conv) 50 | remove_weight_norm(self.out) 51 | 52 | 53 | class WN(nn.Module): 54 | def __init__(self, 55 | hidden_channels=192, 56 | kernel_size=5, 57 | dilation=1, 58 | num_layers=4, 59 | speaker_embedding_dim=256): 60 | super().__init__() 61 | self.layers = nn.ModuleList([]) 62 | for _ in range(num_layers): 63 | self.layers.append( 64 | WNLayer(hidden_channels, kernel_size, dilation, speaker_embedding_dim)) 65 | 66 | # x: [BatchSize, hidden_channels, Length] 67 | # x_mask: [BatchSize, 1, Length] 68 | # g: [BatchSize, speaker_embedding_dim, 1] 69 | # Output: [BatchSize, hidden_channels, Length] 70 | def forward(self, x, x_mask, g): 71 | output = None 72 | for layer in self.layers: 73 | x, skip = layer(x, x_mask, g) 74 | if output is None: 75 | output = skip 76 | else: 77 | output += skip 78 | return output 79 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | from pathlib import Path 5 | import shutil 6 | 7 | from module.preprocess.jvs import preprocess_jvs 8 | from module.preprocess.wave_and_text import preprocess_wave_and_text 9 | from module.preprocess.scan import scan_cache 10 | from module.utils.config import load_json_file 11 | 12 | 13 | def get_preprocess_method(dataset_type): 14 | if dataset_type == 'jvs': 15 | return preprocess_jvs 16 | if dataset_type == 'wav-txt': 17 | return preprocess_wave_and_text 18 | else: 19 | raise "Unknown dataset type" 20 | 21 | 22 | parser = argparse.ArgumentParser("preprocess") 23 | parser.add_argument('type') 24 | parser.add_argument('root_dir') 25 | parser.add_argument('-c', '--config', default='./config/base.json') 26 | parser.add_argument('--scan-only', default=False, type=bool) 27 | 28 | args = parser.parse_args() 29 | 30 | config = load_json_file(args.config) 31 | root_dir = Path(args.root_dir) 32 | dataset_type = args.type 33 | 34 | cache_dir = Path(config.preprocess.cache) 35 | if not cache_dir.exists(): 36 | cache_dir.mkdir() 37 | 38 | preprocess_method = get_preprocess_method(dataset_type) 39 | 40 | if not args.scan_only: 41 | print(f"Start preprocess type={dataset_type}, root={str(root_dir)}") 42 | preprocess_method(root_dir, config) 43 | 44 | print(f"Scaning dataset cache") 45 | scan_cache(config) 46 | shutil.copy(args.config, 'models/config.json') 47 | 48 | print(f"Complete!") 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | torch 3 | torchaudio 4 | tqdm 5 | numpy 6 | onnx 7 | transformers 8 | pyworld 9 | torchfcpe 10 | numba 11 | cython 12 | pyopenjtalk 13 | tensorboard 14 | g2p_en 15 | safetensors 16 | lightning 17 | sentencepiece 18 | sox 19 | soundfile 20 | music21 21 | gradio -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import random 5 | from pathlib import Path 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | import lightning as L 13 | import pytorch_lightning 14 | from module.utils.config import load_json_file 15 | from module.vits import Vits 16 | from module.utils.dataset import VitsDataModule 17 | from module.utils.safetensors import save_tensors 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description="train") 21 | parser.add_argument('-c', '--config', default='config/base.json') 22 | args = parser.parse_args() 23 | 24 | class SaveCheckpoint(L.Callback): 25 | def __init__(self, models_dir, interval=200): 26 | super().__init__() 27 | self.models_dir = Path(models_dir) 28 | self.interval = interval 29 | if not self.models_dir.exists(): 30 | self.models_dir.mkdir() 31 | 32 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 33 | if batch_idx % self.interval == 0: 34 | ckpt_path = self.models_dir / "vits.ckpt" 35 | trainer.save_checkpoint(ckpt_path) 36 | generator_path = self.models_dir / "generator.safetensors" 37 | save_tensors(pl_module.generator.state_dict(), generator_path) 38 | 39 | config = load_json_file(args.config) 40 | dm = VitsDataModule(**config.train.data_module) 41 | model_path = Path(config.train.save.models_dir) / "vits.ckpt" 42 | 43 | if model_path.exists(): 44 | print(f"loading checkpoint from {model_path}") 45 | model = Vits.load_from_checkpoint(model_path) 46 | else: 47 | print("initialize model") 48 | model = Vits(config.vits) 49 | 50 | print("if you need to check tensorboard, run `tensorboard -logdir lightning_logs`") 51 | cb_save_checkpoint = SaveCheckpoint(config.train.save.models_dir, config.train.save.interval) 52 | trainer = L.Trainer(**config.train.trainer, callbacks=[cb_save_checkpoint]) 53 | trainer.fit(model, dm) 54 | --------------------------------------------------------------------------------