├── .gitignore
├── LICENSE
├── README.md
├── config
├── base.json
└── small.json
├── docs
├── config.md
├── credits.md
├── images
│ ├── auris_architecture.png
│ └── auris_decoder.png
├── infer.md
├── installation.md
├── technical_details.md
├── todo_list.md
└── train.md
├── example
├── musicxml
│ └── test.musicxml
├── score_inputs
│ └── test.json
└── text_inputs
│ └── transcription.json
├── infer.py
├── infer_webui.py
├── module
├── g2p
│ ├── __init__.py
│ ├── english.py
│ ├── extractor.py
│ └── japanese.py
├── infer
│ └── __init__.py
├── language_model
│ ├── __init__.py
│ ├── extractor.py
│ └── rinna_roberta.py
├── preprocess
│ ├── jvs.py
│ ├── processor.py
│ ├── scan.py
│ └── wave_and_text.py
├── utils
│ ├── common.py
│ ├── config.py
│ ├── dataset.py
│ ├── energy_estimation.py
│ ├── f0_estimation.py
│ └── safetensors.py
└── vits
│ ├── __init__.py
│ ├── convnext.py
│ ├── crop.py
│ ├── decoder.py
│ ├── discriminator.py
│ ├── duration_discriminator.py
│ ├── duration_predictors.py
│ ├── feature_retrieval.py
│ ├── flow.py
│ ├── generator.py
│ ├── loss.py
│ ├── monotonic_align
│ ├── __init__.py
│ ├── core.pyx
│ └── setup.py
│ ├── normalization.py
│ ├── posterior_encoder.py
│ ├── prior_encoder.py
│ ├── speaker_embedding.py
│ ├── spectrogram.py
│ ├── text_encoder.py
│ ├── transformer.py
│ ├── transforms.py
│ └── wn.py
├── preprocess.py
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | /*.DS_Store
163 | *.pt
164 | *.pth
165 | *.wav
166 | *.ogg
167 | *.mp3
168 | *.swp
169 | *.swo
170 | module/vits/monotonic_align/core.c
171 | module/vits/monotonic_align/build/*
172 | dataset_cache/*
173 | models/*
174 | models
175 | logs/*
176 | lightning_logs
177 | lightning_logs/*
178 | outputs
179 | outputs/*
180 | audio_inputs/*
181 | audio_inputs
182 | flagged/
183 | flagged/*
184 | my_config/*
185 | my_config/
186 | backups/*
187 | backups
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 uthree
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Auris
2 | 歌声合成, 変換, TTSができる効率的な日本語事前学習済みモデルの開発
3 |
4 | ## 目次
5 | - [クレジット表記](docs/credits.md)
6 | - [インストール方法](docs/installation.md)
7 | - [学習方法](docs/train.md)
8 | - [推論方法](docs/infer.md)
9 | - [コンフィグについて](docs/config.md)
10 | - [技術的な詳細](docs/technical_details.md)
11 |
12 | ## Other Languages
13 | - [English](docs/readme_en.md) (comming soon)
--------------------------------------------------------------------------------
/config/base.json:
--------------------------------------------------------------------------------
1 | {
2 | "vits": {
3 | "segment_size": 32,
4 | "generator": {
5 | "decoder": {
6 | "sample_rate": 48000,
7 | "frame_size": 480,
8 | "n_fft": 1920,
9 | "speaker_embedding_dim": 256,
10 | "content_channels": 192,
11 | "pe_internal_channels": 256,
12 | "pe_num_layers": 4,
13 | "source_internal_channels": 256,
14 | "source_num_layers": 4,
15 | "num_harmonics": 30,
16 | "filter_channels": [512, 256, 128, 64, 32],
17 | "filter_factors": [
18 | 5,
19 | 4,
20 | 4,
21 | 3,
22 | 2
23 | ],
24 | "filter_resblock_type": "1",
25 | "filter_down_dilations": [
26 | [1, 2],
27 | [4, 8]
28 | ],
29 | "filter_down_interpolation": "conv",
30 | "filter_up_dilations": [
31 | [1, 3, 5],
32 | [1, 3, 5],
33 | [1, 3, 5]
34 | ],
35 | "filter_up_kernel_sizes": [
36 | 3,
37 | 7,
38 | 11
39 | ],
40 | "filter_up_interpolation": "conv"
41 | },
42 | "posterior_encoder": {
43 | "n_fft": 1920,
44 | "frame_size": 480,
45 | "internal_channels": 192,
46 | "speaker_embedding_dim": 256,
47 | "content_channels": 192,
48 | "kernel_size": 5,
49 | "dilation": 1,
50 | "num_layers": 16
51 | },
52 | "speaker_embedding": {
53 | "num_speakers": 8192,
54 | "embedding_dim": 256
55 | },
56 | "prior_encoder": {
57 | "flow": {
58 | "content_channels": 192,
59 | "internal_channels": 192,
60 | "speaker_embedding_dim": 256,
61 | "kernel_size": 5,
62 | "dilation": 1,
63 | "num_layers": 4,
64 | "num_flows": 4
65 | },
66 | "text_encoder": {
67 | "num_phonemes": 512,
68 | "num_languages": 256,
69 | "lm_dim": 768,
70 | "internal_channels": 256,
71 | "speaker_embedding_dim": 256,
72 | "content_channels": 192,
73 | "n_heads": 4,
74 | "num_layers": 4,
75 | "window_size": 4
76 | },
77 | "stochastic_duration_predictor": {
78 | "in_channels": 192,
79 | "filter_channels": 256,
80 | "kernel_size": 5,
81 | "p_dropout": 0.0,
82 | "speaker_embedding_dim": 256
83 | },
84 | "duration_predictor": {
85 | "content_channels": 192,
86 | "internal_channels": 256,
87 | "speaker_embedding_dim": 256,
88 | "kernel_size": 7,
89 | "num_layers": 4
90 | }
91 | }
92 | },
93 | "discriminator": {
94 | "mrd": {
95 | "resolutions": [
96 | 128,
97 | 256,
98 | 512
99 | ],
100 | "channels": 32,
101 | "max_channels": 256,
102 | "num_layers": 4
103 | },
104 | "mpd": {
105 | "periods": [
106 | 1,
107 | 2,
108 | 3,
109 | 5,
110 | 7,
111 | 11,
112 | 17,
113 | 23,
114 | 31
115 | ],
116 | "channels": 32,
117 | "channels_mul": 2,
118 | "max_channels": 256,
119 | "num_layers": 4
120 | }
121 | },
122 | "duration_discriminator": {
123 | "content_channels": 192,
124 | "speaker_embedding_dim": 256,
125 | "num_layers": 3
126 | },
127 | "optimizer": {
128 | "lr": 1e-4,
129 | "betas": [
130 | 0.8,
131 | 0.99
132 | ]
133 | }
134 | },
135 | "language_model": {
136 | "type": "rinna_roberta",
137 | "options": {
138 | "hf_repo": "rinna/japanese-roberta-base",
139 | "layer": 12
140 | }
141 | },
142 | "preprocess": {
143 | "sample_rate": 48000,
144 | "max_waveform_length": 480000,
145 | "pitch_estimation": "fcpe",
146 | "max_phonemes": 100,
147 | "lm_max_tokens": 30,
148 | "frame_size": 480,
149 | "cache": "dataset_cache"
150 | },
151 | "train": {
152 | "save": {
153 | "models_dir": "models",
154 | "interval": 400
155 | },
156 | "data_module": {
157 | "cache_dir": "dataset_cache",
158 | "metadata": "models/metadata.json",
159 | "batch_size": 8,
160 | "num_workers": 15
161 | },
162 | "trainer": {
163 | "devices": "auto",
164 | "max_epochs": 1000000,
165 | "precision": null
166 | }
167 | },
168 | "infer": {
169 | "n_fft": 1920,
170 | "frame_size": 480,
171 | "sample_rate": 48000,
172 | "max_lm_tokens": 50,
173 | "max_phonemes": 500,
174 | "max_frames": 2000,
175 | "device": "cuda"
176 | }
177 | }
--------------------------------------------------------------------------------
/config/small.json:
--------------------------------------------------------------------------------
1 | {
2 | "vits": {
3 | "segment_size": 32,
4 | "generator": {
5 | "decoder": {
6 | "sample_rate": 48000,
7 | "frame_size": 480,
8 | "n_fft": 1920,
9 | "speaker_embedding_dim": 128,
10 | "content_channels": 96,
11 | "pe_internal_channels": 128,
12 | "pe_num_layers": 3,
13 | "source_internal_channels": 128,
14 | "source_num_layers": 3,
15 | "num_harmonics": 14,
16 | "filter_channels": [
17 | 192,
18 | 96,
19 | 48,
20 | 24
21 | ],
22 | "filter_factors": [
23 | 4,
24 | 4,
25 | 5,
26 | 6
27 | ],
28 | "filter_resblock_type": "3",
29 | "filter_down_dilations": [
30 | [
31 | 1,
32 | 2,
33 | 4
34 | ]
35 | ],
36 | "filter_down_interpolation": "linear",
37 | "filter_up_dilations": [
38 | [
39 | 1,
40 | 3,
41 | 9,
42 | 27
43 | ]
44 | ],
45 | "filter_up_kernel_sizes": [
46 | 3
47 | ],
48 | "filter_up_interpolation": "linear"
49 | },
50 | "posterior_encoder": {
51 | "n_fft": 1920,
52 | "frame_size": 480,
53 | "internal_channels": 96,
54 | "speaker_embedding_dim": 128,
55 | "content_channels": 96,
56 | "kernel_size": 5,
57 | "dilation": 1,
58 | "num_layers": 16
59 | },
60 | "speaker_embedding": {
61 | "num_speakers": 8192,
62 | "embedding_dim": 128
63 | },
64 | "prior_encoder": {
65 | "flow": {
66 | "content_channels": 96,
67 | "internal_channels": 96,
68 | "speaker_embedding_dim": 128,
69 | "kernel_size": 5,
70 | "dilation": 1,
71 | "num_layers": 4,
72 | "num_flows": 4
73 | },
74 | "text_encoder": {
75 | "num_phonemes": 512,
76 | "num_languages": 256,
77 | "lm_dim": 768,
78 | "internal_channels": 256,
79 | "speaker_embedding_dim": 128,
80 | "content_channels": 96,
81 | "n_heads": 4,
82 | "num_layers": 4,
83 | "window_size": 4
84 | },
85 | "stochastic_duration_predictor": {
86 | "in_channels": 96,
87 | "filter_channels": 256,
88 | "kernel_size": 5,
89 | "p_dropout": 0.0,
90 | "speaker_embedding_dim": 128
91 | },
92 | "duration_predictor": {
93 | "content_channels": 96,
94 | "internal_channels": 256,
95 | "speaker_embedding_dim": 128,
96 | "kernel_size": 7,
97 | "num_layers": 4
98 | }
99 | }
100 | },
101 | "discriminator": {
102 | "mrd": {
103 | "resolutions": [
104 | 128,
105 | 256,
106 | 512
107 | ],
108 | "channels": 32,
109 | "max_channels": 256,
110 | "num_layers": 4
111 | },
112 | "mpd": {
113 | "periods": [
114 | 1,
115 | 2,
116 | 3,
117 | 5,
118 | 7,
119 | 11
120 | ],
121 | "channels": 32,
122 | "channels_mul": 2,
123 | "max_channels": 256,
124 | "num_layers": 4
125 | }
126 | },
127 | "duration_discriminator": {
128 | "content_channels": 96,
129 | "speaker_embedding_dim": 128,
130 | "num_layers": 3
131 | },
132 | "optimizer": {
133 | "lr": 1e-4,
134 | "betas": [
135 | 0.8,
136 | 0.99
137 | ]
138 | }
139 | },
140 | "language_model": {
141 | "type": "rinna_roberta",
142 | "options": {
143 | "hf_repo": "rinna/japanese-roberta-base",
144 | "layer": 4
145 | }
146 | },
147 | "preprocess": {
148 | "sample_rate": 48000,
149 | "max_waveform_length": 480000,
150 | "pitch_estimation": "fcpe",
151 | "max_phonemes": 100,
152 | "lm_max_tokens": 30,
153 | "frame_size": 480,
154 | "cache": "dataset_cache"
155 | },
156 | "train": {
157 | "save": {
158 | "models_dir": "models",
159 | "interval": 400
160 | },
161 | "data_module": {
162 | "cache_dir": "dataset_cache",
163 | "metadata": "models/metadata.json",
164 | "batch_size": 4,
165 | "num_workers": 7
166 | },
167 | "trainer": {
168 | "devices": "auto",
169 | "max_epochs": 1000000,
170 | "precision": null
171 | }
172 | },
173 | "infer": {
174 | "n_fft": 1920,
175 | "frame_size": 480,
176 | "sample_rate": 48000,
177 | "max_lm_tokens": 50,
178 | "max_phonemes": 500,
179 | "max_frames": 2000,
180 | "device": "cuda"
181 | }
182 | }
--------------------------------------------------------------------------------
/docs/config.md:
--------------------------------------------------------------------------------
1 | # コンフィグ
2 | [config](../config/)は、モデルアーキテクチャの設定や学習時の設定などを記述するファイルがあります。
3 | どのコンフィグを選ぶかによって、モデルの性能が変化したりします。
4 |
5 | ## プリセット
6 | 本リポジトリにはコンフィグの設定例としていくつかのコンフィグファイルを用意しています。
7 | - [small](../config/small.json) : 最小規模の構成。動作確認実験のためのもの。
8 | - [base](../config/base.json) : デフォルト。HiFi-GANのV1モデル/VITSに相当する規模のもの。おそらく実用的な性能が出る。
9 | - [large](../config/large.json) (Comming Soon?) : baseから次元数を増やしたモデル。大規模な計算震源が必要かも。
10 |
11 | ## パラメータの意味
12 | そもそもこのコンフィグを編集しようとする人はソースコードを読めたり、パラメータ名からある程度どのようなパラメータかを察する事ができるはずなので、必要ないかもしれないが、一部説明を書いておく。
13 | -WIP-
--------------------------------------------------------------------------------
/docs/credits.md:
--------------------------------------------------------------------------------
1 | # クレジット表記
2 | MITライセンスのリポジトリからソースコードをコピペしたりしている都合上、その旨を明記しなければいけないため、ここに記す。
3 |
4 | ## 引用したソースコード
5 | - [monotonic_align](../module/vits/monotonic_align/) : [ESPNet](https://github.com/espnet/espnet) から引用。エラーメッセージを一部改変。
6 | - [duration_predictors.py](../module/vits/duration_predictors.py/) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/model/duration_predictors.py) から引用、改変。
7 | - [transforms.py](../module/utils/transforms.py) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/utils/transforms.py)から引用。
8 | - [transformer.py](../module/vits/transformer.py) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/model/transformer.py) から引用、改変。
9 | - [normalization.py](../module/vits/normalization.py) : [VITS2](https://github.com/daniilrobnikov/vits2/blob/main/model/normalization.py) から一部引用。
10 |
11 | ## 学習済みモデル
12 | - [rinna_roberta.py](../module/language_model/rinna_roberta.py) : [rinna/japanese-roberta-base](https://huggingface.co/rinna/japanese-roberta-base) を使用。
13 |
14 | ## 参考文献
15 |
16 | ### 参考にした記事
17 | 参考にしたブログなど記事たち。気づき次第随時追加していく。
18 | - [【機械学習】VITSでアニメ声へ変換できるボイスチェンジャー&読み上げ器を作った話](https://qiita.com/zassou65535/items/00d7d5562711b89689a8) : zassou氏によるVITS解説記事。理論から実装、損失関数の導出など非常にわかりやすく書かれている。本プロジェクトはこの記事に出合えなかったら完成するのは不可能だろう。
19 |
20 | ### 参考にしたリポジトリ
21 | - [zassou65535/VITS](https://github.com/zassou65535/VITS) : zassou氏によるVITS実装。 posterior_encoder.pyをはじめとする実装や、ディレクトリ構成などを参考にした。
22 | - [RVC-Project/Retrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) : いわゆるRVC, HnNSF-HifiGANによるピッチ制御可能なデコーダーや、特徴量ベクトルを検索するという案は、RVCから着想を得ている。
23 | - [fishaudio/Bert-VITS2](https://github.com/fishaudio/Bert-VITS2) : VITSのテキストエンコーダー部分にBERTの特徴量を付与し、感情や文脈などもエンコードできるようにするという案は、Bert-VITS2から着想を得ている。
24 | - [uthree/tinyvc](https://github.com/uthree/tinyvc) 自分のリポジトリを参考にするとはどういうことだ、と言われそうだが、TinyVCのデコーダーをほぼそのままスケールアップして採用している。
25 |
26 | ### 論文
27 | 参考にした論文。
28 | #### 全体的な設計
29 | - [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103)
30 | - [VITS2: Improving Quality and Efficiency of Single-Stage Text-to-Speech with Adversarial Learning and Architecture Design](https://arxiv.org/abs/2307.16430)
31 |
32 | #### Decoderを設計する際に参考にした論文
33 | - [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646)
34 | - [FastSVC: Fast Cross-Domain Singing Voice Conversion with Feature-wise Linear Modulation](https://arxiv.org/abs/2011.05731)
35 | - [DDSP: Differentiable Digital Signal Processing](https://arxiv.org/abs/2001.04643)
36 | - [VISinger 2: High-Fidelity End-to-End Singing Voice Synthesis Enhanced by Digital Signal Processing Synthesizer](https://arxiv.org/abs/2211.02903)
37 | - [Neural Concatenative Singing Voice Conversion: Rethinking Concatenation-Based Approach for One-Shot Singing Voice Conversion](https://arxiv.org/abs/2312.04919)
38 |
39 | #### Feature Retrieval (特徴量検索)
40 | - [Voice Conversion With Just Nearest Neighbors](https://arxiv.org/abs/2305.18975)
41 |
42 | #### 言語モデル
43 | - [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692)
44 |
45 | #### その他
46 | - [GAN Vocoder: Multi-Resolution Discriminator Is All You Need](https://arxiv.org/abs/2103.05236)
--------------------------------------------------------------------------------
/docs/images/auris_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uthree/auris_experimental_vits_dsp/faffcafa3e38028f25ecaa79dd420207a732e324/docs/images/auris_architecture.png
--------------------------------------------------------------------------------
/docs/images/auris_decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uthree/auris_experimental_vits_dsp/faffcafa3e38028f25ecaa79dd420207a732e324/docs/images/auris_decoder.png
--------------------------------------------------------------------------------
/docs/infer.md:
--------------------------------------------------------------------------------
1 | # 推論方法
2 | このドキュメントでは、推論方法について記す
3 |
4 | ## gradioによるUIを使った推論
5 | webuiを起動する。
6 | ```sh
7 | python3 infer_webui.py
8 | ```
9 | `localhost:7860` にブラウザでアクセスする。
10 |
11 | ## CUIによる推論
12 |
13 | ### 音声再構築タスク
14 | 入力された音声を再構築するタスク。VAEの性能確認用。
15 |
16 | 1. 音声ファイルが複数入ったディレクトリを用意する
17 | ```sh
18 | mkdir audio_inputs # ここに入力音声ファイルを入れる
19 | ```
20 |
21 | 2. 推論する。
22 | `-s 話者名`で話者を指定する必要がある。
23 | ```sh
24 | python3 infer.py -i audio_inputs -t recon -s jvs001
25 | ```
26 |
27 | 3. `outputs/`内に出力ファイルが生成されるので、確認する。
28 |
29 | ### 音声読み上げ(TTS)
30 | テキストを読み上げるタスク。
31 |
32 | 1. 台本フォルダを用意する。
33 | 台本の内容は`example/text_inputs/`に例があるので、それを参考に作成する。
34 | ```sh
35 | mkdir text_inputs # ここに台本を入れる
36 | ```
37 |
38 | 2. 推論する。
39 | ```sh
40 | python3 infer.py -i text_inputs -t tts
41 | ```
42 |
43 | #### 設定項目の詳細
44 | - `speaker`: 話者名
45 | - `text`: 読み上げるテキスト
46 | - `style_text`: スタイルテキスト。任意。つけない場合は`text`と同じ内容になる。
47 | - `language`: 言語
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | # インストール方法
2 | ## 事前に用意するもの
3 | - Python 3.10.6 or later
4 | - torch 2.0.1 or later, cuda等GPUが使用可能な環境
5 |
6 | ## 手順
7 | 1. このリポジトリをクローンし、ディレクトリ内に移動する
8 | ```sh
9 | git clone https://github.com/uthree/auris
10 | cd auris
11 | ```
12 |
13 | 2. 依存関係をインストール
14 | ```sh
15 | pip3 install -r requirements.txt
16 | ```
17 |
18 | 3. monotonic_arginをビルドする(任意)
19 | (ビルドしない場合はnumba実装が使われる。ビルドしたほうがパフォーマンスが良い。)
20 | ```sh
21 | cd module/vits/monotonic_align
22 | python3 setup.py build_ext --inplace
23 | ```
24 |
--------------------------------------------------------------------------------
/docs/technical_details.md:
--------------------------------------------------------------------------------
1 | # 技術的な詳細
2 | このドキュメントは、本リポジトリの技術的な詳細を記す。
3 |
4 | ## 開発の動機
5 | - オープンソースのTTS(Text-To-Speech), SVS(Singing-Voice-Synthesis), SVC(Singing-Voice-Conversion)ができるモデルが欲しい。
6 | - なるべくライセンスが緩いものがいい。(MITライセンスなど)
7 | - 日本語の事前学習モデルが欲しい。
8 |
9 | ## モデルアーキテクチャ
10 | VITSをベースに改造するという形になる。
11 | 具体的には、
12 | - DecoderをDSP+HnNSF-HiFiGANに変更。DSP機能を取り入れることで、外部からピッチを制御可能に。これにより歌声合成ができる。
13 | - Text Encoderに言語モデルの特徴量を参照する機能をつけ、感情や文脈などを読み取れるように。
14 | - Feature Retrievalを追加し話者の再現性を向上させる
15 | - VITS2のDuration Discriminatorを追加
16 | - Discriminatorのうち、MultiScaleDiscriminatorをMultiResolutionalDiscriminatorに変更。スペクトログラムのオーバースムージングを回避する。
17 |
18 | 等の改造があげられる。
19 |
20 | ### 全体の構造
21 | 
22 | VITSの構造とほぼ同じ。
23 |
24 | ### デコーダー
25 | 
26 | DDSPのような加算シンセサイザによる音声合成の後にHiFi-GANに似た構造のフィルターをかけるハイブリッド構造。
--------------------------------------------------------------------------------
/docs/todo_list.md:
--------------------------------------------------------------------------------
1 | # TODO List
2 | - [ ] 話者モーフィング機能の追加
3 | - [ ] 特徴ベクトル辞書の生成
4 | - [ ] メタデータに話者の平均f0等の追加情報をいれる
5 | - [ ] musicxmlを読み込む機能の追加
6 | - [ ] 楽譜から歌声合成を作成する
7 | - [ ] onnxへの出力
--------------------------------------------------------------------------------
/docs/train.md:
--------------------------------------------------------------------------------
1 | # 学習方法
2 | このドキュメントでは、モデルの学習方法について記す。
3 |
4 | ## 前処理
5 | いくつかのデータセットのための前処理スクリプトを用意しておいた。
6 | データセットをダウンロードし、これらのスクリプトを実行するだけで前処理が完了する。
7 | `-c`, `--config` オプションで使用するコンフィグファイルを指定できる。
8 |
9 | ### JVSコーパス
10 | [JVSコーパス](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus)の前処理は以下のコマンドで行う。
11 | ```sh
12 | python3 preprocess.py jvs jvs_ver1/ -c config/base.json
13 | ```
14 |
15 |
16 | ### 自作データセット
17 | データセットを自作する場合、まず以下の構成のディレクトリを用意する。
18 | `root_dir`の名前、`speaker001`等の名前は何でもよいが、半角英数字で命名することを推奨。
19 | 書き起こしの言語が日本語でない場合は、 `preprocess/wave_and_text.py`の`LANGUAGE`を書き換える。
20 | ```
21 | root_dir/
22 | - speaker001/
23 | - speech001.wav
24 | - speech001.txt
25 | - speech002.wav
26 | - speech002.txt
27 | ...
28 | - speaker02/
29 | - speech001.wav
30 | - speech001.txt
31 | - speech002.wav
32 | - speech002.txt
33 | ...
34 | ...
35 | ```
36 | wavファイルと同じファイル名のテキストファイルに、その書き起こしが入る形にする。
37 | データセットが用意できたら、前処理を実行する。
38 | ```sh
39 | python3 preprocess.py wav-txt root_dir/ -c config/base.json
40 | ```
41 |
42 |
43 | ## 学習を実行
44 | ```sh
45 | python3 train.py -c config/base.json
46 | ```
47 |
48 | ## 学習を再開する
49 | `models/vits.ckpt`を自動的に読み込んで再開してくれる。
50 | 読み込みに失敗した場合はファイルが壊れているので、`lightning_logs/`内にある最新のckptを`vits.ckpt`に名前を変更して`models/`に配置することで復旧できる。
51 |
52 | ## 学習の状態を確認
53 | tensorboardというライブラリを使って学習進捗を可視化することができる。
54 | ```sh
55 | tensorboard --logdir lightning_logs
56 | ```
57 | をscreen等を用いてバックグラウンドで実行する。
58 | これが実行されている間はtensorboardのサーバーが動いているので、ブラウザで`http://localhost:6006`にアクセスすると進捗を見ることができる。
59 |
60 | ## FAQ
61 | - `models/metadata.json`は何のファイルですか?
62 | 話者名と話者IDを関連付けるためのメタデータが含まれるJSONファイルです。
63 | データセットの前処理を行うと自動的に生成されます。
64 | - `models/config.json`は何のファイルですか?
65 | 推論時にデフォルトでロードするコンフィグです。前処理時に`-c`, `--config`オプションで指定したコンフィグが複製されます。
--------------------------------------------------------------------------------
/example/musicxml/test.musicxml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Title
6 |
7 |
8 | Composer
9 |
10 | MuseScore 3.6.2
11 | 2024-04-13
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | 6.99911
22 | 40
23 |
24 |
25 | 1696.94
26 | 1200.48
27 |
28 | 85.7252
29 | 85.7252
30 | 85.7252
31 | 85.7252
32 |
33 |
34 | 85.7252
35 | 85.7252
36 | 85.7252
37 | 85.7252
38 |
39 |
40 |
41 |
42 |
43 |
44 | title
45 | test score
46 |
47 |
48 | composer
49 | Composer
50 |
51 |
52 |
53 | Piano
54 | Pno.
55 |
56 | Piano
57 |
58 |
59 |
60 | 1
61 | 1
62 | 78.7402
63 | 0
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | 50.00
73 | 0.00
74 |
75 | 170.00
76 |
77 |
78 |
79 | 1
80 |
81 | 0
82 |
83 |
87 |
88 | G
89 | 2
90 |
91 |
92 |
93 |
94 | C
95 | 4
96 |
97 | 1
98 | 1
99 | quarter
100 | up
101 |
102 | single
103 | あ
104 |
105 |
106 |
107 |
108 | E
109 | 4
110 |
111 | 1
112 | 1
113 | quarter
114 | up
115 |
116 | single
117 | い
118 |
119 |
120 |
121 |
122 | G
123 | 4
124 |
125 | 1
126 | 1
127 | quarter
128 | up
129 |
130 | single
131 | う
132 |
133 |
134 |
135 |
136 | B
137 | 4
138 |
139 | 1
140 | 1
141 | quarter
142 | down
143 |
144 | single
145 | え
146 |
147 |
148 |
149 |
150 |
151 |
152 | 4
153 | 1
154 |
155 |
156 |
157 |
158 |
159 | 4
160 | 1
161 |
162 |
163 |
164 |
165 |
166 | 4
167 | 1
168 |
169 |
170 |
171 |
172 |
173 |
174 | 0.00
175 | 0.00
176 |
177 | 135.55
178 |
179 |
180 |
181 |
182 | 4
183 | 1
184 |
185 |
186 |
187 |
188 |
189 | 4
190 | 1
191 |
192 |
193 |
194 |
195 |
196 | 4
197 | 1
198 |
199 |
200 |
201 |
202 |
203 | 4
204 | 1
205 |
206 |
207 |
208 |
209 |
210 |
211 | 0.00
212 | 0.00
213 |
214 | 135.55
215 |
216 |
217 |
218 |
219 | 4
220 | 1
221 |
222 |
223 |
224 |
225 |
226 | 4
227 | 1
228 |
229 |
230 |
231 |
232 |
233 | 4
234 | 1
235 |
236 |
237 |
238 |
239 |
240 | 4
241 | 1
242 |
243 |
244 |
245 |
246 |
247 |
248 | 0.00
249 | 0.00
250 |
251 | 135.55
252 |
253 |
254 |
255 |
256 | 4
257 | 1
258 |
259 |
260 |
261 |
262 |
263 | 4
264 | 1
265 |
266 |
267 |
268 |
269 |
270 | 4
271 | 1
272 |
273 |
274 |
275 |
276 |
277 | 4
278 | 1
279 |
280 |
281 |
282 |
283 |
284 |
285 | 0.00
286 | 0.00
287 |
288 | 135.55
289 |
290 |
291 |
292 |
293 | 4
294 | 1
295 |
296 |
297 |
298 |
299 |
300 | 4
301 | 1
302 |
303 |
304 |
305 |
306 |
307 | 4
308 | 1
309 |
310 |
311 |
312 |
313 |
314 | 4
315 | 1
316 |
317 |
318 |
319 |
320 |
321 |
322 | 0.00
323 | 0.00
324 |
325 | 135.55
326 |
327 |
328 |
329 |
330 | 4
331 | 1
332 |
333 |
334 |
335 |
336 |
337 | 4
338 | 1
339 |
340 |
341 |
342 |
343 |
344 | 4
345 | 1
346 |
347 |
348 |
349 |
350 |
351 | 4
352 | 1
353 |
354 |
355 |
356 |
357 |
358 |
359 | 0.00
360 | 0.00
361 |
362 | 135.55
363 |
364 |
365 |
366 |
367 | 4
368 | 1
369 |
370 |
371 |
372 |
373 |
374 | 4
375 | 1
376 |
377 |
378 |
379 |
380 |
381 | 4
382 | 1
383 |
384 |
385 |
386 |
387 |
388 | 4
389 | 1
390 |
391 |
392 |
393 |
394 |
395 |
396 | 0.00
397 | 503.83
398 |
399 | 135.55
400 |
401 |
402 |
403 |
404 | 4
405 | 1
406 |
407 |
408 |
409 |
410 |
411 | 4
412 | 1
413 |
414 |
415 |
416 |
417 |
418 | 4
419 | 1
420 |
421 |
422 |
423 |
424 |
425 | 4
426 | 1
427 |
428 |
429 | light-heavy
430 |
431 |
432 |
433 |
434 |
--------------------------------------------------------------------------------
/example/score_inputs/test.json:
--------------------------------------------------------------------------------
1 | {
2 | "parts": {
3 | "0": {
4 | "language": "ja",
5 | "style_text": "あ",
6 | "speaker": "jvs001",
7 | "notes": [
8 | {
9 | "pitch": 64,
10 | "onset": 0.1,
11 | "offset": 0.9,
12 | "lyrics": "あ",
13 | "vibrato": {
14 | "strength": 1.0,
15 | "onset": 0.5,
16 | "offset": 0.8,
17 | "period": 0.05
18 | },
19 | "yodel": {
20 | "strength": 1.0,
21 | "onset": 0.1,
22 | "offset": 0.2
23 | },
24 | "fall": {
25 | "strength": 12.0,
26 | "onset": 0.8,
27 | "offset": 0.9
28 | },
29 | "energy": 1.0
30 | }
31 | ]
32 | }
33 | }
34 | }
--------------------------------------------------------------------------------
/example/text_inputs/transcription.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": {
3 | "speaker": "jvs001",
4 | "text": "また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。",
5 | "style_text": "また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。",
6 | "language": "ja"
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import argparse
3 | import json
4 |
5 | import torch
6 | import torchaudio
7 | from torchaudio.functional import resample
8 |
9 | from module.infer import Infer
10 |
11 | parser = argparse.ArgumentParser(description="inference")
12 | parser.add_argument('-c', '--config', default='models/config.json')
13 | parser.add_argument('-t', '--task', choices=['tts', 'recon', 'svc', 'svs'], default='tts')
14 | parser.add_argument('-s', '--speaker', default='jvs001')
15 | parser.add_argument('-i', '--inputs', default='inputs')
16 | parser.add_argument('-o', '--outputs', default='outputs')
17 | parser.add_argument('-m', '--model', default='models/generator.safetensors')
18 | parser.add_argument('-meta', '--metadata', default='models/metadata.json')
19 | args = parser.parse_args()
20 |
21 | outputs_dir = Path(args.outputs)
22 |
23 | # load model
24 | infer = Infer(args.model, args.config, args.metadata)
25 | device = infer.device
26 |
27 | # support audio formats
28 | audio_formats = ['mp3', 'wav', 'ogg']
29 |
30 | # make outputs directory if not exists
31 | if not outputs_dir.exists():
32 | outputs_dir.mkdir()
33 |
34 | if args.task == 'recon':
35 | print("Task: Audio Reconstruction")
36 | # audio reconstruction task
37 | spk = args.speaker
38 |
39 | # get input path
40 | inputs_dir = Path(args.inputs)
41 | inputs = []
42 |
43 | # load files
44 | for fmt in audio_formats:
45 | for path in inputs_dir.glob(f"*.{fmt}"):
46 | inputs.append(path)
47 |
48 | # inference
49 | for path in inputs:
50 | print(f"Inferencing {path}")
51 |
52 | # load audio
53 | wf, sr = torchaudio.load(path)
54 |
55 | # resample
56 | if sr != infer.sample_rate:
57 | wf = resample(wf, sr)
58 |
59 | # infer
60 | spk = args.speaker
61 | wf = infer.audio_reconstruction(wf, spk).cpu()
62 |
63 | # save
64 | save_path = outputs_dir / (path.stem + ".wav")
65 | torchaudio.save(save_path, wf, infer.sample_rate)
66 |
67 | elif args.task == 'tts':
68 | print("Task: Text to Speech")
69 |
70 | # get input path
71 | inputs_dir = Path(args.inputs)
72 |
73 | # load files
74 | inputs = []
75 | for path in inputs_dir.glob("*.json"):
76 | inputs.append(path)
77 |
78 | # inference
79 | for path in inputs:
80 | print(f"Inferencing {path}")
81 | t = json.load(open(path, encoding='utf-8'))
82 | for k, v in zip(t.keys(), t.values()):
83 | print(f" Inferencing {k}")
84 | wf = infer.text_to_speech(**v).cpu()
85 |
86 | # save
87 | save_path = outputs_dir / (f"{path.stem}_{k}.wav")
88 | torchaudio.save(save_path, wf, infer.sample_rate)
89 |
90 | elif args.task == 'svs':
91 | print("Task: Singing Voice Synthesis")
92 |
93 | inputs_dir = Path(args.inputs)
94 |
95 | # load score
96 | inputs = []
97 | for path in inputs_dir.glob("*.json"):
98 | inputs.append(path)
99 |
100 | # inference
101 | for path in inputs:
102 | print(f"Inferencing {path}")
103 | score = json.load(open(path, encoding='utf-8'))
104 | wf = infer.singing_voice_synthesis(score)
105 |
106 | # save
107 | save_path = outputs_dir / (path.stem + ".wav")
108 | torchaudio.save(save_path, wf, infer.sample_rate)
109 |
110 |
--------------------------------------------------------------------------------
/infer_webui.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import argparse
3 | import torch
4 | import numpy as np
5 |
6 | from module.infer import Infer
7 |
8 | import gradio as gr
9 |
10 | parser = argparse.ArgumentParser(description="inference")
11 | parser.add_argument('-c', '--config', default='models/config.json')
12 | parser.add_argument('-t', '--task', choices=['tts', 'recon', 'svc', 'svs'], default='tts')
13 | parser.add_argument('-m', '--model', default='models/generator.safetensors')
14 | parser.add_argument('-meta', '--metadata', default='models/metadata.json')
15 | parser.add_argument('-p', '--port', default=7860, type=int)
16 | args = parser.parse_args()
17 |
18 | # load model
19 | infer = Infer(args.model, args.config, args.metadata)
20 | device = infer.device
21 |
22 | def text_to_speech(text, style_text, speaker, language, duration_scale, pitch_shift):
23 | if style_text == "":
24 | style_text = text
25 | wf = infer.text_to_speech(text, speaker, language, style_text, duration_scale, pitch_shift)
26 | wf = wf.squeeze(0).cpu().numpy()
27 | wf = (wf * 32768).astype(np.int16)
28 | sample_rate = infer.sample_rate
29 | return sample_rate, wf
30 |
31 | demo = gr.Interface(text_to_speech, inputs=[
32 | gr.Text(label="Text"),
33 | gr.Text(label="Style"),
34 | gr.Dropdown(infer.speakers(), label="Speaker", value=infer.speakers()[0]),
35 | gr.Dropdown(infer.languages(), label="Language", value=infer.languages()[0]),
36 | gr.Slider(0.1, 3.0, 1.0, label="Duration Scale"),
37 | gr.Slider(-12.0, 12.0, 0.0, label="Pitch Shift"),
38 | ],
39 | outputs=[gr.Audio()])
40 |
41 | demo.launch(debug=True, server_port=args.port)
--------------------------------------------------------------------------------
/module/g2p/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Tuple
2 | import torch
3 |
4 | from .japanese import JapaneseExtractor
5 | #from .english import EnglishExtractor
6 |
7 |
8 | class G2PProcessor:
9 | def __init__(self):
10 | self.extractors = {}
11 |
12 | # If you want to add a language, add processing here
13 | # ---
14 | self.extractors['ja'] = JapaneseExtractor()
15 | #self.extractors['en'] = EnglishExtractor()
16 | # ---
17 |
18 | self.languages = []
19 | phoneme_vocabs = []
20 | for mod in self.extractors.values():
21 | phoneme_vocabs += mod.possible_phonemes()
22 | self.languages += self.extractors.keys()
23 | self.phoneme_vocabs = [''] + phoneme_vocabs
24 |
25 | def grapheme_to_phoneme(self, text: Union[str, List[str]], language: Union[str, List[str]]):
26 | if type(text) == list:
27 | return self._g2p_multiple(text, language)
28 | elif type(text) == str:
29 | return self._g2p_single(text, language)
30 |
31 | def _g2p_single(self, text, language):
32 | mod = self.extractors[language]
33 | return mod.g2p(text)
34 |
35 | def _g2p_multiple(self, text, language):
36 | result = []
37 | for t, l in zip(text, language):
38 | result.append(self._g2p_single(t, l))
39 | return result
40 |
41 | def phoneme_to_id(self, phonemes: Union[List[str], List[List[str]]]):
42 | if type(phonemes[0]) == list:
43 | return self._p2id_multiple(phonemes)
44 | elif type(phonemes[0]) == str:
45 | return self._p2id_single(phonemes)
46 |
47 | def _p2id_single(self, phonemes: List[str]):
48 | ids = []
49 | for p in phonemes:
50 | if p in self.phoneme_vocabs:
51 | ids.append(self.phoneme_vocabs.index(p))
52 | else:
53 | print("warning: unknown phoneme.")
54 | ids.append(0)
55 | return ids
56 |
57 | def _p2id_multiple(self, phonemes: List[List[str]]):
58 | sequences = []
59 | for s in phonemes:
60 | out = self._p2id_single(s)
61 | sequences.append(out)
62 | return sequences
63 |
64 | def language_to_id(self, languages: Union[str, List[str]]):
65 | if type(languages) == str:
66 | return self._l2id_single(languages)
67 | elif type(languages) == list:
68 | return self._l2id_multiple(languages)
69 |
70 | def _l2id_single(self, language):
71 | if language in self.languages:
72 | return self.languages.index(language)
73 | else:
74 | return 0
75 |
76 | def _l2id_multiple(self, languages):
77 | result = []
78 | for l in languages:
79 | result.append(self._l2id_single(l))
80 | return result
81 |
82 | def id_to_phoneme(self, ids):
83 | if type(ids[0]) == list:
84 | return self._id2p_multiple(ids)
85 | elif type(ids[0]) == int:
86 | return self._id2p_single(ids)
87 |
88 | def _id2p_single(self, ids: List[int]) -> List[str]:
89 | phonemes = []
90 | for i in ids:
91 | if i < len(self.phoneme_vocabs):
92 | p = self.phoneme_vocabs[i]
93 | else:
94 | p = ''
95 | phonemes.append(p)
96 | return phonemes
97 |
98 | def _id2p_multiple(self, ids: List[List[int]]) -> List[List[str]]:
99 | results = []
100 | for s in ids:
101 | results.append(self._id2p_single(s))
102 | return results
103 |
104 | def encode(self, sentences: List[str], languages: List[str], max_length: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
105 | ids, lengths = self._enc_multiple(sentences, languages, max_length)
106 | language_ids = self.language_to_id(languages)
107 |
108 | ids = torch.LongTensor(ids)
109 | lengths = torch.LongTensor(lengths)
110 | language_ids = torch.LongTensor(language_ids)
111 |
112 | return ids, lengths, language_ids
113 |
114 | def _enc_single(self, sentence, language, max_length):
115 | phonemes = self.grapheme_to_phoneme(sentence, language)
116 | ids = self.phoneme_to_id(phonemes)
117 | length = min(len(ids), max_length)
118 | if len(ids) > max_length:
119 | ids = ids[:max_length]
120 | while len(ids) < max_length:
121 | ids.append(0)
122 | return ids, length
123 |
124 | def _enc_multiple(self, sentences, languages, max_length):
125 | seq, lengths = [], []
126 | for s, l in zip(sentences, languages):
127 | ids, length = self._enc_single(s, l, max_length)
128 | seq.append(ids)
129 | lengths.append(length)
130 | return seq, lengths
131 |
132 |
--------------------------------------------------------------------------------
/module/g2p/english.py:
--------------------------------------------------------------------------------
1 | from .extractor import PhoneticExtractor
2 | from g2p_en import G2p
3 |
4 |
5 | class EnglishExtractor(PhoneticExtractor):
6 | def __init__(self):
7 | super().__init__()
8 | self.g2p_instance = G2p()
9 |
10 | def g2p(self, text):
11 | return self.g2p_instance(text)
12 |
13 | def possible_phonemes(self):
14 | phonemes = self.g2p_instance.phonemes
15 | phonemes.remove('')
16 | return phonemes
17 |
--------------------------------------------------------------------------------
/module/g2p/extractor.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | # grapheme-to-phoneme module base class
5 | # Create a module by inheriting this class for each language.
6 | class PhoneticExtractor():
7 | def __init__(self, *args, **kwargs):
8 | pass
9 |
10 | # grapheme to phoneme
11 | def g2p(self, text: str) -> List[str]:
12 | raise "Not Implemented"
13 |
14 | def possible_phonemes(self) -> List[str]:
15 | raise "Not Implemented"
16 |
--------------------------------------------------------------------------------
/module/g2p/japanese.py:
--------------------------------------------------------------------------------
1 | from .extractor import PhoneticExtractor
2 | import pyopenjtalk
3 | from typing import List
4 |
5 |
6 | class JapaneseExtractor(PhoneticExtractor):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def g2p(self, text) -> List[str]:
11 | phonemes = pyopenjtalk.g2p(text).split(" ")
12 | new_phonemes = []
13 | for p in phonemes:
14 | if p == 'pau':
15 | new_phonemes.append('')
16 | else:
17 | new_phonemes.append(p)
18 | return new_phonemes
19 |
20 | def possible_phonemes(self):
21 | return ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd',
22 | 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky',
23 | 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh',
24 | 't', 'ts', 'ty', 'u', 'v', 'w', 'y', 'z']
25 |
--------------------------------------------------------------------------------
/module/infer/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | import torch
4 | import math
5 | import json
6 | from pathlib import Path
7 |
8 | from safetensors import safe_open
9 | from safetensors.torch import save_file
10 |
11 | from module.vits import Generator, spectrogram
12 | from module.utils.safetensors import load_tensors
13 | from module.g2p import G2PProcessor
14 | from module.language_model import LanguageModel
15 | from module.utils.config import load_json_file
16 |
17 |
18 | class Infer:
19 | def __init__(self, safetensors_path, config_path, metadata_path, device=torch.device('cpu')):
20 | self.device = device
21 | self.config = load_json_file(config_path)
22 | self.metadata = load_json_file(metadata_path)
23 | self.g2p = G2PProcessor()
24 | self.lm = LanguageModel(self.config.language_model.type, self.config.language_model.options)
25 |
26 | # load generator
27 | generator = Generator(self.config.vits.generator)
28 | generator.load_state_dict(load_tensors(safetensors_path))
29 | generator = generator.to(self.device)
30 | self.generator = generator
31 |
32 | self.max_lm_tokens = self.config.infer.max_lm_tokens
33 | self.max_phonemes = self.config.infer.max_phonemes
34 | self.max_frames = self.config.infer.max_frames
35 |
36 | self.n_fft = self.config.infer.n_fft
37 | self.frame_size = self.config.infer.frame_size
38 | self.sample_rate = self.config.infer.sample_rate
39 |
40 | def speakers(self):
41 | return self.metadata.speakers
42 |
43 | def speaker_id(self, speaker):
44 | return self.speakers().index(speaker)
45 |
46 | def languages(self):
47 | return self.g2p.languages
48 |
49 | def language_id(self, language):
50 | return self.g2p.language_to_id(language)
51 |
52 | @torch.inference_mode()
53 | def text_to_speech(
54 | self,
55 | text: str,
56 | speaker: str,
57 | language: str,
58 | style_text: Union[None, str] = None,
59 | duration_scale=1.0,
60 | pitch_shift=0.0,
61 | ):
62 | spk = torch.LongTensor([self.speaker_id(speaker)])
63 | if style_text is None:
64 | style_text = text
65 | lm_feat, lm_feat_len = self.lm.encode([style_text], self.max_lm_tokens)
66 | phoneme, phoneme_len, lang = self.g2p.encode([text], [language], self.max_phonemes)
67 |
68 | device = self.device
69 | phoneme = phoneme.to(device)
70 | phoneme_len = phoneme_len.to(device)
71 | lm_feat = lm_feat.to(device)
72 | lm_feat_len = lm_feat_len.to(device)
73 | spk = spk.to(device)
74 | lang = lang.to(device)
75 |
76 | wf = self.generator.text_to_speech(
77 | phoneme,
78 | phoneme_len,
79 | lm_feat,
80 | lm_feat_len,
81 | lang,
82 | spk,
83 | duration_scale=duration_scale,
84 | pitch_shift=pitch_shift,
85 | )
86 | return wf.squeeze(0)
87 |
88 | # wf: [Channels, Length]
89 | @torch.inference_mode()
90 | def audio_reconstruction(self, wf: torch.Tensor, speaker:str):
91 | spk = torch.LongTensor([self.speaker_id(speaker)])
92 | wf = wf.sum(dim=0, keepdim=True)
93 | spec = spectrogram(wf, self.n_fft, self.frame_size)
94 | spec_len = torch.LongTensor([spec.shape[2]])
95 |
96 | device = self.device
97 | spec = spec.to(device)
98 | spec_len = spec_len.to(device)
99 | spk = spk.to(device)
100 |
101 | wf = self.generator.audio_reconstruction(spec, spec_len, spk)
102 | return wf.squeeze(0)
103 |
104 | def singing_voice_synthesis(self, score):
105 | parts = score['parts']
106 | for part_name, part in zip(parts.keys(), parts.values()):
107 | print(f"processing {part_name}")
108 | self._svs_generate_part(part)
109 |
110 | # TODO: コメントを英語にする、いつかやる。多分。
111 | def _svs_generate_part(self, part):
112 | language = part['language']
113 | style_text = part['style_text']
114 | speaker = part['speaker']
115 | notes = part['notes']
116 |
117 | # notes をonset でソート
118 | notes.sort(key=lambda x: x['onset'])
119 |
120 | # get begin and end time
121 | # 開始時刻[秒]と終了時刻[秒]をノート一覧から探す。もっとも小さいonsetが開始時刻で最も大きいoffsetが終了時刻。
122 | t_begin = None
123 | t_end = None
124 | # それと歌詞情報を取得する
125 | part_phonemes = [] # このパートの音素列
126 | note_phoneme_indices = [] # 各ノート毎の(開始index, 終了index)
127 | for note in notes:
128 | b = note['onset']
129 | e = note['offset']
130 | if t_begin is None:
131 | t_begin = b
132 | elif b < t_begin:
133 | t_begin = b
134 | if t_end is None:
135 | t_end = e
136 | elif b > t_end:
137 | t_end = e
138 |
139 | # 歌詞情報の処理
140 | note_phonemes = self.g2p.grapheme_to_phoneme(note['lyrics'], language)
141 | note_phoneme_indices.append((len(part_phonemes), len(part_phonemes) + len(note_phonemes) - 1))
142 | part_phonemes.extend(note_phonemes)
143 | # 音素の数
144 | num_phonemes = len(part_phonemes)
145 | # パートの長さを求める
146 | part_length = t_end - t_begin
147 | # 1秒間に何フレームか
148 | fps = self.sample_rate / self.frame_size
149 | # 生成するフレーム数
150 | num_frames = math.ceil(part_length * fps)
151 | # ピッチ列のバッファ。この段階ではまだMIDIのスケール。-infに近い値で埋めておく。(self._midi2f0(-inf) = 0なので、発声がない区間を0Hzにしたい。)
152 | pitch = torch.full([num_frames], -1e10)
153 | # エネルギー列のバッファ。 これは初期値0
154 | energy = torch.full([num_frames], 0.0)
155 | # Duration
156 | duration = torch.full([num_phonemes], 0.0)
157 | # 話者をエンコードする
158 | speaker_id = self.speaker_id(speaker)
159 | speaker_id = torch.LongTensor([speaker_id])
160 | spk = self.generator.speaker_embedding(speaker_id)
161 | # 言語をエンコードする
162 | lang_id = self.language_id(language)
163 | lang = torch.LongTensor([lang_id])
164 |
165 | # 音素とテキスト. LMの特徴量をエンコードする
166 | phonemes = torch.LongTensor(self.g2p.phoneme_to_id(part_phonemes)).unsqueeze(1)
167 | phonemes_len = torch.LongTensor([phonemes.shape[1]])
168 | lm_feat, lm_feat_len = self.lm.encode([style_text], self.max_lm_tokens)
169 | text_encoded, text_mean, text_logvar, text_mask = self.generator.prior_encoder.text_encoder(phonemes, phonemes_len, lm_feat, lm_feat_len, spk, lang)
170 | # durationを推定する
171 | log_dur = self.generator.prior_encoder.stochastic_duration_predictor(text_encoded, text_mask, g=spk, reverse=True)
172 | duration = torch.ceil(torch.exp(log_dur)).to(torch.long)
173 |
174 | # ノートごとに処理する
175 | for i, note in enumerate(notes):
176 | phoneme_begin, phoneme_end = note_phoneme_indices[i]
177 |
178 | # ノートの始点と終点をフレーム単位に変換
179 | onset = round((note['onset'] - t_begin) * fps)
180 | offset = round((note['offset'] - t_begin) * fps)
181 |
182 | # 代入
183 | pitch[onset:offset] = float(note['pitch'])
184 | energy[onset:offset] = float(note['energy'])
185 |
186 | # TODO: ビブラートとかフォールとかenergyの調整とか
187 |
188 | print(pitch)
189 |
190 |
191 | def _f02midi(self, f0):
192 | return torch.log2(f0 / 440.0) * 12.0 + 69.0
193 |
194 | def _midi2f0(self, n):
195 | return 440.0 * 2 ** ((n - 69.0) / 12.0)
196 |
--------------------------------------------------------------------------------
/module/language_model/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .rinna_roberta import RinnaRoBERTaExtractor
3 |
4 |
5 | def get_extractor(typename):
6 | if typename == "rinna_roberta":
7 | return RinnaRoBERTaExtractor
8 | else:
9 | raise "Unknown linguistic extractor type"
10 |
11 |
12 | class LanguageModel:
13 | def __init__(self, extractor_type, options):
14 | ext_constructor = get_extractor(extractor_type)
15 | self.extractor = ext_constructor(**options)
16 |
17 | def encode(self, sentences, max_length: int):
18 | if type(sentences) == list:
19 | return self._ext_multiple(sentences, max_length)
20 | elif type(sentences) == str:
21 | return self._ext_single(sentences, max_length)
22 |
23 | def _ext_single(self, sentence, max_length: int):
24 | features, length = self.extractor.extract(sentence)
25 | features = features.cpu()
26 |
27 | N, L, D = features.shape
28 | # add padding
29 | if L < max_length:
30 | pad = torch.zeros(N, max_length - L, D)
31 | features = torch.cat([pad, features], dim=1)
32 | # crop
33 | if L > max_length:
34 | features = features[:, :max_length, :]
35 | # length
36 | length = min(length, max_length)
37 |
38 | return features, length
39 |
40 | def _ext_multiple(self, sentences, max_length: int):
41 | lengths = []
42 | features = []
43 | for s in sentences:
44 | f, l = self._ext_single(s, max_length)
45 | features.append(f)
46 | lengths.append(l)
47 | features = torch.cat(features, dim=0)
48 | lengths = torch.LongTensor(lengths)
49 | return features, lengths
50 |
--------------------------------------------------------------------------------
/module/language_model/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Tuple
3 |
4 | # base class of linguistic feature extractor
5 | class LinguisticExtractor:
6 | def __init__(self, *args, **kwargs):
7 | pass
8 |
9 | # Output: [1, length, lm_dim]
10 | def extract(self, str) -> Tuple[torch.Tensor, int]:
11 | pass
12 |
--------------------------------------------------------------------------------
/module/language_model/rinna_roberta.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, AutoModelForMaskedLM
3 | from .extractor import LinguisticExtractor
4 |
5 |
6 | # feature extractor using rinna/japanese-roberta-base
7 | # from: https://huggingface.co/rinna/japanese-roberta-base
8 | class RinnaRoBERTaExtractor(LinguisticExtractor):
9 | def __init__(self, hf_repo: str, layer: int, device=None):
10 | super().__init__()
11 | self.layer = layer
12 |
13 | if device is None:
14 | if torch.cuda.is_available(): # CUDA available
15 | device = torch.device('cuda')
16 | elif torch.backends.mps.is_available(): # on macos
17 | device = torch.device('mps')
18 | else:
19 | device = torch.device('cpu')
20 | else:
21 | device = torch.device(device)
22 | self.device = device
23 |
24 | self.tokenizer = AutoTokenizer.from_pretrained(hf_repo)
25 | self.tokenizer.do_lower_case = True
26 | self.model = AutoModelForMaskedLM.from_pretrained(hf_repo)
27 | self.model.to(device)
28 |
29 | def extract(self, text):
30 | # add [CLS] token
31 | text = "[CLS]" + text
32 |
33 | # tokenize
34 | tokens = self.tokenizer.tokenize(text)
35 |
36 | # convert to ids
37 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
38 |
39 | # convert to tensor
40 | token_tensor = torch.LongTensor([token_ids]).to(self.device)
41 |
42 | # provide position id explicitly
43 | position_ids = list(range(0, token_tensor.shape[1]))
44 | position_id_tensor = torch.LongTensor([position_ids]).to(self.device)
45 |
46 | with torch.no_grad():
47 | outputs = self.model(
48 | input_ids=token_tensor,
49 | position_ids=position_id_tensor,
50 | output_hidden_states=True)
51 | features = outputs.hidden_states[self.layer]
52 | # return features and length
53 | return features, token_tensor.shape[1]
54 |
--------------------------------------------------------------------------------
/module/preprocess/jvs.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from .processor import Preprocessor
3 | from tqdm import tqdm
4 |
5 |
6 | def process_category(path: Path, category, processor: Preprocessor, speaker_name, config):
7 | print(f"Ppocessing {str(path)}")
8 | audio_dir = path / "wav24kHz16bit"
9 | transcription_path = path / "transcripts_utf8.txt"
10 | with open(transcription_path, encoding='utf-8') as f:
11 | transcription_text = f.read()
12 |
13 | counter = 0
14 | for metadata in tqdm(transcription_text.split("\n")):
15 | s = metadata.split(":")
16 | if len(s) >= 2:
17 | audio_file_name, transcription = s[0], s[1]
18 | audio_file_path = audio_dir / (audio_file_name + ".wav")
19 | if not audio_file_path.exists():
20 | continue
21 | processor.write_cache(
22 | audio_file_path,
23 | transcription,
24 | 'ja',
25 | speaker_name,
26 | f"{category}_{counter}"
27 | )
28 | counter += 1
29 |
30 | def preprocess_jvs(jvs_root: Path, config):
31 | processor = Preprocessor(config)
32 | cache_dir = Path(config['preprocess']['cache'])
33 | for subdir in jvs_root.glob("*/"):
34 | if subdir.is_dir():
35 | print(f"Processing {subdir}")
36 | speaker_name = subdir.name
37 | process_category(subdir / "nonpara30", "nonpara30", processor, speaker_name, config)
38 | process_category(subdir / "parallel100", "paralell100", processor, speaker_name, config)
39 |
--------------------------------------------------------------------------------
/module/preprocess/processor.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import torchaudio
5 | from torchaudio.functional import resample
6 |
7 | from module.g2p import G2PProcessor
8 | from module.language_model import LanguageModel
9 | from module.utils.f0_estimation import estimate_f0
10 |
11 |
12 | # for dataset preprocess
13 | class Preprocessor:
14 | def __init__(self, config):
15 | self.g2p = G2PProcessor()
16 | self.lm = LanguageModel(config.language_model.type, config.language_model.options)
17 | self.max_phonemes = config.preprocess.max_phonemes
18 | self.lm_max_tokens = config.preprocess.lm_max_tokens
19 | self.pitch_estimation = config.preprocess.pitch_estimation
20 | self.max_waveform_length = config.preprocess.max_waveform_length
21 | self.sample_rate = config.preprocess.sample_rate
22 | self.frame_size = config.preprocess.frame_size
23 | self.config = config
24 |
25 | def write_cache(self, waveform_path: Path, transcription: str, language: str, speaker_name: str, data_name: str):
26 | # load waveform file
27 | wf, sr = torchaudio.load(waveform_path)
28 |
29 | # resampling
30 | if sr != self.sample_rate:
31 | wf = resample(wf, sr, self.sample_rate) # [Channels, Length_wf]
32 |
33 | # mix down
34 | wf = wf.sum(dim=0) # [Length_wf]
35 |
36 | # get length frame size
37 | spec_len = torch.LongTensor([wf.shape[0] // self.frame_size])
38 |
39 | # padding
40 | if wf.shape[0] < self.max_waveform_length:
41 | wf = torch.cat([wf, torch.zeros(self.max_waveform_length - wf.shape[0])])
42 |
43 | # crop
44 | if wf.shape[0] > self.max_waveform_length:
45 | wf = wf[:self.max_waveform_length]
46 |
47 | wf = wf.unsqueeze(0) # [1, Length_wf]
48 |
49 | # estimate f0(pitch)
50 | f0 = estimate_f0(wf, self.sample_rate, self.frame_size, self.pitch_estimation)
51 |
52 | # get phonemes
53 | phonemes, phonemes_len, language = self.g2p.encode([transcription], [language], self.max_phonemes)
54 |
55 | # get lm features
56 | lm_feat, lm_feat_len = self.lm.encode([transcription], self.lm_max_tokens)
57 |
58 | # to dict
59 | metadata = {
60 | "spec_len": spec_len,
61 | "f0": f0,
62 | "phonemes": phonemes,
63 | "phonemes_len": phonemes_len,
64 | "language": language,
65 | "lm_feat": lm_feat,
66 | "lm_feat_len": lm_feat_len,
67 | }
68 |
69 | # get target dir.
70 | cache_dir = Path(self.config.preprocess.cache)
71 | subdir = cache_dir / speaker_name
72 |
73 | # check exists subdir
74 | if not subdir.exists():
75 | subdir.mkdir()
76 |
77 | audio_path = subdir / (data_name + ".wav")
78 | metadata_path = subdir / (data_name + ".pt")
79 |
80 | # save
81 | torchaudio.save(audio_path, wf, self.sample_rate)
82 | torch.save(metadata, metadata_path)
83 |
--------------------------------------------------------------------------------
/module/preprocess/scan.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from module.g2p import G2PProcessor
4 |
5 |
6 | # create metadata
7 | def scan_cache(config):
8 | cache_dir = Path(config.preprocess.cache)
9 | models_dir = Path("models")
10 | metadata_path = models_dir / "metadata.json"
11 | if not models_dir.exists():
12 | models_dir.mkdir()
13 |
14 | speaker_names = []
15 | for subdir in cache_dir.glob("*"):
16 | if subdir.is_dir():
17 | speaker_names.append(subdir.name)
18 | speaker_names = sorted(speaker_names)
19 | g2p = G2PProcessor()
20 | phonemes = g2p.phoneme_vocabs
21 | languages = g2p.languages
22 | num_harmonics = config.vits.generator.decoder.num_harmonics
23 | sample_rate = config.vits.generator.decoder.sample_rate
24 | frame_size = config.vits.generator.decoder.frame_size
25 | metadata = {
26 | "speakers": speaker_names, # speaker list
27 | "phonemes": phonemes,
28 | "languages": languages,
29 | "num_harmonics": num_harmonics,
30 | "sample_rate": sample_rate,
31 | "frame_size": frame_size
32 | }
33 |
34 | with open(metadata_path, 'w') as f:
35 | json.dump(metadata, f)
--------------------------------------------------------------------------------
/module/preprocess/wave_and_text.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from .processor import Preprocessor
3 | from tqdm import tqdm
4 |
5 |
6 | LANGUAGE = 'ja'
7 |
8 | def process_speaker(subdir: Path, processor):
9 | speaker_name = subdir.stem
10 | counter = 0
11 | for wave_file in tqdm(subdir.rglob("*.wav")):
12 | text_file = wave_file.parent / (wave_file.stem + ".txt")
13 | if text_file.exists():
14 | with open(text_file) as f:
15 | text = f.read()
16 | else:
17 | continue
18 | processor.write_cache(
19 | wave_file,
20 | text,
21 | LANGUAGE,
22 | speaker_name,
23 | f"{speaker_name}_{counter}",
24 | )
25 | counter += 1
26 |
27 |
28 | def preprocess_wave_and_text(root: Path, config):
29 | processor = Preprocessor(config)
30 | for subdir in root.glob("*/"):
31 | if subdir.is_dir():
32 | print(f"Processing {subdir.stem}")
33 | process_speaker(subdir, processor)
--------------------------------------------------------------------------------
/module/utils/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def init_weights(m, mean=0.0, std=0.01):
6 | classname = m.__class__.__name__
7 | if classname.find("Conv") != -1:
8 | m.weight.data.normal_(mean, std)
9 |
10 |
11 | def get_padding(kernel_size, dilation=1):
12 | return int((kernel_size*dilation - dilation)/2)
13 |
14 |
15 | def convert_pad_shape(pad_shape):
16 | l = pad_shape[::-1]
17 | pad_shape = [item for sublist in l for item in sublist]
18 | return pad_shape
19 |
--------------------------------------------------------------------------------
/module/utils/config.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 |
4 | class Config:
5 | def __init__(self, **kwargs):
6 | for k, v in kwargs.items():
7 | if type(v) == dict:
8 | v = Config(**v)
9 | self[k] = v
10 |
11 | def keys(self):
12 | return self.__dict__.keys()
13 |
14 | def items(self):
15 | return self.__dict__.items()
16 |
17 | def values(self):
18 | return self.__dict__.values()
19 |
20 | def __len__(self):
21 | return len(self.__dict__)
22 |
23 | def __getitem__(self, key):
24 | return getattr(self, key)
25 |
26 | def __setitem__(self, key, value):
27 | return setattr(self, key, value)
28 |
29 | def __contains__(self, key):
30 | return key in self.__dict__
31 |
32 | def __repr__(self):
33 | return self.__dict__.__repr__()
34 |
35 |
36 | def load_json_file(path):
37 | return Config(**json.load(open(path, encoding='utf-8')))
38 |
--------------------------------------------------------------------------------
/module/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 | import json
5 | from pathlib import Path
6 | import lightning as L
7 | from torch.utils.data import DataLoader, random_split
8 |
9 |
10 | class VitsDataset(torch.utils.data.Dataset):
11 | def __init__(self, cache_dir='dataset_cache', metadata='models/metadata.json'):
12 | super().__init__()
13 | self.root = Path(cache_dir)
14 | metadata = json.load(open(Path(metadata)))
15 | self.speakers = metadata['speakers']
16 | self.audio_file_paths = []
17 | self.speaker_ids = []
18 | self.metadata_paths = []
19 | for path in self.root.glob("*/*.wav"):
20 | self.audio_file_paths.append(path)
21 | spk = path.parent.name
22 | self.speaker_ids.append(self._speaker_id(spk))
23 | metadata_path = path.parent / (path.stem + ".pt")
24 | self.metadata_paths.append(metadata_path)
25 |
26 | def _speaker_id(self, speaker: str) -> int:
27 | return self.speakers.index(speaker)
28 |
29 | def __getitem__(self, idx):
30 | speaker_id = self.speaker_ids[idx]
31 | wf, sr = torchaudio.load(self.audio_file_paths[idx])
32 | metadata = torch.load(self.metadata_paths[idx])
33 |
34 | spec_len = metadata['spec_len'].item()
35 | f0 = metadata['f0'].squeeze(0)
36 | phoneme = metadata['phonemes'].squeeze(0)
37 | phoneme_len = metadata['phonemes_len'].item()
38 | language = metadata['language'].item()
39 | lm_feat = metadata['lm_feat'].squeeze(0)
40 | lm_feat_len = metadata['lm_feat_len'].item()
41 | return wf, spec_len, speaker_id, f0, phoneme, phoneme_len, lm_feat, lm_feat_len, language
42 |
43 | def __len__(self):
44 | return len(self.audio_file_paths)
45 |
46 |
47 | class VitsDataModule(L.LightningDataModule):
48 | def __init__(
49 | self,
50 | cache_dir='dataset_cache',
51 | metadata='models/metadata.json',
52 | batch_size=1,
53 | num_workers=1,
54 | ):
55 | super().__init__()
56 | self.cache_dir = cache_dir
57 | self.metadata = metadata
58 | self.batch_size = batch_size
59 | self.num_workers = num_workers
60 |
61 | def train_dataloader(self):
62 | dataset = VitsDataset(
63 | self.cache_dir,
64 | self.metadata)
65 | dataloader = DataLoader(
66 | dataset,
67 | self.batch_size,
68 | shuffle=True,
69 | num_workers=self.num_workers,
70 | persistent_workers=(os.name=='nt'))
71 | return dataloader
72 |
--------------------------------------------------------------------------------
/module/utils/energy_estimation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | # estimate energy from power spectrogram
6 | # input:
7 | # spec: [BatchSize, n_fft//2+1, Length]
8 | # output:
9 | # energy: [BatchSize, 1, Length]
10 | def estimate_energy(spec):
11 | fft_bin = spec.shape[1]
12 | return spec.max(dim=1, keepdim=True).values / fft_bin
13 |
--------------------------------------------------------------------------------
/module/utils/f0_estimation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torchaudio.functional import resample
6 | from torchfcpe import spawn_bundled_infer_model
7 |
8 | import numpy as np
9 | import pyworld as pw
10 |
11 |
12 | def estimate_f0_dio(wf, sample_rate=48000, segment_size=960, f0_min=20, f0_max=20000):
13 | if wf.ndim == 1:
14 | device = wf.device
15 | signal = wf.detach().cpu().numpy()
16 | signal = signal.astype(np.double)
17 | _f0, t = pw.dio(signal, sample_rate, f0_floor=f0_min, f0_ceil=f0_max)
18 | f0 = pw.stonemask(signal, _f0, t, sample_rate)
19 | f0 = torch.from_numpy(f0).to(torch.float)
20 | f0 = f0.to(device)
21 | f0 = f0.unsqueeze(0).unsqueeze(0)
22 | f0 = F.interpolate(f0, wf.shape[0] // segment_size, mode='linear')
23 | f0 = f0.squeeze(0)
24 | return f0
25 | elif wf.ndim == 2:
26 | waves = wf.split(1, dim=0)
27 | pitchs = [estimate_f0_dio(wave[0], sample_rate, segment_size) for wave in waves]
28 | pitchs = torch.stack(pitchs, dim=0)
29 | return pitchs
30 |
31 |
32 | def estimate_f0_harvest(wf, sample_rate=4800, segment_size=960, f0_min=20, f0_max=20000):
33 | if wf.ndim == 1:
34 | device = wf.device
35 | signal = wf.detach().cpu().numpy()
36 | signal = signal.astype(np.double)
37 | f0, t = pw.harvest(signal, sample_rate, f0_floor=f0_min, f0_ceil=f0_max)
38 | f0 = torch.from_numpy(f0).to(torch.float)
39 | f0 = f0.to(device)
40 | f0 = f0.unsqueeze(0).unsqueeze(0)
41 | f0 = F.interpolate(f0, wf.shape[0] // segment_size, mode='linear')
42 | f0 = f0.squeeze(0)
43 | return f0
44 | elif wf.ndim == 2:
45 | waves = wf.split(1, dim=0)
46 | pitchs = [estimate_f0_harvest(wave[0], sample_rate, segment_size) for wave in waves]
47 | pitchs = torch.stack(pitchs, dim=0)
48 | return pitchs
49 |
50 |
51 | global torchfcpe_model
52 | torchfcpe_model = None
53 | def estimate_f0_fcpe(wf, sample_rate=48000, segment_size=960, f0_min=20, f0_max=20000):
54 | input_device = wf.device
55 | global torchfcpe_model
56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57 | wf = wf.to(device)
58 | if torchfcpe_model is None:
59 | torchfcpe_model = spawn_bundled_infer_model(device)
60 | f0 = torchfcpe_model.infer(wf.unsqueeze(2), sample_rate)
61 | f0 = f0.transpose(1, 2)
62 | f0 = f0.to(input_device)
63 | return f0
64 |
65 | # wf: [BatchSize, Length]
66 | def estimate_f0(wf, sample_rate=48000, segment_size=960, algorithm='harvest'):
67 | l = wf.shape[1]
68 | if algorithm == 'harvest':
69 | f0 = estimate_f0_harvest(wf, sample_rate)
70 | elif algorithm == 'dio':
71 | f0 = estimate_f0_dio(wf, sample_rate)
72 | elif algorithm == 'fcpe':
73 | f0 = estimate_f0_fcpe(wf, sample_rate)
74 | return F.interpolate(f0, l // segment_size, mode='linear')
75 |
--------------------------------------------------------------------------------
/module/utils/safetensors.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | from safetensors import safe_open
4 | from safetensors.torch import save_file
5 |
6 |
7 | def save_tensors(tensors: Dict[str, torch.Tensor], path):
8 | save_file(tensors, path)
9 |
10 | def load_tensors(path) -> Dict[str, torch.Tensor]:
11 | tensors = {}
12 | with safe_open(path, framework="pt", device="cpu") as f:
13 | for key in f.keys():
14 | tensors[key] = f.get_tensor(key)
15 | return tensors
--------------------------------------------------------------------------------
/module/vits/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 |
6 | import lightning as L
7 |
8 | from .generator import Generator
9 | from .discriminator import Discriminator
10 | from .duration_discriminator import DurationDiscriminator
11 | from .loss import mel_spectrogram_loss, generator_adversarial_loss, discriminator_adversarial_loss, feature_matching_loss, duration_discriminator_adversarial_loss, duration_generator_adversarial_loss
12 | from .crop import crop_features, crop_waveform, decide_crop_range
13 | from .spectrogram import spectrogram
14 |
15 |
16 | class Vits(L.LightningModule):
17 | def __init__(
18 | self,
19 | config,
20 | ):
21 | super().__init__()
22 | self.generator = Generator(config.generator)
23 | self.discriminator = Discriminator(config.discriminator)
24 | self.duration_discriminator = DurationDiscriminator(**config.duration_discriminator)
25 | self.config = config
26 |
27 | # disable automatic optimization
28 | self.automatic_optimization = False
29 | # save hyperparameters
30 | self.save_hyperparameters()
31 |
32 | def training_step(self, batch):
33 | wf, spec_len, spk, f0, phoneme, phoneme_len, lm_feat, lm_feat_len, lang = batch
34 | wf = wf.squeeze(1) # [Batch, WaveLength]
35 | # get optimizer
36 | opt_g, opt_d, opt_dd = self.optimizers()
37 |
38 | # spectrogram
39 | n_fft = self.generator.posterior_encoder.n_fft
40 | frame_size = self.generator.posterior_encoder.frame_size
41 | spec = spectrogram(wf, n_fft, frame_size)
42 |
43 | # decide crop range
44 | crop_range = decide_crop_range(spec.shape[2], self.config.segment_size)
45 |
46 | # crop real waveform
47 | real = crop_waveform(wf, crop_range, frame_size)
48 |
49 | # start tracking gradient G.
50 | self.toggle_optimizer(opt_g)
51 |
52 | # calculate loss
53 | lossG, loss_dict, (text_encoded, text_mask, fake_log_duration, real_log_duration, spk_emb, dsp_out, fake) = self.generator(
54 | spec, spec_len, phoneme, phoneme_len, lm_feat, lm_feat_len, f0, spk, lang, crop_range)
55 |
56 | loss_dsp = mel_spectrogram_loss(dsp_out, real)
57 | loss_mel = mel_spectrogram_loss(fake, real)
58 | logits_fake, fmap_fake = self.discriminator(fake)
59 | _, fmap_real = self.discriminator(real)
60 | loss_feat = feature_matching_loss(fmap_real, fmap_fake)
61 | loss_adv = generator_adversarial_loss(logits_fake)
62 | dur_logit_fake = self.duration_discriminator(text_encoded, text_mask, fake_log_duration, spk_emb)
63 | loss_dadv = duration_generator_adversarial_loss(dur_logit_fake, text_mask)
64 |
65 | lossG += loss_mel * 45.0 + loss_dsp + loss_feat + loss_adv + loss_dadv
66 | self.manual_backward(lossG)
67 | opt_g.step()
68 | opt_g.zero_grad()
69 |
70 | # stop tracking gradient G.
71 | self.untoggle_optimizer(opt_g)
72 |
73 | # start tracking gradient D.
74 | self.toggle_optimizer(opt_d)
75 |
76 | # calculate loss
77 | fake = fake.detach()
78 | logits_fake, _ = self.discriminator(fake)
79 | logits_real, _ = self.discriminator(real)
80 |
81 | lossD = discriminator_adversarial_loss(logits_real, logits_fake)
82 | self.manual_backward(lossD)
83 | opt_d.step()
84 | opt_d.zero_grad()
85 |
86 | # stop tracking gradient D.
87 | self.untoggle_optimizer(opt_d)
88 |
89 | # start tracking gradient Duration Discriminator
90 | self.toggle_optimizer(opt_dd)
91 |
92 | fake_log_duration = fake_log_duration.detach()
93 | real_log_duration = real_log_duration.detach()
94 | text_mask = text_mask.detach()
95 | text_encoded = text_encoded.detach()
96 | spk_emb = spk_emb.detach()
97 | dur_logit_real = self.duration_discriminator(text_encoded, text_mask, real_log_duration, spk_emb)
98 | dur_logit_fake = self.duration_discriminator(text_encoded, text_mask, fake_log_duration, spk_emb)
99 |
100 | lossDD = duration_discriminator_adversarial_loss(dur_logit_real, dur_logit_fake, text_mask)
101 | self.manual_backward(lossDD)
102 | opt_dd.step()
103 | opt_dd.zero_grad()
104 |
105 | # stop tracking gradient Duration Discriminator
106 | self.untoggle_optimizer(opt_dd)
107 |
108 | # write log
109 | loss_dict['Mel'] = loss_mel.item()
110 | loss_dict['Generator Adversarial'] = loss_adv.item()
111 | loss_dict['DSP'] = loss_dsp.item()
112 | loss_dict['Feature Matching'] = loss_feat.item()
113 | loss_dict['Discriminator Adversarial'] = lossD.item()
114 | loss_dict['Duration Generator Adversarial'] = loss_dadv.item()
115 | loss_dict['Duration Discriminator Adversarial'] = lossDD.item()
116 |
117 | for k, v in zip(loss_dict.keys(), loss_dict.values()):
118 | self.log(f"loss/{k}", v)
119 |
120 | def configure_optimizers(self):
121 | lr = self.config.optimizer.lr
122 | betas = self.config.optimizer.betas
123 |
124 | opt_g = optim.AdamW(self.generator.parameters(), lr, betas)
125 | opt_d = optim.AdamW(self.discriminator.parameters(), lr, betas)
126 | opt_dd = optim.AdamW(self.duration_discriminator.parameters(), lr, betas)
127 | return opt_g, opt_d, opt_dd
128 |
--------------------------------------------------------------------------------
/module/vits/convnext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .normalization import LayerNorm, GRN
6 |
7 | # ConvNeXt v2
8 | class ConvNeXtLayer(nn.Module):
9 | def __init__(self, channels=512, kernel_size=7, mlp_mul=3):
10 | super().__init__()
11 | padding = kernel_size // 2
12 | self.c1 = nn.Conv1d(channels, channels, kernel_size, 1, padding, groups=channels)
13 | self.norm = LayerNorm(channels)
14 | self.c2 = nn.Conv1d(channels, channels * mlp_mul, 1)
15 | self.grn = GRN(channels * mlp_mul)
16 | self.c3 = nn.Conv1d(channels * mlp_mul, channels, 1)
17 |
18 | # x: [batchsize, channels, length]
19 | def forward(self, x):
20 | res = x
21 | x = self.c1(x)
22 | x = self.norm(x)
23 | x = self.c2(x)
24 | x = F.gelu(x)
25 | x = self.grn(x)
26 | x = self.c3(x)
27 | x = x + res
28 | return x
29 |
--------------------------------------------------------------------------------
/module/vits/crop.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | def decide_crop_range(max_length=500, frames=50):
5 | left = random.randint(0, max_length-frames)
6 | right = left + frames
7 | return (left, right)
8 |
9 |
10 | def crop_features(z, crop_range):
11 | left, right = crop_range[0], crop_range[1]
12 | return z[:, :, left:right]
13 |
14 |
15 | def crop_waveform(wf, crop_range, frame_size):
16 | left, right = crop_range[0], crop_range[1]
17 | return wf[:, left*frame_size:right*frame_size]
18 |
19 |
--------------------------------------------------------------------------------
/module/vits/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from torch.nn.utils.parametrizations import weight_norm
8 | from torch.nn.utils import remove_weight_norm
9 | from module.utils.common import get_padding, init_weights
10 | from .normalization import LayerNorm
11 | from .convnext import ConvNeXtLayer
12 | from .feature_retrieval import match_features
13 |
14 |
15 | # Oscillate harmonic signal
16 | #
17 | # Inputs ---
18 | # f0: [BatchSize, 1, Frames]
19 | #
20 | # frame_size: int
21 | # sample_rate: float or int
22 | # min_frequency: float
23 | # num_harmonics: int
24 | #
25 | # Output: [BatchSize, NumHarmonics+1, Length]
26 | #
27 | # length = Frames * frame_size
28 | @torch.cuda.amp.autocast(enabled=False)
29 | def oscillate_harmonics(
30 | f0,
31 | frame_size=960,
32 | sample_rate=48000,
33 | num_harmonics=0,
34 | min_frequency=20.0
35 | ):
36 | N = f0.shape[0]
37 | C = num_harmonics + 1
38 | Lf = f0.shape[2]
39 | Lw = Lf * frame_size
40 |
41 | device = f0.device
42 |
43 | # generate frequency of harmonics
44 | mul = (torch.arange(C, device=device) + 1).unsqueeze(0).unsqueeze(2)
45 |
46 | # change length to wave's
47 | fs = F.interpolate(f0, Lw, mode='linear') * mul
48 |
49 | # unvoiced / voiced mask
50 | uv = (f0 > min_frequency).to(torch.float)
51 | uv = F.interpolate(uv, Lw, mode='linear')
52 |
53 | # generate harmonics
54 | I = torch.cumsum(fs / sample_rate, dim=2) # numerical integration
55 | theta = 2 * math.pi * (I % 1) # convert to radians
56 |
57 | harmonics = torch.sin(theta) * uv
58 |
59 | return harmonics.to(device)
60 |
61 |
62 | # Oscillate noise via gaussian noise and equalizer
63 | #
64 | # fft_bin = n_fft // 2 + 1
65 | # kernels: [BatchSize, fft_bin, Frames]
66 | #
67 | # Output: [BatchSize, 1, Frames * frame_size]
68 | @torch.cuda.amp.autocast(enabled=False)
69 | def oscillate_noise(kernels, frame_size=960, n_fft=3840):
70 | device = kernels.device
71 | N = kernels.shape[0]
72 | Lf = kernels.shape[2] # frame length
73 | Lw = Lf * frame_size # waveform length
74 | dtype = kernels.dtype
75 |
76 | gaussian_noise = torch.randn(N, Lw, device=device, dtype=torch.float)
77 | kernels = kernels.to(torch.float) # to fp32
78 |
79 | # calculate convolution in fourier-domain
80 | # Since the input is an aperiodic signal such as Gaussian noise,
81 | # there is no need to consider the phase on the kernel side.
82 | w = torch.hann_window(n_fft, dtype=torch.float, device=device)
83 | noise_stft = torch.stft(gaussian_noise, n_fft, frame_size, window=w, return_complex=True)[:, :, 1:]
84 | y_stft = noise_stft * kernels # In fourier domain, Multiplication means convolution.
85 | y_stft = F.pad(y_stft, [1, 0]) # pad
86 | y = torch.istft(y_stft, n_fft, frame_size, window=w)
87 | y = y.unsqueeze(1)
88 | y = y.to(dtype)
89 | return y
90 |
91 |
92 | # this model estimates fundamental frequency (f0) and energy from z
93 | class PitchEnergyEstimator(nn.Module):
94 | def __init__(self,
95 | content_channels=192,
96 | speaker_embedding_dim=256,
97 | internal_channels=256,
98 | num_classes=512,
99 | kernel_size=7,
100 | num_layers=6,
101 | mlp_mul=3,
102 | min_frequency=20.0,
103 | classes_per_octave=48,
104 | ):
105 | super().__init__()
106 | self.content_channels = content_channels
107 | self.num_classes = num_classes
108 | self.classes_per_octave = classes_per_octave
109 | self.min_frequency = min_frequency
110 |
111 | self.content_input = nn.Conv1d(content_channels, internal_channels, 1)
112 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1)
113 | self.input_norm = LayerNorm(internal_channels)
114 | self.mid_layers = nn.Sequential(
115 | *[ConvNeXtLayer(internal_channels, kernel_size, mlp_mul) for _ in range(num_layers)])
116 | self.output_norm = LayerNorm(internal_channels)
117 | self.to_f0_logits = nn.Conv1d(internal_channels, num_classes, 1)
118 | self.to_energy = nn.Conv1d(internal_channels, 1, 1)
119 |
120 | # z_p: [BatchSize, content_channels, Length]
121 | # spk: [BatchSize, speaker_embedding_dim, Length]
122 | def forward(self, z_p, spk):
123 | x = self.content_input(z_p) + self.speaker_input(spk)
124 | x = self.input_norm(x)
125 | x = self.mid_layers(x)
126 | x = self.output_norm(x)
127 | logits, energy = self.to_f0_logits(x), self.to_energy(x)
128 | energy = F.elu(energy) + 1.0
129 | return logits, energy
130 |
131 | # f: []
132 | def freq2id(self, f):
133 | fmin = self.min_frequency
134 | cpo = self.classes_per_octave
135 | nc = self.num_classes
136 | return torch.ceil(torch.clamp(cpo * torch.log2(f / fmin), 0, nc-1)).to(torch.long)
137 |
138 | # ids: []
139 | def id2freq(self, ids):
140 | fmin = self.min_frequency
141 | cpo = self.classes_per_octave
142 | x = ids.to(torch.float)
143 | x = fmin * (2 ** (x / cpo))
144 | x[x <= self.min_frequency] = 0
145 | return x
146 |
147 | # z_p: [BatchSize, content_channels, Length]
148 | # spk: [BatchSize, speaker_embedding_dim, Length]
149 | # Outputs:
150 | # f0: [BatchSize, 1, Length]
151 | # energy: [BatchSize, 1, Length]
152 | def infer(self, z_p, spk, k=4):
153 | logits, energy = self.forward(z_p, spk)
154 | probs, indices = torch.topk(logits, k, dim=1)
155 | probs = F.softmax(probs, dim=1)
156 | freqs = self.id2freq(indices)
157 | f0 = (probs * freqs).sum(dim=1, keepdim=True)
158 | f0[f0 <= self.min_frequency] = 0
159 | return f0, energy
160 |
161 |
162 | class FiLM(nn.Module):
163 | def __init__(self, in_channels, condition_channels):
164 | super().__init__()
165 | self.to_shift = weight_norm(nn.Conv1d(condition_channels, in_channels, 1))
166 | self.to_scale = weight_norm(nn.Conv1d(condition_channels, in_channels, 1))
167 |
168 | # x: [BatchSize, in_channels, Length]
169 | # c: [BatchSize, condition_channels, Length]
170 | def forward(self, x, c):
171 | shift = self.to_shift(c)
172 | scale = self.to_scale(c)
173 | return x * scale + shift
174 |
175 | def remove_weight_norm(self):
176 | remove_weight_norm(self.to_shift)
177 | remove_weight_norm(self.to_scale)
178 |
179 |
180 | class SourceNet(nn.Module):
181 | def __init__(
182 | self,
183 | sample_rate=48000,
184 | n_fft=3840,
185 | frame_size=960,
186 | content_channels=192,
187 | speaker_embedding_dim=256,
188 | internal_channels=512,
189 | num_harmonics=30,
190 | kernel_size=7,
191 | num_layers=6,
192 | mlp_mul=3
193 | ):
194 | super().__init__()
195 | self.sample_rate = sample_rate
196 | self.n_fft = n_fft
197 | self.frame_size = frame_size
198 | self.num_harmonics = num_harmonics
199 |
200 | self.content_input = nn.Conv1d(content_channels, internal_channels, 1)
201 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1)
202 | self.energy_input = nn.Conv1d(1, internal_channels, 1)
203 | self.f0_input = nn.Conv1d(1, internal_channels, 1)
204 | self.input_norm = LayerNorm(internal_channels)
205 | self.mid_layers = nn.Sequential(
206 | *[ConvNeXtLayer(internal_channels, kernel_size, mlp_mul) for _ in range(num_layers)])
207 | self.output_norm = LayerNorm(internal_channels)
208 | self.to_amps = nn.Conv1d(internal_channels, num_harmonics + 1, 1)
209 | self.to_kernels = nn.Conv1d(internal_channels, n_fft // 2 + 1, 1)
210 |
211 | # x: [BatchSize, content_channels, Length]
212 | # f0: [BatchSize, 1, Length]
213 | # energy: [BatchSize, 1, Length]
214 | # spk: [BatchSize, speaker_embedding_dim, 1]
215 | # Outputs:
216 | # amps: [BatchSize, num_harmonics+1, Length * frame_size]
217 | # kernels: [BatchSize, n_fft //2 + 1, Length]
218 | def amps_and_kernels(self, x, f0, energy, spk):
219 | x = self.content_input(x) + self.speaker_input(spk) + self.f0_input(torch.log(F.relu(f0) + 1e-6))
220 | x = x * self.energy_input(energy)
221 | x = self.input_norm(x)
222 | x = self.mid_layers(x)
223 | x = self.output_norm(x)
224 | amps = F.elu(self.to_amps(x)) + 1.0
225 | kernels = F.elu(self.to_kernels(x)) + 1.0
226 | return amps, kernels
227 |
228 |
229 | # x: [BatchSize, content_channels, Length]
230 | # f0: [BatchSize, 1, Length]
231 | # spk: [BatchSize, speaker_embedding_dim, 1]
232 | # Outputs:
233 | # dsp_out: [BatchSize, 1, Length * frame_size]
234 | # source: [BatchSize, 1, Length * frame_size]
235 | def forward(self, x, f0, energy, spk):
236 | amps, kernels = self.amps_and_kernels(x, f0, energy, spk)
237 |
238 | # oscillate source signals
239 | harmonics = oscillate_harmonics(f0, self.frame_size, self.sample_rate, self.num_harmonics)
240 | amps = F.interpolate(amps, scale_factor=self.frame_size, mode='linear')
241 | harmonics = harmonics * amps
242 | noise = oscillate_noise(kernels, self.frame_size, self.n_fft)
243 | source = torch.cat([harmonics, noise], dim=1)
244 | dsp_out = torch.sum(source, dim=1, keepdim=True)
245 |
246 | return dsp_out, source
247 |
248 |
249 | # HiFi-GAN's ResBlock1
250 | class ResBlock1(nn.Module):
251 | def __init__(self, channels, condition_channels, kernel_size=3, dilations=[1, 3, 5]):
252 | super().__init__()
253 | self.convs1 = nn.ModuleList([])
254 | self.convs2 = nn.ModuleList([])
255 | self.films = nn.ModuleList([])
256 |
257 | for d in dilations:
258 | padding = get_padding(kernel_size, 1)
259 | self.convs1.append(
260 | weight_norm(
261 | nn.Conv1d(channels, channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate')))
262 | padding = get_padding(kernel_size, d)
263 | self.convs2.append(
264 | weight_norm(
265 | nn.Conv1d(channels, channels, kernel_size, 1, padding, 1, padding_mode='replicate')))
266 | self.films.append(
267 | FiLM(channels, condition_channels))
268 | self.convs1.apply(init_weights)
269 | self.convs2.apply(init_weights)
270 | self.films.apply(init_weights)
271 |
272 | # x: [BatchSize, channels, Length]
273 | # c: [BatchSize, condition_channels, Length]
274 | def forward(self, x, c):
275 | for c1, c2, film in zip(self.convs1, self.convs2, self.films):
276 | res = x
277 | x = F.leaky_relu(x, 0.1)
278 | x = c1(x)
279 | x = F.leaky_relu(x, 0.1)
280 | x = c2(x)
281 | x = film(x, c)
282 | x = x + res
283 | return x
284 |
285 | def remove_weight_norm(self):
286 | for c1, c2, film in zip(self.convs1, self.convs2, self.films):
287 | remove_weight_norm(c1)
288 | remove_weight_norm(c2)
289 | film.remove_weight_norm()
290 |
291 |
292 | # HiFi-GAN's ResBlock2
293 | class ResBlock2(nn.Module):
294 | def __init__(self, channels, condition_channels, kernel_size=3, dilations=[1, 3]):
295 | super().__init__()
296 | self.convs = nn.ModuleList([])
297 | self.films = nn.ModuleList([])
298 | for d in dilations:
299 | padding = get_padding(kernel_size, d)
300 | self.convs.append(
301 | weight_norm(
302 | nn.Conv1d(channels, channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate')))
303 | self.films.append(FiLM(channels, condition_channels))
304 | self.convs.apply(init_weights)
305 | self.films.apply(init_weights)
306 |
307 | # x: [BatchSize, channels, Length]
308 | # c: [BatchSize, condition_channels, Length]
309 | def forward(self, x, c):
310 | for conv, film in zip(self.convs, self.films):
311 | res = x
312 | x = F.leaky_relu(x, 0.1)
313 | x = conv(x)
314 | x = film(x, c)
315 | x = x + res
316 | return x
317 |
318 | def remove_weight_norm(self):
319 | for conv, film in zip(self.convs, self.films):
320 | conv.remove_weight_norm()
321 | film.remove_weight_norm()
322 |
323 |
324 | # TinyVC's Block (from https://github.com/uthree/tinyvc)
325 | class ResBlock3(nn.Module):
326 | def __init__(self, channels, condition_channels, kernel_size=3, dilations=[1, 3, 9, 27]):
327 | super().__init__()
328 | assert len(dilations) == 4, "Resblock 3's len(dilations) should be 4."
329 | self.convs = nn.ModuleList([])
330 | self.films = nn.ModuleList([])
331 | for d in dilations:
332 | padding = get_padding(kernel_size, d)
333 | self.convs.append(
334 | weight_norm(
335 | nn.Conv1d(channels, channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate')))
336 | for _ in range(2):
337 | self.films.append(
338 | FiLM(channels, condition_channels))
339 |
340 | # x: [BatchSize, channels, Length]
341 | # c: [BatchSize, condition_channels, Length]
342 | def forward(self, x, c):
343 | res = x
344 | x = F.leaky_relu(x, 0.1)
345 | x = self.convs[0](x)
346 | x = F.leaky_relu(x, 0.1)
347 | x = self.convs[1](x)
348 | x = self.films[0](x, c)
349 | x = x + res
350 |
351 | res = x
352 | x = F.leaky_relu(x, 0.1)
353 | x = self.convs[2](x)
354 | x = F.leaky_relu(x, 0.1)
355 | x = self.convs[3](x)
356 | x = self.films[1](x, c)
357 | x = x + res
358 |
359 | return x
360 |
361 | def remove_weight_norm(self):
362 | for conv in self.convs:
363 | remove_weight_norm(conv)
364 | for film in self.films:
365 | film.remove_weight_norm()
366 |
367 |
368 | class MRF(nn.Module):
369 | def __init__(self,
370 | channels,
371 | condition_channels,
372 | resblock_type='1',
373 | kernel_sizes=[3, 7, 11],
374 | dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]]):
375 | super().__init__()
376 | self.blocks = nn.ModuleList([])
377 | self.num_blocks = len(kernel_sizes)
378 | if resblock_type == '1':
379 | block = ResBlock1
380 | elif resblock_type == '2':
381 | block = ResBlock2
382 | elif resblock_type == '3':
383 | block = ResBlock3
384 | for k, d in zip(kernel_sizes, dilations):
385 | self.blocks.append(block(channels, condition_channels, k, d))
386 |
387 | # x: [BatchSize, channels, Length]
388 | # c: [BatchSize, condition_channels, Length]
389 | def forward(self, x, c):
390 | xs = None
391 | for block in self.blocks:
392 | if xs is None:
393 | xs = block(x, c)
394 | else:
395 | xs += block(x, c)
396 | return xs / self.num_blocks
397 |
398 | def remove_weight_norm(self):
399 | for block in self.blocks:
400 | block.remove_weight_norm()
401 |
402 |
403 | class UpBlock(nn.Module):
404 | def __init__(self,
405 | in_channels,
406 | out_channels,
407 | condition_channels,
408 | factor,
409 | resblock_type='1',
410 | kernel_sizes=[3, 7, 11],
411 | dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
412 | interpolation='conv'):
413 | super().__init__()
414 | self.MRF = MRF(out_channels, condition_channels, resblock_type, kernel_sizes, dilations)
415 | self.interpolation = interpolation
416 | self.factor = factor
417 | if interpolation == 'conv':
418 | self.up_conv = weight_norm(
419 | nn.ConvTranspose1d(in_channels, out_channels, factor*2, factor))
420 | self.pad_left = factor // 2
421 | self.pad_right = factor - self.pad_left
422 | elif interpolation == 'linear':
423 | self.up_conv = weight_norm(nn.Conv1d(in_channels, out_channels, 1))
424 |
425 | # x: [BatchSize, in_channels, Length]
426 | # c: [BatchSize, condition_channels, Length(upsampled)]
427 | # Output: [BatchSize, out_channels, Length(upsampled)]
428 | def forward(self, x, c):
429 | x = F.leaky_relu(x, 0.1)
430 | if self.interpolation == 'conv':
431 | x = self.up_conv(x)
432 | x = x[:, :, self.pad_left:-self.pad_right]
433 | elif self.interpolation == 'linear':
434 | x = self.up_conv(x)
435 | x = F.interpolate(x, scale_factor=self.factor)
436 | x = self.MRF(x, c)
437 | return x
438 |
439 | def remove_weight_norm(self):
440 | remove_weight_norm(self.up_conv)
441 | self.MRF.remove_weight_norm()
442 |
443 |
444 | class DownBlock(nn.Module):
445 | def __init__(self,
446 | in_channels,
447 | out_channels,
448 | factor,
449 | dilations=[[1, 2], [4, 8]],
450 | kernel_size=3,
451 | interpolation='conv'):
452 | super().__init__()
453 | pad_left = factor // 2
454 | pad_right = factor - pad_left
455 | self.pad = nn.ReplicationPad1d([pad_left, pad_right])
456 | self.interpolation = interpolation
457 | self.factor = factor
458 | if interpolation == 'conv':
459 | self.input_conv = weight_norm(nn.Conv1d(in_channels, out_channels, factor*2, factor))
460 | elif interpolation == 'linear':
461 | self.input_conv = weight_norm(nn.Conv1d(in_channels, out_channels, 1))
462 |
463 | self.convs = nn.ModuleList([])
464 | for ds in dilations:
465 | cs = nn.ModuleList([])
466 | for d in ds:
467 | padding = get_padding(kernel_size, d)
468 | cs.append(
469 | weight_norm(
470 | nn.Conv1d(out_channels, out_channels, kernel_size, 1, padding, dilation=d, padding_mode='replicate')))
471 | self.convs.append(cs)
472 |
473 |
474 | # x: [BatchSize, in_channels, Length]
475 | # Output: [BatchSize, out_channels, Length]
476 | def forward(self, x):
477 | if self.interpolation == 'conv':
478 | x = self.pad(x)
479 | x = self.input_conv(x)
480 | elif self.interpolation == 'linear':
481 | x = self.pad(x)
482 | x = F.avg_pool1d(x, self.factor*2, self.factor) # approximation of linear interpolation
483 | x = self.input_conv(x)
484 | for block in self.convs:
485 | res = x
486 | for c in block:
487 | x = F.leaky_relu(x, 0.1)
488 | x = c(x)
489 | x = x + res
490 | return x
491 |
492 | def remove_weight_norm(self):
493 | for block in self.convs:
494 | for c in block:
495 | remove_weight_norm(c)
496 | remove_weight_norm(self.output_conv)
497 |
498 |
499 | class FilterNet(nn.Module):
500 | def __init__(self,
501 | content_channels=192,
502 | speaker_embedding_dim=256,
503 | channels=[512, 256, 128, 64, 32],
504 | resblock_type='1',
505 | factors=[5, 4, 4, 4, 3],
506 | up_dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
507 | up_kernel_sizes=[3, 7, 11],
508 | up_interpolation='conv',
509 | down_dilations=[[1, 2], [4, 8]],
510 | down_kernel_size=3,
511 | down_interpolation='conv',
512 | num_harmonics=30,
513 | ):
514 | super().__init__()
515 | # input layer
516 | self.content_input = weight_norm(nn.Conv1d(content_channels, channels[0], 1))
517 | self.speaker_input = weight_norm(nn.Conv1d(speaker_embedding_dim, channels[0], 1))
518 | self.energy_input = weight_norm(nn.Conv1d(1, channels[0], 1))
519 | self.f0_input = weight_norm(nn.Conv1d(1, channels[0], 1))
520 |
521 | # downsamples
522 | self.downs = nn.ModuleList([])
523 | self.downs.append(weight_norm(nn.Conv1d(num_harmonics + 2, channels[-1], 1)))
524 | cs = list(reversed(channels[1:]))
525 | ns = cs[1:] + [channels[0]]
526 | fs = list(reversed(factors[1:]))
527 | for c, n, f, in zip(cs, ns, fs):
528 | self.downs.append(DownBlock(c, n, f, down_dilations, down_kernel_size, down_interpolation))
529 |
530 | # upsamples
531 | self.ups = nn.ModuleList([])
532 | cs = channels
533 | ns = channels[1:] + [channels[-1]]
534 | fs = factors
535 | for c, n, f in zip(cs, ns, fs):
536 | self.ups.append(UpBlock(c, n, c, f, resblock_type, up_kernel_sizes, up_dilations, up_interpolation))
537 | self.output_layer = weight_norm(
538 | nn.Conv1d(channels[-1], 1, 7, 1, 3, padding_mode='replicate'))
539 |
540 | # content: [BatchSize, content_channels, Length(frame)]
541 | # f0: [BatchSize, 1, Length(frame)]
542 | # spk: [BatchSize, speaker_embedding_dim, 1]
543 | # source: [BatchSize, num_harmonics+2, Length(Waveform)]
544 | # Output: [Batchsize, 1, Length * frame_size]
545 | def forward(self, content, f0, energy, spk, source):
546 | x = self.content_input(content) + self.speaker_input(spk) + self.f0_input(torch.log(F.relu(f0) + 1e-6))
547 | x = x * self.energy_input(energy)
548 |
549 | skips = []
550 | for down in self.downs:
551 | source = down(source)
552 | skips.append(source)
553 |
554 | for up, s in zip(self.ups, reversed(skips)):
555 | x = up(x, s)
556 | x = self.output_layer(x)
557 | return x
558 |
559 | def remove_weight_norm(self):
560 | remove_weight_norm(self.content_input)
561 | remove_weight_norm(self.output_layer)
562 | remove_weight_norm(self.speaker_input)
563 | for down in self.downs:
564 | down.remove_weight_norm()
565 | for up in self.ups:
566 | up.remove_weight_norm()
567 |
568 |
569 | def pitch_estimation_loss(logits, label):
570 | num_classes = logits.shape[1]
571 | device = logits.device
572 | weight = torch.ones(num_classes, device=device)
573 | weight[0] = 1e-2
574 | return F.cross_entropy(logits, label, weight)
575 |
576 |
577 | class Decoder(nn.Module):
578 | def __init__(self,
579 | sample_rate=48000,
580 | frame_size=960,
581 | n_fft=3840,
582 | content_channels=192,
583 | speaker_embedding_dim=256,
584 | pe_internal_channels=256,
585 | pe_num_layers=6,
586 | source_internal_channels=512,
587 | source_num_layers=6,
588 | num_harmonics=30,
589 | filter_channels=[512, 256, 128, 64, 32],
590 | filter_factors=[5, 4, 4, 3, 2],
591 | filter_resblock_type='1',
592 | filter_down_dilations=[[1, 2], [4, 8]],
593 | filter_down_kernel_size=3,
594 | filter_down_interpolation='conv',
595 | filter_up_dilations=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
596 | filter_up_kernel_sizes=[3, 7, 11],
597 | filter_up_interpolation='conv'):
598 | super().__init__()
599 | self.frame_size = frame_size
600 | self.n_fft = n_fft
601 | self.sample_rate = sample_rate
602 | self.num_harmonics = num_harmonics
603 | self.pitch_energy_estimator = PitchEnergyEstimator(
604 | content_channels=content_channels,
605 | speaker_embedding_dim=speaker_embedding_dim,
606 | num_layers=pe_num_layers,
607 | internal_channels=pe_internal_channels
608 | )
609 | self.source_net = SourceNet(
610 | sample_rate=sample_rate,
611 | frame_size=frame_size,
612 | n_fft=n_fft,
613 | num_harmonics=num_harmonics,
614 | content_channels=content_channels,
615 | speaker_embedding_dim=speaker_embedding_dim,
616 | internal_channels=source_internal_channels,
617 | num_layers=source_num_layers,
618 | )
619 | self.filter_net = FilterNet(
620 | content_channels=content_channels,
621 | speaker_embedding_dim=speaker_embedding_dim,
622 | channels=filter_channels,
623 | resblock_type=filter_resblock_type,
624 | factors=filter_factors,
625 | up_dilations=filter_up_dilations,
626 | up_kernel_sizes=filter_up_kernel_sizes,
627 | up_interpolation=filter_up_interpolation,
628 | down_dilations=filter_down_dilations,
629 | down_kernel_size=filter_down_kernel_size,
630 | down_interpolation=filter_down_interpolation,
631 | num_harmonics=num_harmonics
632 | )
633 | # training pass
634 | #
635 | # content: [BatchSize, content_channels, Length]
636 | # f0: [BatchSize, 1, Length]
637 | # spk: [BatchSize, speaker_embedding_dim, 1]
638 | #
639 | # Outputs:
640 | # f0_logits [BatchSize, num_f0_classes, Length]
641 | # estimated_energy: [BatchSize, 1, Length]
642 | # dsp_out: [BatchSize, Length * frame_size]
643 | # output: [BatchSize, Length * frame_size]
644 | def forward(self, content, f0, energy, spk):
645 | # estimate energy, f0
646 | f0_logits, estimated_energy = self.pitch_energy_estimator(content, spk)
647 | f0_label = self.pitch_energy_estimator.freq2id(f0).squeeze(1)
648 | loss_pe = pitch_estimation_loss(f0_logits, f0_label)
649 | loss_ee = (estimated_energy - energy).abs().mean()
650 | loss = loss_pe * 45.0 + loss_ee * 45.0
651 | loss_dict = {
652 | "Pitch Estimation": loss_pe.item(),
653 | "Energy Estimation": loss_ee.item()
654 | }
655 |
656 | # source net
657 | dsp_out, source = self.source_net(content, f0, energy, spk)
658 | dsp_out = dsp_out.squeeze(1)
659 |
660 | # GAN output
661 | output = self.filter_net(content, f0, energy, spk, source)
662 | output = output.squeeze(1)
663 |
664 | return dsp_out, output, loss, loss_dict
665 |
666 | # inference pass
667 | #
668 | # content: [BatchSize, content_channels, Length]
669 | # f0: [BatchSize, 1, Length]
670 | # Output: [BatchSize, 1, Length * frame_size]
671 | # energy: [BatchSize, 1, Length]
672 | # f0: [BatchSize, 1, Length]
673 | # reference: None or [BatchSize, content_channels, NumReferenceVectors]
674 | # alpha: float 0 ~ 1.0
675 | # k: int
676 | def infer(self, content, spk, energy=None, f0=None, reference=None, alpha=0, k=4, metrics='cos'):
677 | if energy is None or f0 is None:
678 | f0_est, energy_est = self.pitch_energy_estimator.infer(content, spk)
679 | if energy is None:
680 | energy = energy_est
681 | if f0 is None:
682 | f0 = f0_est
683 |
684 | # run feature retrieval if got reference vectors
685 | if reference is not None:
686 | content = match_features(content, reference, k, alpha, metrics)
687 |
688 | dsp_out, source = self.source_net(content, f0, energy, spk)
689 |
690 | # filter network
691 | output = self.filter_net(content, f0, energy, spk, source)
692 | return output
693 |
694 | def estimate_pitch_energy(self, content, spk):
695 | f0, energy = self.pitch_energy_estimator.infer(content, spk)
696 | return f0, energy
697 |
--------------------------------------------------------------------------------
/module/vits/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from module.utils.common import get_padding
6 |
7 |
8 | class DiscriminatorP(nn.Module):
9 | def __init__(self, period, kernel_size=5, stride=3, channels=32, channels_mul=4, num_layers=4, max_channels=1024, use_spectral_norm=False):
10 | super().__init__()
11 | self.period = period
12 | norm_f = nn.utils.parametrizations.weight_norm if use_spectral_norm == False else nn.utils.parametrizations.spectral_norm
13 |
14 | k = kernel_size
15 | s = stride
16 | c = channels
17 |
18 | convs = [nn.Conv2d(1, c, (k, 1), (s, 1), (get_padding(5, 1), 0), padding_mode='replicate')]
19 | for i in range(num_layers):
20 | c_next = min(c * channels_mul, max_channels)
21 | convs.append(nn.Conv2d(c, c_next, (k, 1), (s, 1), (get_padding(5, 1), 0), padding_mode='replicate'))
22 | c = c_next
23 | self.convs = nn.ModuleList([norm_f(c) for c in convs])
24 | self.post = norm_f(nn.Conv2d(c, 1, (3, 1), 1, (1, 0), padding_mode='replicate'))
25 |
26 | def forward(self, x):
27 | fmap = []
28 |
29 | # 1d to 2d
30 | b, c, t = x.shape
31 | if t % self.period != 0:
32 | n_pad = self.period - (t % self.period)
33 | x = F.pad(x, (0, n_pad), "reflect")
34 | t = t + n_pad
35 | x = x.view(b, c, t // self.period, self.period)
36 |
37 | for l in self.convs:
38 | x = l(x)
39 | x = F.leaky_relu(x, 0.1)
40 | fmap.append(x)
41 | x = self.post(x)
42 | fmap.append(x)
43 | return x, fmap
44 |
45 |
46 | class MultiPeriodicDiscriminator(nn.Module):
47 | def __init__(
48 | self,
49 | periods=[1, 2, 3, 5, 7, 11],
50 | channels=32,
51 | channels_mul=4,
52 | max_channels=1024,
53 | num_layers=4,
54 | ):
55 | super().__init__()
56 | self.sub_discs = nn.ModuleList([])
57 | for p in periods:
58 | self.sub_discs.append(DiscriminatorP(p,
59 | channels=channels,
60 | channels_mul=channels_mul,
61 | max_channels=max_channels,
62 | num_layers=num_layers))
63 |
64 | def forward(self, x):
65 | x = x.unsqueeze(1)
66 | feats = []
67 | logits = []
68 | for d in self.sub_discs:
69 | logit, fmap = d(x)
70 | logits.append(logit)
71 | feats += fmap
72 | return logits, feats
73 |
74 |
75 | class DiscriminatorR(nn.Module):
76 | def __init__(self, resolution=128, channels=16, num_layers=4, max_channels=256):
77 | super().__init__()
78 | norm_f = nn.utils.weight_norm
79 | self.convs = nn.ModuleList([norm_f(nn.Conv2d(1, channels, (7, 3), (2, 1), (3, 1)))])
80 | self.hop_size = resolution
81 | self.n_fft = resolution * 4
82 | c = channels
83 | for _ in range(num_layers):
84 | c_next = min(c * 2, max_channels)
85 | self.convs.append(norm_f(nn.Conv2d(c, c_next, (5, 3), (2, 1), (2, 1))))
86 | c = c_next
87 | self.post = norm_f(nn.Conv2d(c, 1, 3, 1, 1))
88 |
89 | def forward(self, x):
90 | w = torch.hann_window(self.n_fft).to(x.device)
91 | x = torch.stft(x, self.n_fft, self.hop_size, window=w, return_complex=True).abs()
92 | x = x.unsqueeze(1)
93 | feats = []
94 | for l in self.convs:
95 | x = l(x)
96 | F.leaky_relu(x, 0.1)
97 | feats.append(x)
98 | x = self.post(x)
99 | feats.append(x)
100 | return x, feats
101 |
102 |
103 | class MultiResolutionDiscriminator(nn.Module):
104 | def __init__(
105 | self,
106 | resolutions=[128, 256, 512],
107 | channels=32,
108 | num_layers=4,
109 | max_channels=256,
110 | ):
111 | super().__init__()
112 | self.sub_discs = nn.ModuleList([])
113 | for r in resolutions:
114 | self.sub_discs.append(DiscriminatorR(r, channels, num_layers, max_channels))
115 |
116 | def forward(self, x):
117 | feats = []
118 | logits = []
119 | for d in self.sub_discs:
120 | logit, fmap = d(x)
121 | logits.append(logit)
122 | feats += fmap
123 | return logits, feats
124 |
125 |
126 | class Discriminator(nn.Module):
127 | def __init__(self, config):
128 | super().__init__()
129 | self.MPD = MultiPeriodicDiscriminator(**config.mpd)
130 | self.MRD = MultiResolutionDiscriminator(**config.mrd)
131 |
132 | # x: [BatchSize, Length(waveform)]
133 | def forward(self, x):
134 | mpd_logits, mpd_feats = self.MPD(x)
135 | mrd_logits, mrd_feats = self.MRD(x)
136 | return mpd_logits + mrd_logits, mpd_feats + mrd_feats
137 |
--------------------------------------------------------------------------------
/module/vits/duration_discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .normalization import LayerNorm
4 | from .convnext import ConvNeXtLayer
5 |
6 | class DurationDiscriminator(nn.Module):
7 | def __init__(
8 | self,
9 | content_channels=192,
10 | speaker_embedding_dim=256,
11 | internal_channels=192,
12 | num_layers=3,
13 | ):
14 | super().__init__()
15 | self.text_input = nn.Conv1d(content_channels, internal_channels, 1)
16 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1)
17 | self.duration_input = nn.Conv1d(1, internal_channels, 1)
18 | self.input_norm = LayerNorm(internal_channels)
19 | self.mid_layers = nn.Sequential(*[ConvNeXtLayer(internal_channels) for _ in range(num_layers)])
20 | self.output_norm = LayerNorm(internal_channels)
21 | self.output_layer = nn.Conv1d(internal_channels, 1, 1)
22 |
23 | def forward(
24 | self,
25 | text_encoded,
26 | text_mask,
27 | log_duration,
28 | spk,
29 | ):
30 | x = self.text_input(text_encoded) + self.duration_input(log_duration) + self.speaker_input(spk)
31 | x = self.input_norm(x)
32 | x = self.mid_layers(x) * text_mask
33 | x = self.output_norm(x)
34 | x = self.output_layer(x) * text_mask
35 | return x
36 |
--------------------------------------------------------------------------------
/module/vits/duration_predictors.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .transforms import piecewise_rational_quadratic_transform
7 | from .normalization import LayerNorm
8 | from .convnext import ConvNeXtLayer
9 |
10 |
11 | DEFAULT_MIN_BIN_WIDTH = 1e-3
12 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
13 | DEFAULT_MIN_DERIVATIVE = 1e-3
14 |
15 |
16 | class Flip(nn.Module):
17 | def forward(self, x, *args, reverse=False, **kwargs):
18 | x = torch.flip(x, [1])
19 | if not reverse:
20 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
21 | return x, logdet
22 | else:
23 | return x
24 |
25 |
26 | class StochasticDurationPredictor(nn.Module):
27 | def __init__(
28 | self,
29 | in_channels=192,
30 | filter_channels=256,
31 | kernel_size=5,
32 | p_dropout=0.0,
33 | n_flows=4,
34 | speaker_embedding_dim=256
35 | ):
36 | super().__init__()
37 | self.log_flow = Log()
38 | self.flows = nn.ModuleList()
39 | self.flows.append(ElementwiseAffine(2))
40 | for i in range(n_flows):
41 | self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
42 | self.flows.append(Flip())
43 |
44 | self.pre = nn.Linear(in_channels, filter_channels)
45 | self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
46 | self.proj = nn.Linear(filter_channels, filter_channels)
47 |
48 | self.post_pre = nn.Linear(1, filter_channels)
49 | self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
50 | self.post_proj = nn.Linear(filter_channels, filter_channels)
51 |
52 | self.post_flows = nn.ModuleList()
53 | self.post_flows.append(ElementwiseAffine(2))
54 | for i in range(4):
55 | self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
56 | self.post_flows.append(Flip())
57 |
58 | if speaker_embedding_dim != 0:
59 | self.cond = nn.Linear(speaker_embedding_dim, filter_channels)
60 |
61 |
62 | # x: [BatchSize, in_chanels, Length]
63 | # x_mask [BatchSize, 1, Length]
64 | # g: [BatchSize, speaker_embedding_dim, 1]
65 | # w: Optional, Training only shape=[BatchSize, 1, Length]
66 | # Output: [BatchSize, 1, Length]
67 | #
68 | # note: g is speaker embedding
69 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, w=None, g=None, reverse=False, noise_scale=1.0):
70 | x = torch.detach(x)
71 | x = self.pre(x.mT).mT
72 | if g is not None:
73 | g = torch.detach(g)
74 | x = x + self.cond(g.mT).mT
75 | x = self.convs(x, x_mask)
76 | x = self.proj(x.mT).mT * x_mask
77 |
78 | if not reverse:
79 | flows = self.flows
80 | assert w is not None
81 |
82 | logdet_tot_q = 0
83 | h_w = self.post_pre(w.mT).mT
84 | h_w = self.post_convs(h_w, x_mask)
85 | h_w = self.post_proj(h_w.mT).mT * x_mask
86 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
87 | z_q = e_q
88 | for flow in self.post_flows:
89 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
90 | logdet_tot_q += logdet_q
91 | z_u, z1 = torch.split(z_q, [1, 1], 1)
92 | u = torch.sigmoid(z_u) * x_mask
93 | z0 = (w - u) * x_mask
94 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
95 | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
96 |
97 | logdet_tot = 0
98 | z0, logdet = self.log_flow(z0, x_mask)
99 | logdet_tot += logdet
100 | z = torch.cat([z0, z1], 1)
101 | for flow in flows:
102 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
103 | logdet_tot = logdet_tot + logdet
104 | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
105 | return nll + logq # [b]
106 | else:
107 | flows = list(reversed(self.flows))
108 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
109 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
110 | for flow in flows:
111 | z = flow(z, x_mask, g=x, reverse=reverse)
112 | z0, z1 = torch.split(z, [1, 1], 1)
113 | logw = z0
114 | return logw
115 |
116 |
117 | class ConvFlow(nn.Module):
118 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
119 | super().__init__()
120 | self.filter_channels = filter_channels
121 | self.num_bins = num_bins
122 | self.tail_bound = tail_bound
123 | self.half_channels = in_channels // 2
124 |
125 | self.pre = nn.Linear(self.half_channels, filter_channels)
126 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
127 | self.proj = nn.Linear(filter_channels, self.half_channels * (num_bins * 3 - 1))
128 | self.proj.weight.data.zero_()
129 | self.proj.bias.data.zero_()
130 |
131 | def forward(self, x, x_mask, g=None, reverse=False):
132 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
133 | h = self.pre(x0.mT).mT
134 | h = self.convs(h, x_mask, g=g)
135 | h = self.proj(h.mT).mT * x_mask
136 |
137 | b, c, t = x0.shape
138 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
139 |
140 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
141 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
142 | unnormalized_derivatives = h[..., 2 * self.num_bins :]
143 |
144 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, unnormalized_widths, unnormalized_heights, unnormalized_derivatives, inverse=reverse, tails="linear", tail_bound=self.tail_bound)
145 |
146 | x = torch.cat([x0, x1], 1) * x_mask
147 | logdet = torch.sum(logabsdet * x_mask, [1, 2])
148 | if not reverse:
149 | return x, logdet
150 | else:
151 | return x
152 |
153 |
154 | class DDSConv(nn.Module):
155 | """
156 | Dilated and Depth-Separable Convolution
157 | """
158 |
159 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
160 | super().__init__()
161 | self.n_layers = n_layers
162 |
163 | self.drop = nn.Dropout(p_dropout)
164 | self.convs_sep = nn.ModuleList()
165 | self.linears = nn.ModuleList()
166 | self.norms_1 = nn.ModuleList()
167 | self.norms_2 = nn.ModuleList()
168 | for i in range(n_layers):
169 | dilation = kernel_size**i
170 | padding = (kernel_size * dilation - dilation) // 2
171 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding))
172 | self.linears.append(nn.Linear(channels, channels))
173 | self.norms_1.append(LayerNorm(channels))
174 | self.norms_2.append(LayerNorm(channels))
175 |
176 | def forward(self, x, x_mask, g=None):
177 | if g is not None:
178 | x = x + g
179 | for i in range(self.n_layers):
180 | y = self.convs_sep[i](x * x_mask)
181 | y = self.norms_1[i](y)
182 | y = F.gelu(y)
183 | y = self.linears[i](y.mT).mT
184 | y = self.norms_2[i](y)
185 | y = F.gelu(y)
186 | y = self.drop(y)
187 | x = x + y
188 | return x * x_mask
189 |
190 |
191 | # TODO convert to class method
192 | class Log(nn.Module):
193 | def forward(self, x, x_mask, reverse=False, **kwargs):
194 | if not reverse:
195 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
196 | logdet = torch.sum(-y, [1, 2])
197 | return y, logdet
198 | else:
199 | x = torch.exp(x) * x_mask
200 | return x
201 |
202 |
203 | class ElementwiseAffine(nn.Module):
204 | def __init__(self, channels):
205 | super().__init__()
206 | self.m = nn.Parameter(torch.zeros(channels, 1))
207 | self.logs = nn.Parameter(torch.zeros(channels, 1))
208 |
209 | def forward(self, x, x_mask, reverse=False, **kwargs):
210 | if not reverse:
211 | y = self.m + torch.exp(self.logs) * x
212 | y = y * x_mask
213 | logdet = torch.sum(self.logs * x_mask, [1, 2])
214 | return y, logdet
215 | else:
216 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
217 | return x
218 |
219 |
220 | class DurationPredictor(nn.Module):
221 | def __init__(self,
222 | content_channels=192,
223 | internal_channels=256,
224 | speaker_embedding_dim=256,
225 | kernel_size=7,
226 | num_layers=4):
227 | super().__init__()
228 | self.phoneme_input = nn.Conv1d(content_channels, internal_channels, 1)
229 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1)
230 | self.input_norm = LayerNorm(internal_channels)
231 | self.mid_layers = nn.ModuleList()
232 | for _ in range(num_layers):
233 | self.mid_layers.append(ConvNeXtLayer(internal_channels, kernel_size))
234 | self.output_norm = LayerNorm(internal_channels)
235 | self.output_layer = nn.Conv1d(internal_channels, 1, 1)
236 |
237 | def forward(self, x, x_mask, g):
238 | x = (self.phoneme_input(x) + self.speaker_input(g)) * x_mask
239 | x = self.input_norm(x) * x_mask
240 | for layer in self.mid_layers:
241 | x = layer(x) * x_mask
242 | x = self.output_norm(x) * x_mask
243 | x = self.output_layer(x) * x_mask
244 | x = F.relu(x)
245 | return x
246 |
--------------------------------------------------------------------------------
/module/vits/feature_retrieval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # Feature retrieval via kNN (like kNN-VC, RVC, etc...)
7 | # Warning: this method is not optimized.
8 | # Do not give long sequence. computing complexy is quadratic.
9 | #
10 | # source: [BatchSize, Channels, Length]
11 | # reference: [BatchSize, Channels, Length]
12 | # k: int
13 | # alpha: float (0.0 ~ 1.0)
14 | # metrics: one of ['IP', 'L2', 'cos'], 'IP' means innner product, 'L2' means euclid distance, 'cos' means cosine similarity
15 | # Output: [BatchSize, Channels, Length]
16 | def match_features(source, reference, k=4, alpha=0.0, metrics='cos'):
17 | input_data = source
18 |
19 | source = source.transpose(1, 2)
20 | reference = reference.transpose(1, 2)
21 | if metrics == 'IP':
22 | sims = torch.bmm(source, reference.transpose(1, 2))
23 | elif metrics == 'L2':
24 | sims = -torch.cdist(source, reference)
25 | elif metrics == 'cos':
26 | reference_norm = torch.norm(reference, dim=2, keepdim=True, p=2) + 1e-6
27 | source_norm = torch.norm(source, dim=2, keepdim=True, p=2) + 1e-6
28 | sims = torch.bmm(source / source_norm, (reference / reference_norm).transpose(1, 2))
29 | best = torch.topk(sims, k, dim=2)
30 |
31 | result = torch.stack([reference[n][best.indices[n]] for n in range(source.shape[0])], dim=0).mean(dim=2)
32 | result = result.transpose(1, 2)
33 |
34 | return result * (1-alpha) + input_data * alpha
35 |
--------------------------------------------------------------------------------
/module/vits/flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .wn import WN
5 |
6 |
7 | class Flip(nn.Module):
8 | def forward(self, x, *args, **kwargs):
9 | x = torch.flip(x, [1])
10 | return x
11 |
12 |
13 | class ResidualCouplingLayer(nn.Module):
14 | def __init__(self,
15 | in_channels=192,
16 | internal_channels=192,
17 | speaker_embedding_dim=256,
18 | kernel_size=5,
19 | dilation=1,
20 | num_layers=4):
21 | super().__init__()
22 | self.in_channels = in_channels
23 | self.half_channels = in_channels // 2
24 |
25 | self.pre = nn.Conv1d(self.half_channels, internal_channels, 1)
26 | self.wn = WN(internal_channels, kernel_size, dilation, num_layers, speaker_embedding_dim)
27 | self.post = nn.Conv1d(internal_channels, self.half_channels, 1)
28 |
29 | self.post.weight.data.zero_()
30 | self.post.bias.data.zero_()
31 |
32 | # x: [BatchSize, in_channels, Length]
33 | # x_mask: [BatchSize, 1, Length]
34 | # g: [BatchSize, speaker_embedding_dim, 1]
35 | def forward(self, x, x_mask, g, reverse=False):
36 | x_0, x_1 = torch.chunk(x, 2, dim=1)
37 | h = self.pre(x_0) * x_mask
38 | h = self.wn(h, x_mask, g)
39 | x_1_mean = self.post(h) * x_mask
40 |
41 | if not reverse:
42 | x_1 = x_1_mean + x_1 * x_mask
43 | else:
44 | x_1 = (x_1 - x_1_mean) * x_mask
45 |
46 | x = torch.cat([x_0, x_1], dim=1)
47 | return x
48 |
49 |
50 | class Flow(nn.Module):
51 | def __init__(self,
52 | content_channels=192,
53 | internal_channels=192,
54 | speaker_embedding_dim=256,
55 | kernel_size=5,
56 | dilation=1,
57 | num_flows=4,
58 | num_layers=4):
59 | super().__init__()
60 |
61 | self.flows = nn.ModuleList()
62 | for i in range(num_flows):
63 | self.flows.append(
64 | ResidualCouplingLayer(
65 | content_channels,
66 | internal_channels,
67 | speaker_embedding_dim,
68 | kernel_size,
69 | dilation,
70 | num_layers))
71 | self.flows.append(Flip())
72 |
73 | # z: [BatchSize, content_channels, Length]
74 | # z_mask: [BatchSize, 1, Length]
75 | # g: [Batchsize, speaker_embedding_dim, 1]
76 | # Output: [BatchSize, content_channels, Length]
77 | def forward(self, z, z_mask, g, reverse=False):
78 | if not reverse:
79 | for flow in self.flows:
80 | z = flow(z, z_mask, g, reverse=False)
81 | else:
82 | for flow in reversed(self.flows):
83 | z = flow(z, z_mask, g, reverse=True)
84 | return z
85 |
--------------------------------------------------------------------------------
/module/vits/generator.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .decoder import Decoder
9 | from .posterior_encoder import PosteriorEncoder
10 | from .prior_encoder import PriorEncoder
11 | from .speaker_embedding import SpeakerEmbedding
12 | from .flow import Flow
13 | from .crop import crop_features
14 | from module.utils.energy_estimation import estimate_energy
15 |
16 |
17 | class Generator(nn.Module):
18 | # initialize from config
19 | def __init__(self, config):
20 | super().__init__()
21 | self.decoder = Decoder(**config.decoder)
22 | self.posterior_encoder = PosteriorEncoder(**config.posterior_encoder)
23 | self.prior_encoder = PriorEncoder(config.prior_encoder)
24 | self.speaker_embedding = SpeakerEmbedding(**config.speaker_embedding)
25 |
26 | # training pass
27 | #
28 | # spec: [BatchSize, fft_bin, Length]
29 | # spec_len: [BatchSize]
30 | # phone: [BatchSize, NumPhonemes]
31 | # phone_len: [BatchSize]
32 | # lm_feat: [BatchSize, lm_dim, NumLMfeatures]
33 | # lm_feat_len: [BatchSize]
34 | # f0: [Batchsize, 1, Length]
35 | # spk: [BatchSize]
36 | # lang: [BatchSize]
37 | # crop_range: Tuple[int, int]
38 | #
39 | # Outputs:
40 | # dsp_out: [BatchSize, Length * frame_size]
41 | # fake: [BatchSize, Length * frame_size]
42 | # lossG: [1]
43 | # loss_dict: Dict[str: float]
44 | #
45 | def forward(
46 | self,
47 | spec,
48 | spec_len,
49 | phoneme,
50 | phoneme_len,
51 | lm_feat,
52 | lm_feat_len,
53 | f0,
54 | spk,
55 | lang,
56 | crop_range
57 | ):
58 |
59 | spk = self.speaker_embedding(spk)
60 | z, m_q, logs_q, spec_mask = self.posterior_encoder.forward(spec, spec_len, spk)
61 | energy = estimate_energy(spec)
62 | loss_prior, loss_dict_prior, (text_encoded, text_mask, fake_log_duration, real_log_duration) = self.prior_encoder.forward(spec_mask, z, logs_q, phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk)
63 |
64 | z_crop = crop_features(z, crop_range)
65 | f0_crop = crop_features(f0, crop_range)
66 | energy_crop = crop_features(energy, crop_range)
67 | dsp_out, fake, loss_decoder, loss_dict_decoder = self.decoder.forward(z_crop, f0_crop, energy_crop, spk)
68 |
69 | loss_dict = (loss_dict_decoder | loss_dict_prior) # merge dict
70 | loss = loss_prior + loss_decoder
71 |
72 | return loss, loss_dict, (text_encoded, text_mask, fake_log_duration, real_log_duration, spk.detach(), dsp_out, fake)
73 |
74 | @torch.no_grad()
75 | def text_to_speech(
76 | self,
77 | phoneme,
78 | phoneme_len,
79 | lm_feat,
80 | lm_feat_len,
81 | lang,
82 | spk,
83 | noise_scale=0.6,
84 | max_frames=2000,
85 | use_sdp=True,
86 | duration_scale=1.0,
87 | pitch_shift=0.0,
88 | energy_scale=1.0,
89 | ):
90 | spk = self.speaker_embedding(spk)
91 | z = self.prior_encoder.text_to_speech(
92 | phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk,
93 | noise_scale=noise_scale,
94 | max_frames=max_frames,
95 | use_sdp=use_sdp,
96 | duration_scale=duration_scale)
97 | f0, energy = self.decoder.estimate_pitch_energy(z, spk)
98 | pitch = torch.log2((f0 + 1e-6) / 440.0) * 12.0
99 | pitch += pitch_shift
100 | f0 = 440.0 * 2 ** (pitch / 12.0)
101 | energy = energy * energy_scale
102 | fake = self.decoder.infer(z, spk, f0=f0, energy=energy)
103 | return fake
104 |
105 | @torch.no_grad()
106 | def audio_reconstruction(self, spec, spec_len, spk):
107 | # embed speaker
108 | spk = self.speaker_embedding(spk)
109 |
110 | # encode linear spectrogram and speaker infomation
111 | z, m_q, logs_q, spec_mask = self.posterior_encoder(spec, spec_len, spk)
112 |
113 | # decode
114 | fake = self.decoder.infer(z, spk)
115 | return fake
116 |
117 | @torch.no_grad()
118 | def singing_voice_conversion(self):
119 | pass # TODO: write this
120 |
121 | @torch.no_grad()
122 | def singing_voice_synthesis(self):
123 | pass # TODO: write this
124 |
--------------------------------------------------------------------------------
/module/vits/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchaudio
5 |
6 |
7 | def safe_log(x, eps=1e-6):
8 | return torch.log(x + eps)
9 |
10 | def multiscale_stft_loss(x, y, scales=[16, 32, 64, 128, 256, 512]):
11 | x = x.to(torch.float)
12 | y = y.to(torch.float)
13 |
14 | loss = 0
15 | num_scales = len(scales)
16 | for s in scales:
17 | hop_length = s
18 | n_fft = s * 4
19 | window = torch.hann_window(n_fft, device=x.device)
20 | x_spec = torch.stft(x, n_fft, hop_length, return_complex=True, window=window).abs()
21 | y_spec = torch.stft(y, n_fft, hop_length, return_complex=True, window=window).abs()
22 |
23 | x_spec[x_spec.isnan()] = 0
24 | x_spec[x_spec.isinf()] = 0
25 | y_spec[y_spec.isnan()] = 0
26 | y_spec[y_spec.isinf()] = 0
27 |
28 | loss += (safe_log(x_spec) - safe_log(y_spec)).abs().mean()
29 | return loss / num_scales
30 |
31 |
32 | global mel_spectrogram_modules
33 | mel_spectrogram_modules = {}
34 | def mel_spectrogram_loss(x, y, sample_rate=48000, n_fft=2048, hop_length=512, power=2.0, log=True):
35 | device = x.device
36 | if device not in mel_spectrogram_modules:
37 | mel_spectrogram_modules[device] = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, power=power).to(device)
38 | mel_spectrogram = mel_spectrogram_modules[device]
39 |
40 | x_mel = mel_spectrogram(x)
41 | y_mel = mel_spectrogram(y)
42 | if log:
43 | x_mel = safe_log(x_mel)
44 | y_mel = safe_log(y_mel)
45 |
46 | x_mel[x_mel.isnan()] = 0
47 | x_mel[x_mel.isinf()] = 0
48 | y_mel[y_mel.isnan()] = 0
49 | y_mel[y_mel.isinf()] = 0
50 |
51 | loss = F.l1_loss(x_mel, y_mel)
52 | return loss
53 |
54 | # 1 = fake, 0 = real
55 | def discriminator_adversarial_loss(real_outputs, fake_outputs):
56 | loss = 0
57 | n = min(len(real_outputs), len(fake_outputs))
58 | for dr, df in zip(real_outputs, fake_outputs):
59 | dr = dr.float()
60 | df = df.float()
61 | real_loss = (dr ** 2).mean()
62 | fake_loss = ((df - 1) ** 2).mean()
63 | loss += real_loss + fake_loss
64 | return loss / n
65 |
66 |
67 | def generator_adversarial_loss(fake_outputs):
68 | loss = 0
69 | n = len(fake_outputs)
70 | for dg in fake_outputs:
71 | dg = dg.float()
72 | loss += (dg ** 2).mean()
73 | return loss / n
74 |
75 |
76 | def duration_discriminator_adversarial_loss(real_output, fake_output, text_mask):
77 | loss = (((fake_output - 1) ** 2) * text_mask).sum() / text_mask.sum()
78 | loss += ((real_output ** 2) * text_mask).sum() / text_mask.sum()
79 | return loss
80 |
81 |
82 | def duration_generator_adversarial_loss(fake_output, text_mask):
83 | loss = ((fake_output ** 2) * text_mask).sum() / text_mask.sum()
84 | return loss
85 |
86 |
87 | def feature_matching_loss(fmap_real, fmap_fake):
88 | loss = 0
89 | n = min(len(fmap_real), len(fmap_fake))
90 | for r, f in zip(fmap_real, fmap_fake):
91 | f = f.float()
92 | r = r.float()
93 | loss += (f - r).abs().mean()
94 | return loss * (2 / n)
95 |
--------------------------------------------------------------------------------
/module/vits/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | """Maximum path calculation module.
2 |
3 | This code is based on https://github.com/jaywalnut310/vits.
4 |
5 | """
6 |
7 | import warnings
8 |
9 | import numpy as np
10 | import torch
11 | from numba import njit, prange
12 |
13 | try:
14 | from .core import maximum_path_c
15 |
16 | is_cython_avalable = True
17 | except ImportError:
18 | is_cython_avalable = False
19 | warnings.warn(
20 | "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
21 | "If you want to use the cython version, please build it as follows: "
22 | "`cd auris/module/vits/monotonic_align; python setup.py build_ext --inplace`"
23 | )
24 |
25 |
26 | def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
27 | """Calculate maximum path.
28 |
29 | Args:
30 | neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
31 | attn_mask (Tensor): Attention mask (B, T_feats, T_text).
32 |
33 | Returns:
34 | Tensor: Maximum path tensor (B, T_feats, T_text).
35 |
36 | """
37 | device, dtype = neg_x_ent.device, neg_x_ent.dtype
38 | neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
39 | path = np.zeros(neg_x_ent.shape, dtype=np.int32)
40 | t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
41 | t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
42 | if is_cython_avalable:
43 | maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
44 | else:
45 | maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
46 |
47 | return torch.from_numpy(path).to(device=device, dtype=dtype)
48 |
49 |
50 | @njit
51 | def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
52 | """Calculate a single maximum path with numba."""
53 | index = t_x - 1
54 | for y in range(t_y):
55 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
56 | if x == y:
57 | v_cur = max_neg_val
58 | else:
59 | v_cur = value[y - 1, x]
60 | if x == 0:
61 | if y == 0:
62 | v_prev = 0.0
63 | else:
64 | v_prev = max_neg_val
65 | else:
66 | v_prev = value[y - 1, x - 1]
67 | value[y, x] += max(v_prev, v_cur)
68 |
69 | for y in range(t_y - 1, -1, -1):
70 | path[y, index] = 1
71 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
72 | index = index - 1
73 |
74 |
75 | @njit(parallel=True)
76 | def maximum_path_numba(paths, values, t_ys, t_xs):
77 | """Calculate batch maximum path with numba."""
78 | for i in prange(paths.shape[0]):
79 | maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
80 |
--------------------------------------------------------------------------------
/module/vits/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | """Maximum path calculation module with cython optimization.
2 |
3 | This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
4 |
5 | """
6 |
7 | cimport cython
8 |
9 | from cython.parallel import prange
10 |
11 |
12 | @cython.boundscheck(False)
13 | @cython.wraparound(False)
14 | cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
15 | cdef int x
16 | cdef int y
17 | cdef float v_prev
18 | cdef float v_cur
19 | cdef float tmp
20 | cdef int index = t_x - 1
21 |
22 | for y in range(t_y):
23 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
24 | if x == y:
25 | v_cur = max_neg_val
26 | else:
27 | v_cur = value[y - 1, x]
28 | if x == 0:
29 | if y == 0:
30 | v_prev = 0.0
31 | else:
32 | v_prev = max_neg_val
33 | else:
34 | v_prev = value[y - 1, x - 1]
35 | value[y, x] += max(v_prev, v_cur)
36 |
37 | for y in range(t_y - 1, -1, -1):
38 | path[y, index] = 1
39 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
40 | index = index - 1
41 |
42 |
43 | @cython.boundscheck(False)
44 | @cython.wraparound(False)
45 | cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
46 | cdef int b = paths.shape[0]
47 | cdef int i
48 | for i in prange(b, nogil=True):
49 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
50 |
--------------------------------------------------------------------------------
/module/vits/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | """Setup cython code."""
2 |
3 | from Cython.Build import cythonize
4 | from setuptools import Extension, setup
5 | from setuptools.command.build_ext import build_ext as _build_ext
6 |
7 |
8 | class build_ext(_build_ext):
9 | """Overwrite build_ext."""
10 |
11 | def finalize_options(self):
12 | """Prevent numpy from thinking it is still in its setup process."""
13 | _build_ext.finalize_options(self)
14 | __builtins__.__NUMPY_SETUP__ = False
15 | import numpy
16 |
17 | self.include_dirs.append(numpy.get_include())
18 |
19 |
20 | exts = [
21 | Extension(
22 | name="core",
23 | sources=["core.pyx"],
24 | )
25 | ]
26 | setup(
27 | name="monotonic_align",
28 | ext_modules=cythonize(exts, language_level=3),
29 | cmdclass={"build_ext": build_ext},
30 | )
31 |
--------------------------------------------------------------------------------
/module/vits/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class LayerNorm(nn.Module):
7 | def __init__(self, channels, eps=1e-5):
8 | super().__init__()
9 | self.channels = channels
10 | self.eps = eps
11 |
12 | self.gamma = nn.Parameter(torch.ones(channels))
13 | self.beta = nn.Parameter(torch.zeros(channels))
14 |
15 | # x: [BatchSize, cnannels, *]
16 | def forward(self, x: torch.Tensor):
17 | x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
18 | return x.mT
19 |
20 |
21 | # Global Resnponse Normalization for 1d Sequence (shape=[BatchSize, Channels, Length])
22 | class GRN(nn.Module):
23 | def __init__(self, channels, eps=1e-6):
24 | super().__init__()
25 | self.beta = nn.Parameter(torch.zeros(1, channels, 1))
26 | self.gamma = nn.Parameter(torch.zeros(1, channels, 1))
27 | self.eps = eps
28 |
29 | # x: [batchsize, channels, length]
30 | def forward(self, x):
31 | gx = torch.norm(x, p=2, dim=2, keepdim=True)
32 | nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps)
33 | return self.gamma * (x * nx) + self.beta + x
34 |
--------------------------------------------------------------------------------
/module/vits/posterior_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .wn import WN
4 |
5 |
6 | class PosteriorEncoder(nn.Module):
7 | def __init__(self,
8 | n_fft=3840,
9 | frame_size=960,
10 | internal_channels=192,
11 | speaker_embedding_dim=256,
12 | content_channels=192,
13 | kernel_size=5,
14 | dilation=1,
15 | num_layers=16):
16 | super().__init__()
17 |
18 | self.input_channels = n_fft // 2 + 1
19 | self.n_fft = n_fft
20 | self.frame_size = frame_size
21 | self.pre = nn.Conv1d(self.input_channels, internal_channels, 1)
22 | self.wn = WN(internal_channels, kernel_size, dilation, num_layers, speaker_embedding_dim)
23 | self.post = nn.Conv1d(internal_channels, content_channels * 2, 1)
24 |
25 | # x: [BatchSize, fft_bin, Length]
26 | # x_length: [BatchSize]
27 | # g: [BatchSize, speaker_embedding_dim, 1]
28 | #
29 | # Outputs:
30 | # z: [BatchSize, content_channels, Length]
31 | # mean: [BatchSize, content_channels, Length]
32 | # logvar: [BatchSize, content_channels, Length]
33 | # z_mask: [BatchSize, 1, Length]
34 | #
35 | # where fft_bin = input_channels = n_fft // 2 + 1
36 | def forward(self, x, x_length, g):
37 | # generate mask
38 | max_length = x.shape[2]
39 | progression = torch.arange(max_length, dtype=x_length.dtype, device=x_length.device)
40 | z_mask = (progression.unsqueeze(0) < x_length.unsqueeze(1))
41 | z_mask = z_mask.unsqueeze(1).to(x.dtype)
42 |
43 | # pass network
44 | x = self.pre(x) * z_mask
45 | x = self.wn(x, z_mask, g)
46 | x = self.post(x) * z_mask
47 | mean, logvar = torch.chunk(x, 2, dim=1)
48 | z = mean + torch.randn_like(mean) * torch.exp(logvar) * z_mask
49 | return z, mean, logvar, z_mask
50 |
--------------------------------------------------------------------------------
/module/vits/prior_encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .duration_predictors import StochasticDurationPredictor, DurationPredictor
7 | from .monotonic_align import maximum_path
8 | from .text_encoder import TextEncoder
9 | from .flow import Flow
10 | from module.utils.energy_estimation import estimate_energy
11 |
12 |
13 | def sequence_mask(length, max_length=None):
14 | if max_length is None:
15 | max_length = length.max()
16 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
17 | return x.unsqueeze(0) < length.unsqueeze(1)
18 |
19 |
20 | def generate_path(duration, mask):
21 | """
22 | duration: [b, 1, t_x]
23 | mask: [b, 1, t_y, t_x]
24 | """
25 | b, _, t_y, t_x = mask.shape
26 | cum_duration = torch.cumsum(duration, -1)
27 |
28 | cum_duration_flat = cum_duration.view(b * t_x)
29 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
30 | path = path.view(b, t_x, t_y)
31 |
32 | padding_shape = [[0, 0], [1, 0], [0, 0]]
33 | padding = [item for sublist in padding_shape[::-1] for item in sublist]
34 |
35 | path = path - F.pad(path, padding)[:, :-1]
36 | path = path.unsqueeze(1).transpose(2,3) * mask
37 | return path
38 |
39 |
40 | def kl_divergence_loss(z_p, logs_q, m_p, logs_p, z_mask):
41 | z_p = z_p.float()
42 | logs_q = logs_q.float()
43 | m_p = m_p.float()
44 | logs_p = logs_p.float()
45 | z_mask = z_mask.float()
46 |
47 | kl = logs_p - logs_q - 0.5
48 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
49 | kl = torch.sum(kl * z_mask)
50 | l = kl / torch.sum(z_mask)
51 | return l
52 |
53 |
54 | # run Monotonic Alignment Search (MAS).
55 | # MAS associates phoneme sequences with sounds.
56 | #
57 | # z_p: [b, d, t']
58 | # m_p: [b, d, t]
59 | # logs_p: [b, d, t]
60 | # text_mask: [b, 1, t]
61 | # spec_mask: [b, 1, t']
62 | # Output: [b, 1, t', t]
63 | def search_path(z_p, m_p, logs_p, text_mask, spec_mask, mas_noise_scale=0.0):
64 | with torch.no_grad():
65 | # calculate nodes
66 | # b = batch size, d = feature dim, t = text length, t' = spec length
67 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
68 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, dim=1, keepdim=True) # [b, 1, t]
69 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).mT, s_p_sq_r) # [b, t', d] x [b, d, t] = [b, t', t]
70 | neg_cent3 = torch.matmul(z_p.mT, (m_p * s_p_sq_r)) # [b, t', s] x [b, d, t] = [b, t', t]
71 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, dim=1, keepdim=True) # [b, 1, t]
72 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 # [b, t', t]
73 |
74 | # add noise
75 | if mas_noise_scale > 0.0:
76 | eps = torch.std(neg_cent) * torch.randn_like(neg_cent) * mas_noise_scale
77 | neg_cent += eps
78 |
79 | # mask unnecessary nodes, run D.P.
80 | MAS_node_mask = text_mask.unsqueeze(2) * spec_mask.unsqueeze(-1) # [b, 1, 't, t]
81 | MAS_path = maximum_path(neg_cent, MAS_node_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, 't, t]
82 | return MAS_path
83 |
84 |
85 | class PriorEncoder(nn.Module):
86 | def __init__(
87 | self,
88 | config,
89 | ):
90 | super().__init__()
91 | self.flow = Flow(**config.flow)
92 | self.text_encoder = TextEncoder(**config.text_encoder)
93 | self.duration_predictor = DurationPredictor(**config.duration_predictor)
94 | self.stochastic_duration_predictor = StochasticDurationPredictor(**config.stochastic_duration_predictor)
95 |
96 | def forward(self, spec_mask, z, logs_q, phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk):
97 | # encode text
98 | text_encoded, m_p, logs_p, text_mask = self.text_encoder(phoneme, phoneme_len, lm_feat, lm_feat_len, spk, lang)
99 |
100 | # remove speaker infomation
101 | z_p = self.flow(z, spec_mask, spk)
102 |
103 | # search path
104 | MAS_path = search_path(z_p, m_p, logs_p, text_mask, spec_mask)
105 |
106 | # KL Divergence loss
107 | m_p = torch.matmul(MAS_path.squeeze(1), m_p.mT).mT
108 | logs_p = torch.matmul(MAS_path.squeeze(1), logs_p.mT).mT
109 | loss_kl = kl_divergence_loss(z_p, logs_q, m_p, logs_p, spec_mask)
110 |
111 | # calculate duration each phonemes
112 | duration = MAS_path.sum(2)
113 | loss_sdp = self.stochastic_duration_predictor(
114 | text_encoded,
115 | text_mask,
116 | w=duration,
117 | g=spk
118 | ).sum() / text_mask.sum()
119 |
120 | logw_y = torch.log(duration + 1e-6) * text_mask
121 | logw_x = self.duration_predictor(text_encoded, text_mask, spk)
122 | loss_dp = torch.sum(((logw_x - logw_y) ** 2) * text_mask) / torch.sum(text_mask)
123 |
124 | # predict duration
125 | fake_log_duration = self.duration_predictor(text_encoded, text_mask, spk)
126 | real_log_duration = logw_y
127 |
128 | loss_dict = {
129 | "StochasticDurationPredictor": loss_sdp.item(),
130 | "DurationPredictor": loss_dp.item(),
131 | "KL Divergence": loss_kl.item(),
132 | }
133 |
134 | loss = loss_sdp + loss_dp + loss_kl
135 | return loss, loss_dict, (text_encoded.detach(), text_mask, fake_log_duration, real_log_duration)
136 |
137 | def text_to_speech(self, phoneme, phoneme_len, lm_feat, lm_feat_len, lang, spk, noise_scale=0.6, max_frames=2000, use_sdp=True, duration_scale=1.0):
138 | # encode text
139 | text_encoded, m_p, logs_p, text_mask = self.text_encoder(phoneme, phoneme_len, lm_feat, lm_feat_len, spk, lang)
140 |
141 | # predict duration
142 | if use_sdp:
143 | log_duration = self.stochastic_duration_predictor(text_encoded, text_mask, g=spk, reverse=True)
144 | else:
145 | log_duration = self.duration_predictor(text_encoded, text_mask, spk)
146 | duration = torch.exp(log_duration)
147 | duration = duration * text_mask * duration_scale
148 | duration = torch.ceil(duration)
149 |
150 | spec_len = torch.clamp_min(torch.sum(duration, dim=(1, 2)), 1).long()
151 | spec_len = torch.clamp_max(spec_len, max_frames)
152 | spec_mask = sequence_mask(spec_len).unsqueeze(1).to(text_mask.dtype)
153 |
154 | MAS_node_mask = text_mask.unsqueeze(2) * spec_mask.unsqueeze(-1)
155 | MAS_path = generate_path(duration, MAS_node_mask).float()
156 |
157 | # projection
158 | m_p = torch.matmul(MAS_path.squeeze(1), m_p.mT).mT
159 | logs_p = torch.matmul(MAS_path.squeeze(1), logs_p.mT).mT
160 |
161 | # sample from gaussian
162 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
163 |
164 | # crop max frames
165 | if z_p.shape[2] > max_frames:
166 | z_p = z_p[:, :, :max_frames]
167 |
168 | # add speaker infomation
169 | z = self.flow(z_p, spec_mask, spk, reverse=True)
170 |
171 | return z
--------------------------------------------------------------------------------
/module/vits/speaker_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SpeakerEmbedding(nn.Module):
6 | def __init__(self, num_speakers=8192, embedding_dim=256):
7 | super().__init__()
8 | self.embedding = nn.Embedding(num_speakers, embedding_dim)
9 |
10 | # i: [BatchSize]
11 | # Output: [BatchSize, embedding_dim, 1]
12 | def forward(self, i):
13 | x = self.embedding(i)
14 | x = x.unsqueeze(2)
15 | return x
16 |
--------------------------------------------------------------------------------
/module/vits/spectrogram.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # wave: [BatchSize, 1, Length]
7 | # Output: [BatchSize, 1, Frames]
8 | def spectrogram(wave, n_fft, hop_size, power=2.0):
9 | dtype = wave.dtype
10 | wave = wave.to(torch.float)
11 | window = torch.hann_window(n_fft, device=wave.device)
12 | spec = torch.stft(wave, n_fft, hop_size, return_complex=True, window=window).abs()
13 | spec = spec[:, :, 1:]
14 | spec = spec ** power
15 | spec = spec.to(dtype)
16 | return spec
17 |
18 |
--------------------------------------------------------------------------------
/module/vits/text_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .transformer import RelativePositionTransformerDecoder
6 |
7 |
8 | class TextEncoder(nn.Module):
9 | def __init__(
10 | self,
11 | num_phonemes=512,
12 | num_languages=256,
13 | internal_channels=256,
14 | speaker_embedding_dim=256,
15 | content_channels=192,
16 | n_heads=4,
17 | lm_dim=768,
18 | kernel_size=1,
19 | dropout=0.0,
20 | window_size=4,
21 | num_layers=4,
22 | ):
23 | super().__init__()
24 | self.lm_proj = nn.Linear(lm_dim, internal_channels)
25 | self.phoneme_embedding = nn.Embedding(num_phonemes, internal_channels)
26 | self.language_embedding = nn.Embedding(num_languages, internal_channels)
27 | self.speaker_input = nn.Conv1d(speaker_embedding_dim, internal_channels, 1)
28 | self.transformer = RelativePositionTransformerDecoder(
29 | internal_channels,
30 | internal_channels * 4,
31 | n_heads,
32 | num_layers,
33 | kernel_size,
34 | dropout,
35 | window_size
36 | )
37 | self.post = nn.Conv1d(internal_channels, content_channels * 2, 1)
38 |
39 | # Note:
40 | # x: Phoneme IDs
41 | # x_length: length of phoneme sequence, for generating mask
42 | # y: language model's output features
43 | # y_length: length of language model's output features, for generating mask
44 | # lang: Language ID
45 | #
46 | # x: [BatchSize, Length_x]
47 | # x_length: [BatchSize]
48 | # y: [BatchSize, Length_y, lm_dim]
49 | # y_length [BatchSize]
50 | # spk: [BatchSize, speaker_embedding_dim, 1]
51 | # lang: [BatchSize]
52 | #
53 | # Outputs:
54 | # z: [BatchSize, content_channels, Length_x]
55 | # mean: [BatchSize, content_channels, Length_x]
56 | # logvar: [BatchSize, content_channels, Length_x]
57 | # z_mask: [BatchSize, 1, Length_x]
58 | def forward(self, x, x_length, y, y_length, spk, lang):
59 | # generate mask
60 | # x mask
61 | max_length = x.shape[1]
62 | progression = torch.arange(max_length, dtype=x_length.dtype, device=x_length.device)
63 | x_mask = (progression.unsqueeze(0) < x_length.unsqueeze(1))
64 | x_mask = x_mask.unsqueeze(1).to(y.dtype)
65 | z_mask = x_mask
66 |
67 | # y mask
68 | max_length = y.shape[1]
69 | progression = torch.arange(max_length, dtype=y_length.dtype, device=y_length.device)
70 | y_mask = (progression.unsqueeze(0) < y_length.unsqueeze(1))
71 | y_mask = y_mask.unsqueeze(1).to(y.dtype)
72 |
73 | # pass network
74 | y = self.lm_proj(y).mT # [B, C, L_y] where C = internal_channels, L_y = Length_y, L_x = Length_x, B = BatchSize
75 | x = self.phoneme_embedding(x) # [B, L_x, C]
76 | lang = self.language_embedding(lang) # [B, C]
77 | lang = lang.unsqueeze(1) # [B, 1, C]
78 | x = x + lang # language conditioning
79 | x = x.mT # [B, C, L_x]
80 | x = x + self.speaker_input(spk)
81 | x = self.transformer(x, x_mask, y, y_mask) # [B, C, L_x]
82 | x = self.post(x) * x_mask # [B, 2C, L_x]
83 | mean, logvar = torch.chunk(x, 2, dim=1)
84 | z = mean + torch.randn_like(mean) * torch.exp(logvar) * z_mask
85 | return z, mean, logvar, z_mask
86 |
--------------------------------------------------------------------------------
/module/vits/transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from module.utils.common import convert_pad_shape
7 | from .normalization import LayerNorm
8 |
9 |
10 | class RelativePositionTransformerEncoder(nn.Module):
11 | def __init__(
12 | self,
13 | hidden_channels: int,
14 | hidden_channels_ffn: int,
15 | n_heads: int,
16 | n_layers: int,
17 | kernel_size=1,
18 | dropout=0.0,
19 | window_size=4,
20 | ):
21 | super().__init__()
22 | self.n_layers = n_layers
23 |
24 | self.drop = nn.Dropout(dropout)
25 | self.attn_layers = nn.ModuleList()
26 | self.norm_layers_1 = nn.ModuleList()
27 | self.ffn_layers = nn.ModuleList()
28 | self.norm_layers_2 = nn.ModuleList()
29 | for i in range(self.n_layers):
30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=dropout, window_size=window_size))
31 | self.norm_layers_1.append(LayerNorm(hidden_channels))
32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, hidden_channels_ffn, kernel_size, p_dropout=dropout))
33 | self.norm_layers_2.append(LayerNorm(hidden_channels))
34 |
35 | # x: [BatchSize, hidden_channels, Length]
36 | # x_mask: [BatchSize, 1, Length]
37 | #
38 | # Output: [BatchSize, hidden_channels, Length]
39 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
40 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
41 | x = x * x_mask
42 | for i in range(self.n_layers):
43 | res = x
44 | x = self.attn_layers[i](x, x, attn_mask)
45 | x = self.drop(x)
46 | x = self.norm_layers_1[i](x + res)
47 |
48 | res = x
49 | x = self.ffn_layers[i](x, x_mask)
50 | x = self.drop(x)
51 | x = self.norm_layers_2[i](x + res)
52 | x = x * x_mask
53 | return x
54 |
55 |
56 | class RelativePositionTransformerDecoder(nn.Module):
57 | def __init__(
58 | self,
59 | hidden_channels: int,
60 | hidden_channels_ffn: int,
61 | n_heads: int,
62 | n_layers: int,
63 | kernel_size=1,
64 | dropout=0.0,
65 | window_size=4,
66 | ):
67 | super().__init__()
68 | self.n_layers = n_layers
69 |
70 | self.drop = nn.Dropout(dropout)
71 | self.self_attn_layers = nn.ModuleList()
72 | self.norm_layers_1 = nn.ModuleList()
73 | self.cross_attn_layers = nn.ModuleList()
74 | self.norm_layers_2 = nn.ModuleList()
75 | self.ffn_layers = nn.ModuleList()
76 | self.norm_layers_3 = nn.ModuleList()
77 | for i in range(self.n_layers):
78 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=dropout, window_size=window_size))
79 | self.norm_layers_1.append(LayerNorm(hidden_channels))
80 | self.cross_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=dropout))
81 | self.norm_layers_2.append(LayerNorm(hidden_channels))
82 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, hidden_channels_ffn, kernel_size, p_dropout=dropout))
83 | self.norm_layers_3.append(LayerNorm(hidden_channels))
84 |
85 | # x: [BatchSize, hidden_channels, Length_x]
86 | # x_mask: [BatchSize, 1, Length_x]
87 | # y: [BatchSize, hidden_channels, Length_y]
88 | # y_mask: [BatchSize, 1, Length_y]
89 | #
90 | # Output: [BatchSize, hidden_channels, Length_x]
91 | def forward(self, x: torch.Tensor, x_mask: torch.Tensor, y: torch.Tensor, y_mask: torch.Tensor):
92 | self_attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
93 | cross_attn_mask = y_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
94 | x = x * x_mask
95 | for i in range(self.n_layers):
96 | res = x
97 | x = self.self_attn_layers[i](x, x, self_attn_mask)
98 | x = self.drop(x)
99 | x = self.norm_layers_1[i](x + res)
100 |
101 | res = x
102 | x = self.cross_attn_layers[i](x, y, cross_attn_mask)
103 | x = self.drop(x)
104 | x = self.norm_layers_2[i](x + res)
105 |
106 | res = x
107 | x = self.ffn_layers[i](x, x_mask)
108 | x = self.drop(x)
109 | x = self.norm_layers_3[i](x + res)
110 | x = x * x_mask
111 | return x
112 |
113 |
114 | class MultiHeadAttention(nn.Module):
115 | def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
116 | super().__init__()
117 | assert channels % n_heads == 0
118 |
119 | self.channels = channels
120 | self.out_channels = out_channels
121 | self.n_heads = n_heads
122 | self.p_dropout = p_dropout
123 | self.window_size = window_size
124 | self.heads_share = heads_share
125 | self.block_length = block_length
126 | self.proximal_bias = proximal_bias
127 | self.proximal_init = proximal_init
128 | self.attn = None
129 |
130 | self.k_channels = channels // n_heads
131 | self.conv_q = nn.Linear(channels, channels)
132 | self.conv_k = nn.Linear(channels, channels)
133 | self.conv_v = nn.Linear(channels, channels)
134 | self.conv_o = nn.Linear(channels, out_channels)
135 | self.drop = nn.Dropout(p_dropout)
136 |
137 | if window_size is not None:
138 | n_heads_rel = 1 if heads_share else n_heads
139 | rel_stddev = self.k_channels**-0.5
140 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
141 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
142 |
143 | nn.init.xavier_uniform_(self.conv_q.weight)
144 | nn.init.xavier_uniform_(self.conv_k.weight)
145 | nn.init.xavier_uniform_(self.conv_v.weight)
146 | if proximal_init:
147 | with torch.no_grad():
148 | self.conv_k.weight.copy_(self.conv_q.weight)
149 | self.conv_k.bias.copy_(self.conv_q.bias)
150 |
151 | def forward(self, x, c, attn_mask=None):
152 | q = self.conv_q(x.mT).mT
153 | k = self.conv_k(c.mT).mT
154 | v = self.conv_v(c.mT).mT
155 |
156 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
157 |
158 | x = self.conv_o(x.mT).mT
159 | return x
160 |
161 | # query: [b, n_h, t_t, d_k]
162 | # key: [b, n_h, t_s, d_k]
163 | # value: [b, n_h, t_s, d_k]
164 | # mask: [b, n_h, t_t, t_s]
165 | def attention(self, query, key, value, mask=None):
166 | # reshape [b, d, t] -> [b, n_h, t, d_k]
167 | b, d, t_s, t_t = (*key.size(), query.size(2))
168 | query = query.view(b, self.n_heads, self.k_channels, t_t).mT # [b, n_h, t_t, d_k]
169 | key = key.view(b, self.n_heads, self.k_channels, t_s).mT # [b, n_h, t_s, d_k]
170 | value = value.view(b, self.n_heads, self.k_channels, t_s).mT # [b, n_h, t_s, d_k]
171 |
172 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT) # [b, n_h, t_t, t_s]
173 | if self.window_size is not None:
174 | assert t_s == t_t, "Relative attention is only available for self-attention."
175 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
176 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
177 | scores_local = self._relative_position_to_absolute_position(rel_logits)
178 | scores = scores + scores_local
179 | if self.proximal_bias:
180 | assert t_s == t_t, "Proximal bias is only available for self-attention."
181 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
182 | if mask is not None:
183 | scores = scores.masked_fill(mask == 0, -1e4)
184 | if self.block_length is not None:
185 | assert t_s == t_t, "Local attention is only available for self-attention."
186 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
187 | scores = scores.masked_fill(block_mask == 0, -1e4)
188 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
189 | p_attn = self.drop(p_attn)
190 | output = torch.matmul(p_attn, value)
191 | if self.window_size is not None:
192 | relative_weights = self._absolute_position_to_relative_position(p_attn)
193 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
194 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
195 | output = output.mT.contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
196 | return output, p_attn
197 |
198 | def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
199 | """
200 | x: [b, h, l, m]
201 | y: [h or 1, m, d]
202 | ret: [b, h, l, d]
203 | """
204 | return torch.matmul(x, y.unsqueeze(0))
205 |
206 | def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
207 | """
208 | x: [b, h, l, d]
209 | y: [h or 1, m, d]
210 | ret: [b, h, l, m]
211 | """
212 | return torch.matmul(x, y.unsqueeze(0).mT)
213 |
214 | def _get_relative_embeddings(self, relative_embeddings, length):
215 | max_relative_position = 2 * self.window_size + 1
216 | # Pad first before slice to avoid using cond ops.
217 | pad_length = max(length - (self.window_size + 1), 0)
218 | slice_start_position = max((self.window_size + 1) - length, 0)
219 | slice_end_position = slice_start_position + 2 * length - 1
220 | if pad_length > 0:
221 | padded_relative_embeddings = F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
222 | else:
223 | padded_relative_embeddings = relative_embeddings
224 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
225 | return used_relative_embeddings
226 |
227 | def _relative_position_to_absolute_position(self, x):
228 | """
229 | x: [b, h, l, 2*l-1]
230 | ret: [b, h, l, l]
231 | """
232 | batch, heads, length, _ = x.size()
233 | # Concat columns of pad to shift from relative to absolute indexing.
234 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
235 |
236 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
237 | x_flat = x.view([batch, heads, length * 2 * length])
238 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
239 |
240 | # Reshape and slice out the padded elements.
241 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
242 | return x_final
243 |
244 | def _absolute_position_to_relative_position(self, x):
245 | """
246 | x: [b, h, l, l]
247 | ret: [b, h, l, 2*l-1]
248 | """
249 | batch, heads, length, _ = x.size()
250 | # padd along column
251 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
252 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
253 | # add 0's in the beginning that will skew the elements after reshape
254 | x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
255 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
256 | return x_final
257 |
258 | def _attention_bias_proximal(self, length):
259 | """Bias for self-attention to encourage attention to close positions.
260 | Args:
261 | length: an integer scalar.
262 | Returns:
263 | a Tensor with shape [1, 1, length, length]
264 | """
265 | r = torch.arange(length, dtype=torch.float32)
266 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
267 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
268 |
269 |
270 | class FFN(nn.Module):
271 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, causal=False):
272 | super().__init__()
273 | self.kernel_size = kernel_size
274 | self.padding = self._causal_padding if causal else self._same_padding
275 |
276 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
277 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
278 | self.drop = nn.Dropout(p_dropout)
279 |
280 | def forward(self, x, x_mask):
281 | x = self.conv_1(self.padding(x * x_mask))
282 | x = torch.relu(x)
283 | x = self.drop(x)
284 | x = self.conv_2(self.padding(x * x_mask))
285 | return x * x_mask
286 |
287 | def _causal_padding(self, x):
288 | if self.kernel_size == 1:
289 | return x
290 | pad_l = self.kernel_size - 1
291 | pad_r = 0
292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
293 | x = F.pad(x, convert_pad_shape(padding))
294 | return x
295 |
296 | def _same_padding(self, x):
297 | if self.kernel_size == 1:
298 | return x
299 | pad_l = (self.kernel_size - 1) // 2
300 | pad_r = self.kernel_size // 2
301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
302 | x = F.pad(x, convert_pad_shape(padding))
303 | return x
304 |
--------------------------------------------------------------------------------
/module/vits/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | # transforms for stochasitc duration predictor
8 |
9 |
10 | DEFAULT_MIN_BIN_WIDTH = 1e-3
11 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
12 | DEFAULT_MIN_DERIVATIVE = 1e-3
13 |
14 |
15 | def piecewise_rational_quadratic_transform(
16 | inputs,
17 | unnormalized_widths,
18 | unnormalized_heights,
19 | unnormalized_derivatives,
20 | inverse=False,
21 | tails=None,
22 | tail_bound=1.0,
23 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
24 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
25 | min_derivative=DEFAULT_MIN_DERIVATIVE,
26 | ):
27 | if tails is None:
28 | spline_fn = rational_quadratic_spline
29 | spline_kwargs = {}
30 | else:
31 | spline_fn = unconstrained_rational_quadratic_spline
32 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
33 |
34 | outputs, logabsdet = spline_fn(
35 | inputs=inputs,
36 | unnormalized_widths=unnormalized_widths,
37 | unnormalized_heights=unnormalized_heights,
38 | unnormalized_derivatives=unnormalized_derivatives,
39 | inverse=inverse,
40 | min_bin_width=min_bin_width,
41 | min_bin_height=min_bin_height,
42 | min_derivative=min_derivative,
43 | **spline_kwargs
44 | )
45 | return outputs, logabsdet
46 |
47 |
48 | def searchsorted(bin_locations, inputs, eps=1e-6):
49 | bin_locations[..., -1] += eps
50 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
51 |
52 |
53 | def unconstrained_rational_quadratic_spline(
54 | inputs,
55 | unnormalized_widths,
56 | unnormalized_heights,
57 | unnormalized_derivatives,
58 | inverse=False,
59 | tails="linear",
60 | tail_bound=1.0,
61 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
62 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
63 | min_derivative=DEFAULT_MIN_DERIVATIVE,
64 | ):
65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66 | outside_interval_mask = ~inside_interval_mask
67 |
68 | outputs = torch.zeros_like(inputs)
69 | logabsdet = torch.zeros_like(inputs)
70 |
71 | if tails == "linear":
72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73 | constant = np.log(np.exp(1 - min_derivative) - 1)
74 | unnormalized_derivatives[..., 0] = constant
75 | unnormalized_derivatives[..., -1] = constant
76 |
77 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
78 | logabsdet[outside_interval_mask] = 0
79 | else:
80 | raise RuntimeError("{} tails are not implemented.".format(tails))
81 |
82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound,
89 | right=tail_bound,
90 | bottom=-tail_bound,
91 | top=tail_bound,
92 | min_bin_width=min_bin_width,
93 | min_bin_height=min_bin_height,
94 | min_derivative=min_derivative,
95 | )
96 |
97 | return outputs, logabsdet
98 |
99 |
100 | def rational_quadratic_spline(
101 | inputs,
102 | unnormalized_widths,
103 | unnormalized_heights,
104 | unnormalized_derivatives,
105 | inverse=False,
106 | left=0.0,
107 | right=1.0,
108 | bottom=0.0,
109 | top=1.0,
110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112 | min_derivative=DEFAULT_MIN_DERIVATIVE,
113 | ):
114 | if torch.min(inputs) < left or torch.max(inputs) > right:
115 | raise ValueError("Input to a transform is not within its domain")
116 |
117 | num_bins = unnormalized_widths.shape[-1]
118 |
119 | if min_bin_width * num_bins > 1.0:
120 | raise ValueError("Minimal bin width too large for the number of bins")
121 | if min_bin_height * num_bins > 1.0:
122 | raise ValueError("Minimal bin height too large for the number of bins")
123 |
124 | widths = F.softmax(unnormalized_widths, dim=-1)
125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126 | cumwidths = torch.cumsum(widths, dim=-1)
127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128 | cumwidths = (right - left) * cumwidths + left
129 | cumwidths[..., 0] = left
130 | cumwidths[..., -1] = right
131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132 |
133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134 |
135 | heights = F.softmax(unnormalized_heights, dim=-1)
136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137 | cumheights = torch.cumsum(heights, dim=-1)
138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139 | cumheights = (top - bottom) * cumheights + bottom
140 | cumheights[..., 0] = bottom
141 | cumheights[..., -1] = top
142 | heights = cumheights[..., 1:] - cumheights[..., :-1]
143 |
144 | if inverse:
145 | bin_idx = searchsorted(cumheights, inputs)[..., None]
146 | else:
147 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
148 |
149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151 |
152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153 | delta = heights / widths
154 | input_delta = delta.gather(-1, bin_idx)[..., 0]
155 |
156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158 |
159 | input_heights = heights.gather(-1, bin_idx)[..., 0]
160 |
161 | if inverse:
162 | a = (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + input_heights * (input_delta - input_derivatives)
163 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
164 | c = -input_delta * (inputs - input_cumheights)
165 |
166 | discriminant = b.pow(2) - 4 * a * c
167 | assert (discriminant >= 0).all()
168 |
169 | root = (2 * c) / (-b - torch.sqrt(discriminant))
170 | outputs = root * input_bin_widths + input_cumwidths
171 |
172 | theta_one_minus_theta = root * (1 - root)
173 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta)
174 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176 |
177 | return outputs, -logabsdet
178 | else:
179 | theta = (inputs - input_cumwidths) / input_bin_widths
180 | theta_one_minus_theta = theta * (1 - theta)
181 |
182 | numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
183 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta)
184 | outputs = input_cumheights + numerator / denominator
185 |
186 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
187 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
188 |
189 | return outputs, logabsdet
190 |
--------------------------------------------------------------------------------
/module/vits/wn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.parametrizations import weight_norm
4 | from torch.nn.utils import remove_weight_norm
5 |
6 |
7 | # WN module from https://arxiv.org/abs/1609.03499
8 | class WNLayer(nn.Module):
9 | def __init__(self,
10 | hidden_channels=192,
11 | kernel_size=5,
12 | dilation=1,
13 | speaker_embedding_dim=256):
14 | super().__init__()
15 | self.speaker_in = weight_norm(
16 | nn.Conv1d(
17 | speaker_embedding_dim ,
18 | hidden_channels * 2, 1))
19 | padding = int((kernel_size * dilation - dilation) / 2)
20 | self.conv = weight_norm(
21 | nn.Conv1d(
22 | hidden_channels,
23 | hidden_channels * 2,
24 | kernel_size,
25 | 1,
26 | padding,
27 | dilation=dilation,
28 | padding_mode='replicate'))
29 | self.out = weight_norm(
30 | nn.Conv1d(hidden_channels, hidden_channels * 2, 1))
31 |
32 | # x: [BatchSize, hidden_channels, Length]
33 | # x_mask: [BatchSize, 1, Length]
34 | # g: [BatchSize, speaker_embedding_dim, 1]
35 | # Output: [BatchSize, hidden_channels, Length]
36 | def forward(self, x, x_mask, g):
37 | res = x
38 | x = self.conv(x) + self.speaker_in(g)
39 | x_0, x_1 = torch.chunk(x, 2, dim=1)
40 | x = torch.tanh(x_0) * torch.sigmoid(x_1)
41 | x = self.out(x)
42 | out, skip = torch.chunk(x, 2, dim=1)
43 | out = (out + res) * x_mask
44 | skip = skip * x_mask
45 | return out, skip
46 |
47 | def remove_weight_norm(self):
48 | remove_weight_norm(self.speaker_in)
49 | remove_weight_norm(self.conv)
50 | remove_weight_norm(self.out)
51 |
52 |
53 | class WN(nn.Module):
54 | def __init__(self,
55 | hidden_channels=192,
56 | kernel_size=5,
57 | dilation=1,
58 | num_layers=4,
59 | speaker_embedding_dim=256):
60 | super().__init__()
61 | self.layers = nn.ModuleList([])
62 | for _ in range(num_layers):
63 | self.layers.append(
64 | WNLayer(hidden_channels, kernel_size, dilation, speaker_embedding_dim))
65 |
66 | # x: [BatchSize, hidden_channels, Length]
67 | # x_mask: [BatchSize, 1, Length]
68 | # g: [BatchSize, speaker_embedding_dim, 1]
69 | # Output: [BatchSize, hidden_channels, Length]
70 | def forward(self, x, x_mask, g):
71 | output = None
72 | for layer in self.layers:
73 | x, skip = layer(x, x_mask, g)
74 | if output is None:
75 | output = skip
76 | else:
77 | output += skip
78 | return output
79 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | from pathlib import Path
5 | import shutil
6 |
7 | from module.preprocess.jvs import preprocess_jvs
8 | from module.preprocess.wave_and_text import preprocess_wave_and_text
9 | from module.preprocess.scan import scan_cache
10 | from module.utils.config import load_json_file
11 |
12 |
13 | def get_preprocess_method(dataset_type):
14 | if dataset_type == 'jvs':
15 | return preprocess_jvs
16 | if dataset_type == 'wav-txt':
17 | return preprocess_wave_and_text
18 | else:
19 | raise "Unknown dataset type"
20 |
21 |
22 | parser = argparse.ArgumentParser("preprocess")
23 | parser.add_argument('type')
24 | parser.add_argument('root_dir')
25 | parser.add_argument('-c', '--config', default='./config/base.json')
26 | parser.add_argument('--scan-only', default=False, type=bool)
27 |
28 | args = parser.parse_args()
29 |
30 | config = load_json_file(args.config)
31 | root_dir = Path(args.root_dir)
32 | dataset_type = args.type
33 |
34 | cache_dir = Path(config.preprocess.cache)
35 | if not cache_dir.exists():
36 | cache_dir.mkdir()
37 |
38 | preprocess_method = get_preprocess_method(dataset_type)
39 |
40 | if not args.scan_only:
41 | print(f"Start preprocess type={dataset_type}, root={str(root_dir)}")
42 | preprocess_method(root_dir, config)
43 |
44 | print(f"Scaning dataset cache")
45 | scan_cache(config)
46 | shutil.copy(args.config, 'models/config.json')
47 |
48 | print(f"Complete!")
49 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | setuptools
2 | torch
3 | torchaudio
4 | tqdm
5 | numpy
6 | onnx
7 | transformers
8 | pyworld
9 | torchfcpe
10 | numba
11 | cython
12 | pyopenjtalk
13 | tensorboard
14 | g2p_en
15 | safetensors
16 | lightning
17 | sentencepiece
18 | sox
19 | soundfile
20 | music21
21 | gradio
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import random
5 | from pathlib import Path
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 |
12 | import lightning as L
13 | import pytorch_lightning
14 | from module.utils.config import load_json_file
15 | from module.vits import Vits
16 | from module.utils.dataset import VitsDataModule
17 | from module.utils.safetensors import save_tensors
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser(description="train")
21 | parser.add_argument('-c', '--config', default='config/base.json')
22 | args = parser.parse_args()
23 |
24 | class SaveCheckpoint(L.Callback):
25 | def __init__(self, models_dir, interval=200):
26 | super().__init__()
27 | self.models_dir = Path(models_dir)
28 | self.interval = interval
29 | if not self.models_dir.exists():
30 | self.models_dir.mkdir()
31 |
32 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
33 | if batch_idx % self.interval == 0:
34 | ckpt_path = self.models_dir / "vits.ckpt"
35 | trainer.save_checkpoint(ckpt_path)
36 | generator_path = self.models_dir / "generator.safetensors"
37 | save_tensors(pl_module.generator.state_dict(), generator_path)
38 |
39 | config = load_json_file(args.config)
40 | dm = VitsDataModule(**config.train.data_module)
41 | model_path = Path(config.train.save.models_dir) / "vits.ckpt"
42 |
43 | if model_path.exists():
44 | print(f"loading checkpoint from {model_path}")
45 | model = Vits.load_from_checkpoint(model_path)
46 | else:
47 | print("initialize model")
48 | model = Vits(config.vits)
49 |
50 | print("if you need to check tensorboard, run `tensorboard -logdir lightning_logs`")
51 | cb_save_checkpoint = SaveCheckpoint(config.train.save.models_dir, config.train.save.interval)
52 | trainer = L.Trainer(**config.train.trainer, callbacks=[cb_save_checkpoint])
53 | trainer.fit(model, dm)
54 |
--------------------------------------------------------------------------------