├── .idea ├── .gitignore ├── VITS_voice_conversion.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── DATA.MD ├── DATA_EN.MD ├── LICENSE ├── LOCAL.md ├── README.md ├── README_ZH.md ├── VC_inference.py ├── attentions.py ├── cmd_inference.py ├── commons.py ├── configs ├── modified_finetune_speaker.json └── uma_trilingual.json ├── data_utils.py ├── finetune_speaker_v2.py ├── losses.py ├── mel_processing.py ├── models.py ├── models_infer.py ├── modules.py ├── monotonic_align ├── __init__.py ├── core.pyx └── setup.py ├── preprocess_v2.py ├── requirements.txt ├── scripts ├── denoise_audio.py ├── download_model.py ├── download_video.py ├── long_audio_transcribe.py ├── rearrange_speaker.py ├── resample.py ├── short_audio_transcribe.py ├── video2audio.py └── voice_upload.py ├── text ├── LICENSE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── cleaners.cpython-37.pyc │ ├── english.cpython-37.pyc │ ├── japanese.cpython-37.pyc │ ├── korean.cpython-37.pyc │ ├── mandarin.cpython-37.pyc │ ├── sanskrit.cpython-37.pyc │ ├── symbols.cpython-37.pyc │ └── thai.cpython-37.pyc ├── cantonese.py ├── cleaners.py ├── english.py ├── japanese.py ├── korean.py ├── mandarin.py ├── ngu_dialect.py ├── sanskrit.py ├── shanghainese.py ├── symbols.py └── thai.py ├── transforms.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/VITS_voice_conversion.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 154 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /DATA.MD: -------------------------------------------------------------------------------- 1 | 本仓库的pipeline支持多种声音样本上传方式,您只需根据您所持有的样本选择任意一种或其中几种即可。 2 | 3 | 1.`.zip`文件打包的,按角色名排列的短音频,该压缩文件结构应如下所示: 4 | ``` 5 | Your-zip-file.zip 6 | ├───Character_name_1 7 | ├ ├───xxx.wav 8 | ├ ├───... 9 | ├ ├───yyy.mp3 10 | ├ └───zzz.wav 11 | ├───Character_name_2 12 | ├ ├───xxx.wav 13 | ├ ├───... 14 | ├ ├───yyy.mp3 15 | ├ └───zzz.wav 16 | ├───... 17 | ├ 18 | └───Character_name_n 19 | ├───xxx.wav 20 | ├───... 21 | ├───yyy.mp3 22 | └───zzz.wav 23 | ``` 24 | 注意音频的格式和名称都不重要,只要它们是音频文件。 25 | 质量要求:2秒以上,10秒以内,尽量不要有背景噪音。 26 | 数量要求:一个角色至少10条,最好每个角色20条以上。 27 | 2. 以角色名命名的长音频文件,音频内只能有单说话人,背景音会被自动去除。命名格式为:`{CharacterName}_{random_number}.wav` 28 | (例如:`Diana_234135.wav`, `MinatoAqua_234252.wav`),必须是`.wav`文件,长度要在20分钟以内(否则会内存不足)。 29 | 30 | 3. 以角色名命名的长视频文件,视频内只能有单说话人,背景音会被自动去除。命名格式为:`{CharacterName}_{random_number}.mp4` 31 | (例如:`Taffy_332452.mp4`, `Dingzhen_957315.mp4`),必须是`.mp4`文件,长度要在20分钟以内(否则会内存不足)。 32 | 注意:命名中,`CharacterName`必须是英文字符,`random_number`是为了区分同一个角色的多个文件,必须要添加,该数字可以为0~999999之间的任意整数。 33 | 34 | 4. 包含多行`{CharacterName}|{video_url}`的`.txt`文件,格式应如下所示: 35 | ``` 36 | Char1|https://xyz.com/video1/ 37 | Char2|https://xyz.com/video2/ 38 | Char2|https://xyz.com/video3/ 39 | Char3|https://xyz.com/video4/ 40 | ``` 41 | 视频内只能有单说话人,背景音会被自动去除。目前仅支持来自bilibili的视频,其它网站视频的url还没测试过。 42 | 若对格式有疑问,可以在[这里](https://drive.google.com/file/d/132l97zjanpoPY4daLgqXoM7HKXPRbS84/view?usp=sharing)找到所有格式对应的数据样本。 43 | -------------------------------------------------------------------------------- /DATA_EN.MD: -------------------------------------------------------------------------------- 1 | The pipeline of this repo supports multiple voice uploading options,you can choose one or more options depending on the data you have. 2 | 3 | 1. Short audios packed by a single `.zip` file, whose file structure should be as shown below: 4 | ``` 5 | Your-zip-file.zip 6 | ├───Character_name_1 7 | ├ ├───xxx.wav 8 | ├ ├───... 9 | ├ ├───yyy.mp3 10 | ├ └───zzz.wav 11 | ├───Character_name_2 12 | ├ ├───xxx.wav 13 | ├ ├───... 14 | ├ ├───yyy.mp3 15 | ├ └───zzz.wav 16 | ├───... 17 | ├ 18 | └───Character_name_n 19 | ├───xxx.wav 20 | ├───... 21 | ├───yyy.mp3 22 | └───zzz.wav 23 | ``` 24 | Note that the format of the audio files does not matter as long as they are audio files。 25 | Quality requirement: >=2s, <=10s, contain as little background sound as possible. 26 | Quantity requirement: at least 10 per character, 20+ per character is recommended. 27 | 28 | 2. Long audio files named by character names, which should contain single character voice only. Background sound is 29 | acceptable since they will be automatically removed. File name format `{CharacterName}_{random_number}.wav` 30 | (E.G. `Diana_234135.wav`, `MinatoAqua_234252.wav`), must be `.wav` files. 31 | 32 | 33 | 3. Long video files named by character names, which should contain single character voice only. Background sound is 34 | acceptable since they will be automatically removed. File name format `{CharacterName}_{random_number}.mp4` 35 | (E.G. `Taffy_332452.mp4`, `Dingzhen_957315.mp4`), must be `.mp4` files. 36 | Note: `CharacterName` must be English characters only, `random_number` is to identify multiple files for one character, 37 | which is compulsory to add. It could be a random integer between 0~999999. 38 | 39 | 4. A `.txt` containing multiple lines of`{CharacterName}|{video_url}`, which should be formatted as follows: 40 | ``` 41 | Char1|https://xyz.com/video1/ 42 | Char2|https://xyz.com/video2/ 43 | Char2|https://xyz.com/video3/ 44 | Char3|https://xyz.com/video4/ 45 | ``` 46 | One video should contain single speaker only. Currently supports videos links from bilibili, other websites are yet to be tested. 47 | Having questions regarding to data format? Fine data samples of all format from [here](https://drive.google.com/file/d/132l97zjanpoPY4daLgqXoM7HKXPRbS84/view?usp=sharing). 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LOCAL.md: -------------------------------------------------------------------------------- 1 | # Train locally 2 | ### Build environment 3 | 0. Make sure you have installed `Python==3.8`, CMake & C/C++ compilers, ffmpeg; 4 | 1. Clone this repository; 5 | 2. Run `pip install -r requirements.txt`; 6 | 3. Install GPU version PyTorch: (Make sure you have CUDA 11.6 or 11.7 installed) 7 | ``` 8 | # CUDA 11.6 9 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 10 | # CUDA 11.7 11 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 12 | ``` 13 | 4. Install necessary libraries for dealing video data: 14 | ``` 15 | pip install imageio==2.4.1 16 | pip install moviepy 17 | ``` 18 | 5. Build monotonic align (necessary for training) 19 | ``` 20 | cd monotonic_align 21 | mkdir monotonic_align 22 | python setup.py build_ext --inplace 23 | cd .. 24 | ``` 25 | 6. Download auxiliary data for training 26 | ``` 27 | mkdir pretrained_models 28 | # download data for fine-tuning 29 | wget https://huggingface.co/datasets/Plachta/sampled_audio4ft/resolve/main/sampled_audio4ft_v2.zip 30 | unzip sampled_audio4ft_v2.zip 31 | # create necessary directories 32 | mkdir video_data 33 | mkdir raw_audio 34 | mkdir denoised_audio 35 | mkdir custom_character_voice 36 | mkdir segmented_character_voice 37 | ``` 38 | 7. Download pretrained model, available options are: 39 | ``` 40 | CJE: Trilingual (Chinese, Japanese, English) 41 | CJ: Dualigual (Chinese, Japanese) 42 | C: Chinese only 43 | ``` 44 | ### Linux 45 | To download `CJE` model, run the following: 46 | ``` 47 | wget https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/D_trilingual.pth -O ./pretrained_models/D_0.pth 48 | wget https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth -O ./pretrained_models/G_0.pth 49 | wget https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/configs/uma_trilingual.json -O ./configs/finetune_speaker.json 50 | ``` 51 | To download `CJ` model, run the following: 52 | ``` 53 | wget https://huggingface.co/spaces/sayashi/vits-uma-genshin-honkai/resolve/main/model/D_0-p.pth -O ./pretrained_models/D_0.pth 54 | wget https://huggingface.co/spaces/sayashi/vits-uma-genshin-honkai/resolve/main/model/G_0-p.pth -O ./pretrained_models/G_0.pth 55 | wget https://huggingface.co/spaces/sayashi/vits-uma-genshin-honkai/resolve/main/model/config.json -O ./configs/finetune_speaker.json 56 | ``` 57 | To download `C` model, run the follwoing: 58 | ``` 59 | wget https://huggingface.co/datasets/Plachta/sampled_audio4ft/resolve/main/VITS-Chinese/D_0.pth -O ./pretrained_models/D_0.pth 60 | wget https://huggingface.co/datasets/Plachta/sampled_audio4ft/resolve/main/VITS-Chinese/G_0.pth -O ./pretrained_models/G_0.pth 61 | wget https://huggingface.co/datasets/Plachta/sampled_audio4ft/resolve/main/VITS-Chinese/config.json -O ./configs/finetune_speaker.json 62 | ``` 63 | ### Windows 64 | Manually download `G_0.pth`, `D_0.pth`, `finetune_speaker.json` from the URLs in one of the options described above. 65 | 66 | Rename all `G` models to `G_0.pth`, `D` models to `D_0.pth`, config files (`.json`) to `finetune_speaker.json`. 67 | Put `G_0.pth`, `D_0.pth` under `pretrained_models` directory; 68 | Put `finetune_speaker.json` under `configs` directory 69 | 70 | #### Please note that when you download one of them, the previous model will be overwritten. 71 | 9. Put your voice data under corresponding directories, see [DATA.MD](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/DATA_EN.MD) for detailed different uploading options. 72 | ### Short audios 73 | 1. Prepare your data according to [DATA.MD](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/DATA_EN.MD) as a single `.zip` file; 74 | 2. Put your file under directory `./custom_character_voice/`; 75 | 3. run `unzip ./custom_character_voice/custom_character_voice.zip -d ./custom_character_voice/` 76 | 77 | ### Long audios 78 | 1. Name your audio files according to [DATA.MD](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/DATA_EN.MD); 79 | 2. Put your renamed audio files under directory `./raw_audio/` 80 | 81 | ### Videos 82 | 1. Name your video files according to [DATA.MD](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/DATA_EN.MD); 83 | 2. Put your renamed video files under directory `./video_data/` 84 | 10. Process all audio data. 85 | ``` 86 | python scripts/video2audio.py 87 | python scripts/denoise_audio.py 88 | python scripts/long_audio_transcribe.py --languages "{PRETRAINED_MODEL}" --whisper_size large 89 | python scripts/short_audio_transcribe.py --languages "{PRETRAINED_MODEL}" --whisper_size large 90 | python scripts/resample.py 91 | ``` 92 | Replace `"{PRETRAINED_MODEL}"` with one of `{CJ, CJE, C}` according to your previous model choice. 93 | Make sure you have a minimum GPU memory of 12GB. If not, change the argument `whisper_size` to `medium` or `small`. 94 | 95 | 10. Process all text data. 96 | If you choose to add auxiliary data, run `python preprocess_v2.py --add_auxiliary_data True --languages "{PRETRAINED_MODEL}"` 97 | If not, run `python preprocess_v2.py --languages "{PRETRAINED_MODEL}"` 98 | Do replace `"{PRETRAINED_MODEL}"` with one of `{CJ, CJE, C}` according to your previous model choice. 99 | 100 | 11. Start Training. 101 | Run `python finetune_speaker_v2.py -m ./OUTPUT_MODEL --max_epochs "{Maximum_epochs}" --drop_speaker_embed True` 102 | Do replace `{Maximum_epochs}` with your desired number of epochs. Empirically, 100 or more is recommended. 103 | To continue training on previous checkpoint, change the training command to: `python finetune_speaker_v2.py -m ./OUTPUT_MODEL --max_epochs "{Maximum_epochs}" --drop_speaker_embed False --cont True`. Before you do this, make sure you have previous `G_latest.pth` and `D_latest.pth` under `./OUTPUT_MODEL/` directory. 104 | To view training progress, open a new terminal and `cd` to the project root directory, run `tensorboard --logdir=./OUTPUT_MODEL`, then visit `localhost:6006` with your web browser. 105 | 106 | 12. After training is completed, you can use your model by running: 107 | `python VC_inference.py --model_dir ./OUTPUT_MODEL/G_latest.pth --share True` 108 | 13. To clear all audio data, run: 109 | ### Linux 110 | ``` 111 | rm -rf ./custom_character_voice/* ./video_data/* ./raw_audio/* ./denoised_audio/* ./segmented_character_voice/* ./separated/* long_character_anno.txt short_character_anno.txt 112 | ``` 113 | ### Windows 114 | ``` 115 | del /Q /S .\custom_character_voice\* .\video_data\* .\raw_audio\* .\denoised_audio\* .\segmented_character_voice\* .\separated\* long_character_anno.txt short_character_anno.txt 116 | ``` 117 | 118 | 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [中文文档请点击这里](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/README_ZH.md) 2 | # VITS Fast Fine-tuning 3 | This repo will guide you to add your own character voices, or even your own voice, into existing VITS TTS model 4 | to make it able to do the following tasks in less than 1 hour: 5 | 6 | 1. Many-to-many voice conversion between any characters you added & preset characters in the model. 7 | 2. English, Japanese & Chinese Text-to-Speech synthesis with the characters you added & preset characters 8 | 9 | 10 | Welcome to play around with the base models! 11 | Chinese & English & Japanese:[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer) Author: Me 12 | 13 | Chinese & Japanese:[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sayashi/vits-uma-genshin-honkai) Author: [SayaSS](https://github.com/SayaSS) 14 | 15 | Chinese only:(No running huggingface spaces) Author: [Wwwwhy230825](https://github.com/Wwwwhy230825) 16 | 17 | 18 | ### Currently Supported Tasks: 19 | - [x] Clone character voice from 10+ short audios 20 | - [x] Clone character voice from long audio(s) >= 3 minutes (one audio should contain single speaker only) 21 | - [x] Clone character voice from videos(s) >= 3 minutes (one video should contain single speaker only) 22 | - [x] Clone character voice from BILIBILI video links (one video should contain single speaker only) 23 | 24 | ### Currently Supported Characters for TTS & VC: 25 | - [x] Any character you wish as long as you have their voices! 26 | (Note that voice conversion can only be conducted between any two speakers in the model) 27 | 28 | 29 | 30 | ## Fine-tuning 31 | See [LOCAL.md](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/LOCAL.md) for local training guide. 32 | Alternatively, you can perform fine-tuning on [Google Colab](https://colab.research.google.com/drive/1pn1xnFfdLK63gVXDwV4zCXfVeo8c-I-0?usp=sharing) 33 | 34 | 35 | ### How long does it take? 36 | 1. Install dependencies (3 min) 37 | 2. Choose pretrained model to start. The detailed differences between them are described in [Colab Notebook](https://colab.research.google.com/drive/1pn1xnFfdLK63gVXDwV4zCXfVeo8c-I-0?usp=sharing) 38 | 3. Upload the voice samples of the characters you wish to add,see [DATA.MD](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/DATA_EN.MD) for detailed uploading options. 39 | 4. Start fine-tuning. Time taken varies from 20 minutes ~ 2 hours, depending on the number of voices you uploaded. 40 | 41 | 42 | ## Inference or Usage (Currently support Windows only) 43 | 0. Remember to download your fine-tuned model! 44 | 1. Download the latest release 45 | 2. Put your model & config file into the folder `inference`, which are named `G_latest.pth` and `finetune_speaker.json`, respectively. 46 | 3. The file structure should be as follows: 47 | ``` 48 | inference 49 | ├───inference.exe 50 | ├───... 51 | ├───finetune_speaker.json 52 | └───G_latest.pth 53 | ``` 54 | 4. run `inference.exe`, the browser should pop up automatically. 55 | 5. Note: you must install `ffmpeg` to enable voice conversion feature. 56 | 57 | 58 | ## Inference with CLI 59 | In this example, we will show how to run inference with the default pretrained model. We are now in the main repository directory. 60 | 1. Create the necessary folders and download the necessary files. 61 | ``` 62 | cd monotonic_align/ 63 | mkdir monotonic_align 64 | python setup.py build_ext --inplace 65 | cd .. 66 | mkdir pretrained_models 67 | # download data for fine-tuning 68 | wget https://huggingface.co/datasets/Plachta/sampled_audio4ft/resolve/main/sampled_audio4ft_v2.zip 69 | unzip sampled_audio4ft_v2.zip 70 | ``` 71 | 72 | For your finetuned model you may need to create additional directories: 73 | ``` 74 | mkdir video_data 75 | mkdir raw_audio 76 | mkdir denoised_audio 77 | mkdir custom_character_voice 78 | mkdir segmented_character_voice 79 | ``` 80 | 2. Download pretrained models. For example, trilingual model: 81 | ``` 82 | wget https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/D_trilingual.pth -O ./pretrained_models/D_0.pth 83 | wget https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/pretrained_models/G_trilingual.pth -O ./pretrained_models/G_0.pth 84 | wget https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer/resolve/main/configs/uma_trilingual.json -O ./configs/finetune_speaker.json 85 | ``` 86 | 3. Activate your environment and run the following code: 87 | `python3 cmd_inference.py -m pretrained_models/G_0.pth -c configs/finetune_speaker.json -t 你好,训练员先生,很高兴见到你。 -s "派蒙 Paimon (Genshin Impact)" -l "简体中文"` 88 | You can choose another language, customize output folder, change text and character, but all these parameters you can see in the file `cmd_inference.py`. 89 | Below I'll show only how to change the character. 90 | 4. To change the character please open config file (`configs/finetune_speaker.json`). There you can find dictionary `speakers`, where you'll be able to see full list of speakers. Just copy the name of the character you need use it instead of `"派蒙 Paimon (Genshin Impact)"` 91 | 5. If you have success, you can find output `.wav` file in the `output/vits` 92 | 93 | 94 | ## Use in MoeGoe 95 | 0. Prepare downloaded model & config file, which are named `G_latest.pth` and `moegoe_config.json`, respectively. 96 | 1. Follow [MoeGoe](https://github.com/CjangCjengh/MoeGoe) page instructions to install, configure path, and use. 97 | 98 | ## Looking for help? 99 | If you have any questions, please feel free to open an [issue](https://github.com/Plachtaa/VITS-fast-fine-tuning/issues/new) or join our [Discord](https://discord.gg/TcrjDFvm5A) server. 100 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 | English Documentation Please Click [here](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/README.md) 2 | # VITS 快速微调 3 | 这个代码库会指导你如何将自定义角色(甚至你自己),加入预训练的VITS模型中,在1小时内的微调使模型具备如下功能: 4 | 1. 在 模型所包含的任意两个角色 之间进行声线转换 5 | 2. 以 你加入的角色声线 进行中日英三语 文本到语音合成。 6 | 7 | 本项目使用的底模涵盖常见二次元男/女配音声线(来自原神数据集)以及现实世界常见男/女声线(来自VCTK数据集),支持中日英三语,保证能够在微调时快速适应新的声线。 8 | 9 | 欢迎体验微调所使用的底模! 10 | 11 | 中日英:[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer) 作者:我 12 | 13 | 中日:[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sayashi/vits-uma-genshin-honkai) 作者:[SayaSS](https://github.com/SayaSS) 14 | 15 | 纯中文:(没有huggingface demo)作者:[Wwwwhy230825](https://github.com/Wwwwhy230825) 16 | 17 | ### 目前支持的任务: 18 | - [x] 从 10条以上的短音频 克隆角色声音 19 | - [x] 从 3分钟以上的长音频(单个音频只能包含单说话人) 克隆角色声音 20 | - [x] 从 3分钟以上的视频(单个视频只能包含单说话人) 克隆角色声音 21 | - [x] 通过输入 bilibili视频链接(单个视频只能包含单说话人) 克隆角色声音 22 | 23 | ### 目前支持声线转换和中日英三语TTS的角色 24 | - [x] 任意角色(只要你有角色的声音样本) 25 | (注意:声线转换只能在任意两个存在于模型中的说话人之间进行) 26 | 27 | 28 | 29 | 30 | ## 微调 31 | 若希望于本地机器进行训练,请参考[LOCAL.md](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/LOCAL.md)以进行。 32 | 另外,也可以选择使用 [Google Colab](https://colab.research.google.com/drive/1pn1xnFfdLK63gVXDwV4zCXfVeo8c-I-0?usp=sharing) 33 | 进行微调任务。 34 | ### 我需要花多长时间? 35 | 1. 安装依赖 (10 min在Google Colab中) 36 | 2. 选择预训练模型,详细区别参见[Colab 笔记本页面](https://colab.research.google.com/drive/1pn1xnFfdLK63gVXDwV4zCXfVeo8c-I-0?usp=sharing)。 37 | 3. 上传你希望加入的其它角色声音,详细上传方式见[DATA.MD](https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/DATA.MD) 38 | 4. 进行微调,根据选择的微调方式和样本数量不同,花费时长可能在20分钟到2小时不等。 39 | 40 | 微调结束后可以直接下载微调好的模型,日后在本地运行(不需要GPU) 41 | 42 | ## 本地运行和推理 43 | 0. 记得下载微调好的模型和config文件! 44 | 1. 下载最新的Release包(在Github页面的右侧) 45 | 2. 把下载的模型和config文件放在 `inference`文件夹下, 其文件名分别为 `G_latest.pth` 和 `finetune_speaker.json`。 46 | 3. 一切准备就绪后,文件结构应该如下所示: 47 | ``` 48 | inference 49 | ├───inference.exe 50 | ├───... 51 | ├───finetune_speaker.json 52 | └───G_latest.pth 53 | ``` 54 | 4. 运行 `inference.exe`, 浏览器会自动弹出窗口, 注意其所在路径不能有中文字符或者空格. 55 | 5. 请注意,声线转换功能需要安装`ffmpeg`才能正常使用. 56 | 57 | ## 在MoeGoe使用 58 | 0. MoeGoe以及类似其它VITS推理UI使用的config格式略有不同,需要下载的文件为模型`G_latest.pth`和配置文件`moegoe_config.json` 59 | 1. 按照[MoeGoe](https://github.com/CjangCjengh/MoeGoe)页面的提示配置路径即可使用。 60 | 2. MoeGoe在输入句子时需要使用相应的语言标记包裹句子才能正常合成。(日语用[JA], 中文用[ZH], 英文用[EN]),例如: 61 | [JA]こんにちわ。[JA] 62 | [ZH]你好![ZH] 63 | [EN]Hello![EN] 64 | 65 | ## 帮助 66 | 如果你在使用过程中遇到了任何问题,可以在[这里](https://github.com/Plachtaa/VITS-fast-fine-tuning/issues/new)开一个issue,或者加入Discord服务器寻求帮助:[Discord](https://discord.gg/TcrjDFvm5A)。 67 | -------------------------------------------------------------------------------- /VC_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch import no_grad, LongTensor 5 | import argparse 6 | import commons 7 | from mel_processing import spectrogram_torch 8 | import utils 9 | from models import SynthesizerTrn 10 | import gradio as gr 11 | import librosa 12 | import webbrowser 13 | 14 | from text import text_to_sequence, _clean_text 15 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 16 | import logging 17 | logging.getLogger("PIL").setLevel(logging.WARNING) 18 | logging.getLogger("urllib3").setLevel(logging.WARNING) 19 | logging.getLogger("markdown_it").setLevel(logging.WARNING) 20 | logging.getLogger("httpx").setLevel(logging.WARNING) 21 | logging.getLogger("asyncio").setLevel(logging.WARNING) 22 | 23 | language_marks = { 24 | "Japanese": "", 25 | "日本語": "[JA]", 26 | "简体中文": "[ZH]", 27 | "English": "[EN]", 28 | "Mix": "", 29 | } 30 | lang = ['日本語', '简体中文', 'English', 'Mix'] 31 | def get_text(text, hps, is_symbol): 32 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) 33 | if hps.data.add_blank: 34 | text_norm = commons.intersperse(text_norm, 0) 35 | text_norm = LongTensor(text_norm) 36 | return text_norm 37 | 38 | def create_tts_fn(model, hps, speaker_ids): 39 | def tts_fn(text, speaker, language, speed): 40 | if language is not None: 41 | text = language_marks[language] + text + language_marks[language] 42 | speaker_id = speaker_ids[speaker] 43 | stn_tst = get_text(text, hps, False) 44 | with no_grad(): 45 | x_tst = stn_tst.unsqueeze(0).to(device) 46 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) 47 | sid = LongTensor([speaker_id]).to(device) 48 | audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, 49 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() 50 | del stn_tst, x_tst, x_tst_lengths, sid 51 | return "Success", (hps.data.sampling_rate, audio) 52 | 53 | return tts_fn 54 | 55 | def create_vc_fn(model, hps, speaker_ids): 56 | def vc_fn(original_speaker, target_speaker, record_audio, upload_audio): 57 | input_audio = record_audio if record_audio is not None else upload_audio 58 | if input_audio is None: 59 | return "You need to record or upload an audio", None 60 | sampling_rate, audio = input_audio 61 | original_speaker_id = speaker_ids[original_speaker] 62 | target_speaker_id = speaker_ids[target_speaker] 63 | 64 | audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32) 65 | if len(audio.shape) > 1: 66 | audio = librosa.to_mono(audio.transpose(1, 0)) 67 | if sampling_rate != hps.data.sampling_rate: 68 | audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate) 69 | with no_grad(): 70 | y = torch.FloatTensor(audio) 71 | y = y / max(-y.min(), y.max()) / 0.99 72 | y = y.to(device) 73 | y = y.unsqueeze(0) 74 | spec = spectrogram_torch(y, hps.data.filter_length, 75 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 76 | center=False).to(device) 77 | spec_lengths = LongTensor([spec.size(-1)]).to(device) 78 | sid_src = LongTensor([original_speaker_id]).to(device) 79 | sid_tgt = LongTensor([target_speaker_id]).to(device) 80 | audio = model.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][ 81 | 0, 0].data.cpu().float().numpy() 82 | del y, spec, spec_lengths, sid_src, sid_tgt 83 | return "Success", (hps.data.sampling_rate, audio) 84 | 85 | return vc_fn 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--model_dir", default="./G_latest.pth", help="directory to your fine-tuned model") 89 | parser.add_argument("--config_dir", default="./finetune_speaker.json", help="directory to your model config file") 90 | parser.add_argument("--share", default=False, help="make link public (used in colab)") 91 | 92 | args = parser.parse_args() 93 | hps = utils.get_hparams_from_file(args.config_dir) 94 | 95 | 96 | net_g = SynthesizerTrn( 97 | len(hps.symbols), 98 | hps.data.filter_length // 2 + 1, 99 | hps.train.segment_size // hps.data.hop_length, 100 | n_speakers=hps.data.n_speakers, 101 | **hps.model).to(device) 102 | _ = net_g.eval() 103 | 104 | _ = utils.load_checkpoint(args.model_dir, net_g, None) 105 | speaker_ids = hps.speakers 106 | speakers = list(hps.speakers.keys()) 107 | tts_fn = create_tts_fn(net_g, hps, speaker_ids) 108 | vc_fn = create_vc_fn(net_g, hps, speaker_ids) 109 | app = gr.Blocks() 110 | with app: 111 | with gr.Tab("Text-to-Speech"): 112 | with gr.Row(): 113 | with gr.Column(): 114 | textbox = gr.TextArea(label="Text", 115 | placeholder="Type your sentence here", 116 | value="こんにちわ。", elem_id=f"tts-input") 117 | # select character 118 | char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character') 119 | language_dropdown = gr.Dropdown(choices=lang, value=lang[0], label='language') 120 | duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1, 121 | label='速度 Speed') 122 | with gr.Column(): 123 | text_output = gr.Textbox(label="Message") 124 | audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio") 125 | btn = gr.Button("Generate!") 126 | btn.click(tts_fn, 127 | inputs=[textbox, char_dropdown, language_dropdown, duration_slider,], 128 | outputs=[text_output, audio_output]) 129 | with gr.Tab("Voice Conversion"): 130 | gr.Markdown(""" 131 | 录制或上传声音,并选择要转换的音色。 132 | """) 133 | with gr.Column(): 134 | record_audio = gr.Audio(label="record your voice", source="microphone") 135 | upload_audio = gr.Audio(label="or upload audio here", source="upload") 136 | source_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="source speaker") 137 | target_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="target speaker") 138 | with gr.Column(): 139 | message_box = gr.Textbox(label="Message") 140 | converted_audio = gr.Audio(label='converted audio') 141 | btn = gr.Button("Convert!") 142 | btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio], 143 | outputs=[message_box, converted_audio]) 144 | webbrowser.open("http://127.0.0.1:7860") 145 | app.launch(share=args.share) 146 | 147 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import commons 9 | import modules 10 | from modules import LayerNorm 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): 15 | super().__init__() 16 | self.hidden_channels = hidden_channels 17 | self.filter_channels = filter_channels 18 | self.n_heads = n_heads 19 | self.n_layers = n_layers 20 | self.kernel_size = kernel_size 21 | self.p_dropout = p_dropout 22 | self.window_size = window_size 23 | 24 | self.drop = nn.Dropout(p_dropout) 25 | self.attn_layers = nn.ModuleList() 26 | self.norm_layers_1 = nn.ModuleList() 27 | self.ffn_layers = nn.ModuleList() 28 | self.norm_layers_2 = nn.ModuleList() 29 | for i in range(self.n_layers): 30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) 31 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 33 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 34 | 35 | def forward(self, x, x_mask): 36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 37 | x = x * x_mask 38 | for i in range(self.n_layers): 39 | y = self.attn_layers[i](x, x, attn_mask) 40 | y = self.drop(y) 41 | x = self.norm_layers_1[i](x + y) 42 | 43 | y = self.ffn_layers[i](x, x_mask) 44 | y = self.drop(y) 45 | x = self.norm_layers_2[i](x + y) 46 | x = x * x_mask 47 | return x 48 | 49 | 50 | class Decoder(nn.Module): 51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): 52 | super().__init__() 53 | self.hidden_channels = hidden_channels 54 | self.filter_channels = filter_channels 55 | self.n_heads = n_heads 56 | self.n_layers = n_layers 57 | self.kernel_size = kernel_size 58 | self.p_dropout = p_dropout 59 | self.proximal_bias = proximal_bias 60 | self.proximal_init = proximal_init 61 | 62 | self.drop = nn.Dropout(p_dropout) 63 | self.self_attn_layers = nn.ModuleList() 64 | self.norm_layers_0 = nn.ModuleList() 65 | self.encdec_attn_layers = nn.ModuleList() 66 | self.norm_layers_1 = nn.ModuleList() 67 | self.ffn_layers = nn.ModuleList() 68 | self.norm_layers_2 = nn.ModuleList() 69 | for i in range(self.n_layers): 70 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 71 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 72 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 73 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 74 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 75 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 76 | 77 | def forward(self, x, x_mask, h, h_mask): 78 | """ 79 | x: decoder input 80 | h: encoder output 81 | """ 82 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 83 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 84 | x = x * x_mask 85 | for i in range(self.n_layers): 86 | y = self.self_attn_layers[i](x, x, self_attn_mask) 87 | y = self.drop(y) 88 | x = self.norm_layers_0[i](x + y) 89 | 90 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 91 | y = self.drop(y) 92 | x = self.norm_layers_1[i](x + y) 93 | 94 | y = self.ffn_layers[i](x, x_mask) 95 | y = self.drop(y) 96 | x = self.norm_layers_2[i](x + y) 97 | x = x * x_mask 98 | return x 99 | 100 | 101 | class MultiHeadAttention(nn.Module): 102 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 103 | super().__init__() 104 | assert channels % n_heads == 0 105 | 106 | self.channels = channels 107 | self.out_channels = out_channels 108 | self.n_heads = n_heads 109 | self.p_dropout = p_dropout 110 | self.window_size = window_size 111 | self.heads_share = heads_share 112 | self.block_length = block_length 113 | self.proximal_bias = proximal_bias 114 | self.proximal_init = proximal_init 115 | self.attn = None 116 | 117 | self.k_channels = channels // n_heads 118 | self.conv_q = nn.Conv1d(channels, channels, 1) 119 | self.conv_k = nn.Conv1d(channels, channels, 1) 120 | self.conv_v = nn.Conv1d(channels, channels, 1) 121 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 122 | self.drop = nn.Dropout(p_dropout) 123 | 124 | if window_size is not None: 125 | n_heads_rel = 1 if heads_share else n_heads 126 | rel_stddev = self.k_channels**-0.5 127 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 128 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 129 | 130 | nn.init.xavier_uniform_(self.conv_q.weight) 131 | nn.init.xavier_uniform_(self.conv_k.weight) 132 | nn.init.xavier_uniform_(self.conv_v.weight) 133 | if proximal_init: 134 | with torch.no_grad(): 135 | self.conv_k.weight.copy_(self.conv_q.weight) 136 | self.conv_k.bias.copy_(self.conv_q.bias) 137 | 138 | def forward(self, x, c, attn_mask=None): 139 | q = self.conv_q(x) 140 | k = self.conv_k(c) 141 | v = self.conv_v(c) 142 | 143 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 144 | 145 | x = self.conv_o(x) 146 | return x 147 | 148 | def attention(self, query, key, value, mask=None): 149 | # reshape [b, d, t] -> [b, n_h, t, d_k] 150 | b, d, t_s, t_t = (*key.size(), query.size(2)) 151 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 152 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 153 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 154 | 155 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 156 | if self.window_size is not None: 157 | assert t_s == t_t, "Relative attention is only available for self-attention." 158 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 159 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) 160 | scores_local = self._relative_position_to_absolute_position(rel_logits) 161 | scores = scores + scores_local 162 | if self.proximal_bias: 163 | assert t_s == t_t, "Proximal bias is only available for self-attention." 164 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 165 | if mask is not None: 166 | scores = scores.masked_fill(mask == 0, -1e4) 167 | if self.block_length is not None: 168 | assert t_s == t_t, "Local attention is only available for self-attention." 169 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 170 | scores = scores.masked_fill(block_mask == 0, -1e4) 171 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 172 | p_attn = self.drop(p_attn) 173 | output = torch.matmul(p_attn, value) 174 | if self.window_size is not None: 175 | relative_weights = self._absolute_position_to_relative_position(p_attn) 176 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 177 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 178 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 179 | return output, p_attn 180 | 181 | def _matmul_with_relative_values(self, x, y): 182 | """ 183 | x: [b, h, l, m] 184 | y: [h or 1, m, d] 185 | ret: [b, h, l, d] 186 | """ 187 | ret = torch.matmul(x, y.unsqueeze(0)) 188 | return ret 189 | 190 | def _matmul_with_relative_keys(self, x, y): 191 | """ 192 | x: [b, h, l, d] 193 | y: [h or 1, m, d] 194 | ret: [b, h, l, m] 195 | """ 196 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 197 | return ret 198 | 199 | def _get_relative_embeddings(self, relative_embeddings, length): 200 | max_relative_position = 2 * self.window_size + 1 201 | # Pad first before slice to avoid using cond ops. 202 | pad_length = max(length - (self.window_size + 1), 0) 203 | slice_start_position = max((self.window_size + 1) - length, 0) 204 | slice_end_position = slice_start_position + 2 * length - 1 205 | if pad_length > 0: 206 | padded_relative_embeddings = F.pad( 207 | relative_embeddings, 208 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 209 | else: 210 | padded_relative_embeddings = relative_embeddings 211 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 212 | return used_relative_embeddings 213 | 214 | def _relative_position_to_absolute_position(self, x): 215 | """ 216 | x: [b, h, l, 2*l-1] 217 | ret: [b, h, l, l] 218 | """ 219 | batch, heads, length, _ = x.size() 220 | # Concat columns of pad to shift from relative to absolute indexing. 221 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 222 | 223 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 224 | x_flat = x.view([batch, heads, length * 2 * length]) 225 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 226 | 227 | # Reshape and slice out the padded elements. 228 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 229 | return x_final 230 | 231 | def _absolute_position_to_relative_position(self, x): 232 | """ 233 | x: [b, h, l, l] 234 | ret: [b, h, l, 2*l-1] 235 | """ 236 | batch, heads, length, _ = x.size() 237 | # padd along column 238 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 239 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 240 | # add 0's in the beginning that will skew the elements after reshape 241 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 242 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 243 | return x_final 244 | 245 | def _attention_bias_proximal(self, length): 246 | """Bias for self-attention to encourage attention to close positions. 247 | Args: 248 | length: an integer scalar. 249 | Returns: 250 | a Tensor with shape [1, 1, length, length] 251 | """ 252 | r = torch.arange(length, dtype=torch.float32) 253 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 254 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 255 | 256 | 257 | class FFN(nn.Module): 258 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): 259 | super().__init__() 260 | self.in_channels = in_channels 261 | self.out_channels = out_channels 262 | self.filter_channels = filter_channels 263 | self.kernel_size = kernel_size 264 | self.p_dropout = p_dropout 265 | self.activation = activation 266 | self.causal = causal 267 | 268 | if causal: 269 | self.padding = self._causal_padding 270 | else: 271 | self.padding = self._same_padding 272 | 273 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 274 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 275 | self.drop = nn.Dropout(p_dropout) 276 | 277 | def forward(self, x, x_mask): 278 | x = self.conv_1(self.padding(x * x_mask)) 279 | if self.activation == "gelu": 280 | x = x * torch.sigmoid(1.702 * x) 281 | else: 282 | x = torch.relu(x) 283 | x = self.drop(x) 284 | x = self.conv_2(self.padding(x * x_mask)) 285 | return x * x_mask 286 | 287 | def _causal_padding(self, x): 288 | if self.kernel_size == 1: 289 | return x 290 | pad_l = self.kernel_size - 1 291 | pad_r = 0 292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 293 | x = F.pad(x, commons.convert_pad_shape(padding)) 294 | return x 295 | 296 | def _same_padding(self, x): 297 | if self.kernel_size == 1: 298 | return x 299 | pad_l = (self.kernel_size - 1) // 2 300 | pad_r = self.kernel_size // 2 301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 302 | x = F.pad(x, commons.convert_pad_shape(padding)) 303 | return x 304 | -------------------------------------------------------------------------------- /cmd_inference.py: -------------------------------------------------------------------------------- 1 | """该模块用于生成VITS文件 2 | 使用方法 3 | 4 | python cmd_inference.py -m 模型路径 -c 配置文件路径 -o 输出文件路径 -l 输入的语言 -t 输入文本 -s 合成目标说话人名称 5 | 6 | 可选参数 7 | -ns 感情变化程度 8 | -nsw 音素发音长度 9 | -ls 整体语速 10 | -on 输出文件的名称 11 | 12 | """ 13 | """English version of this module, which is used to generate VITS files 14 | Instructions 15 | 16 | python cmd_inference.py -m model_path -c configuration_file_path -o output_file_path -l input_language -t input_text -s synthesize_target_speaker_name 17 | 18 | Optional parameters 19 | -ns degree of emotional change 20 | -nsw phoneme pronunciation length 21 | -ls overall speaking speed 22 | -on name of the output file 23 | """ 24 | 25 | from pathlib import Path 26 | import utils 27 | from models import SynthesizerTrn 28 | import torch 29 | from torch import no_grad, LongTensor 30 | import librosa 31 | from text import text_to_sequence, _clean_text 32 | import commons 33 | import scipy.io.wavfile as wavf 34 | import os 35 | 36 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 37 | 38 | language_marks = { 39 | "Japanese": "", 40 | "日本語": "[JA]", 41 | "简体中文": "[ZH]", 42 | "English": "[EN]", 43 | "Mix": "", 44 | } 45 | 46 | 47 | def get_text(text, hps, is_symbol): 48 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) 49 | if hps.data.add_blank: 50 | text_norm = commons.intersperse(text_norm, 0) 51 | text_norm = LongTensor(text_norm) 52 | return text_norm 53 | 54 | 55 | 56 | if __name__ == "__main__": 57 | import argparse 58 | 59 | """ 60 | English description of some parameters: 61 | -s - speaker name, you should use name, not the number 62 | """ 63 | parser = argparse.ArgumentParser(description='vits inference') 64 | #必须参数 65 | parser.add_argument('-m', '--model_path', type=str, default="logs/44k/G_0.pth", help='模型路径') 66 | parser.add_argument('-c', '--config_path', type=str, default="configs/config.json", help='配置文件路径') 67 | parser.add_argument('-o', '--output_path', type=str, default="output/vits", help='输出文件路径') 68 | parser.add_argument('-l', '--language', type=str, default="日本語", help='输入的语言') 69 | parser.add_argument('-t', '--text', type=str, help='输入文本') 70 | parser.add_argument('-s', '--spk', type=str, help='合成目标说话人名称') 71 | #可选参数 72 | parser.add_argument('-on', '--output_name', type=str, default="output", help='输出文件的名称') 73 | parser.add_argument('-ns', '--noise_scale', type=float,default= .667,help='感情变化程度') 74 | parser.add_argument('-nsw', '--noise_scale_w', type=float,default=0.6, help='音素发音长度') 75 | parser.add_argument('-ls', '--length_scale', type=float,default=1, help='整体语速') 76 | 77 | args = parser.parse_args() 78 | 79 | model_path = args.model_path 80 | config_path = args.config_path 81 | output_dir = Path(args.output_path) 82 | output_dir.mkdir(parents=True, exist_ok=True) 83 | 84 | language = args.language 85 | text = args.text 86 | spk = args.spk 87 | noise_scale = args.noise_scale 88 | noise_scale_w = args.noise_scale_w 89 | length = args.length_scale 90 | output_name = args.output_name 91 | 92 | hps = utils.get_hparams_from_file(config_path) 93 | net_g = SynthesizerTrn( 94 | len(hps.symbols), 95 | hps.data.filter_length // 2 + 1, 96 | hps.train.segment_size // hps.data.hop_length, 97 | n_speakers=hps.data.n_speakers, 98 | **hps.model).to(device) 99 | _ = net_g.eval() 100 | _ = utils.load_checkpoint(model_path, net_g, None) 101 | 102 | speaker_ids = hps.speakers 103 | 104 | 105 | if language is not None: 106 | text = language_marks[language] + text + language_marks[language] 107 | speaker_id = speaker_ids[spk] 108 | stn_tst = get_text(text, hps, False) 109 | with no_grad(): 110 | x_tst = stn_tst.unsqueeze(0).to(device) 111 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) 112 | sid = LongTensor([speaker_id]).to(device) 113 | audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, 114 | length_scale=1.0 / length)[0][0, 0].data.cpu().float().numpy() 115 | del stn_tst, x_tst, x_tst_lengths, sid 116 | 117 | wavf.write(str(output_dir)+"/"+output_name+".wav",hps.data.sampling_rate,audio) 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | try: 54 | ret[i] = x[i, :, idx_str:idx_end] 55 | except RuntimeError: 56 | print("?") 57 | return ret 58 | 59 | 60 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 61 | b, d, t = x.size() 62 | if x_lengths is None: 63 | x_lengths = t 64 | ids_str_max = x_lengths - segment_size + 1 65 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 66 | ret = slice_segments(x, ids_str, segment_size) 67 | return ret, ids_str 68 | 69 | 70 | def get_timing_signal_1d( 71 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 72 | position = torch.arange(length, dtype=torch.float) 73 | num_timescales = channels // 2 74 | log_timescale_increment = ( 75 | math.log(float(max_timescale) / float(min_timescale)) / 76 | (num_timescales - 1)) 77 | inv_timescales = min_timescale * torch.exp( 78 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 79 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 80 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 81 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 82 | signal = signal.view(1, channels, length) 83 | return signal 84 | 85 | 86 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 87 | b, channels, length = x.size() 88 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 89 | return x + signal.to(dtype=x.dtype, device=x.device) 90 | 91 | 92 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 93 | b, channels, length = x.size() 94 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 95 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 96 | 97 | 98 | def subsequent_mask(length): 99 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 100 | return mask 101 | 102 | 103 | @torch.jit.script 104 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 105 | n_channels_int = n_channels[0] 106 | in_act = input_a + input_b 107 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 108 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 109 | acts = t_act * s_act 110 | return acts 111 | 112 | 113 | def convert_pad_shape(pad_shape): 114 | l = pad_shape[::-1] 115 | pad_shape = [item for sublist in l for item in sublist] 116 | return pad_shape 117 | 118 | 119 | def shift_1d(x): 120 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 121 | return x 122 | 123 | 124 | def sequence_mask(length, max_length=None): 125 | if max_length is None: 126 | max_length = length.max() 127 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 128 | return x.unsqueeze(0) < length.unsqueeze(1) 129 | 130 | 131 | def generate_path(duration, mask): 132 | """ 133 | duration: [b, 1, t_x] 134 | mask: [b, 1, t_y, t_x] 135 | """ 136 | device = duration.device 137 | 138 | b, _, t_y, t_x = mask.shape 139 | cum_duration = torch.cumsum(duration, -1) 140 | 141 | cum_duration_flat = cum_duration.view(b * t_x) 142 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 143 | path = path.view(b, t_x, t_y) 144 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 145 | path = path.unsqueeze(1).transpose(2,3) * mask 146 | return path 147 | 148 | 149 | def clip_grad_value_(parameters, clip_value, norm_type=2): 150 | if isinstance(parameters, torch.Tensor): 151 | parameters = [parameters] 152 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 153 | norm_type = float(norm_type) 154 | if clip_value is not None: 155 | clip_value = float(clip_value) 156 | 157 | total_norm = 0 158 | for p in parameters: 159 | param_norm = p.grad.data.norm(norm_type) 160 | total_norm += param_norm.item() ** norm_type 161 | if clip_value is not None: 162 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 163 | total_norm = total_norm ** (1. / norm_type) 164 | return total_norm 165 | -------------------------------------------------------------------------------- /configs/modified_finetune_speaker.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 10, 4 | "eval_interval": 100, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 0.0002, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-09, 13 | "batch_size": 16, 14 | "fp16_run": true, 15 | "lr_decay": 0.999875, 16 | "segment_size": 8192, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 45, 20 | "c_kl": 1.0 21 | }, 22 | "data": { 23 | "training_files": "final_annotation_train.txt", 24 | "validation_files": "final_annotation_val.txt", 25 | "text_cleaners": [ 26 | "chinese_cleaners" 27 | ], 28 | "max_wav_value": 32768.0, 29 | "sampling_rate": 22050, 30 | "filter_length": 1024, 31 | "hop_length": 256, 32 | "win_length": 1024, 33 | "n_mel_channels": 80, 34 | "mel_fmin": 0.0, 35 | "mel_fmax": null, 36 | "add_blank": true, 37 | "n_speakers": 2, 38 | "cleaned_text": true 39 | }, 40 | "model": { 41 | "inter_channels": 192, 42 | "hidden_channels": 192, 43 | "filter_channels": 768, 44 | "n_heads": 2, 45 | "n_layers": 6, 46 | "kernel_size": 3, 47 | "p_dropout": 0.1, 48 | "resblock": "1", 49 | "resblock_kernel_sizes": [ 50 | 3, 51 | 7, 52 | 11 53 | ], 54 | "resblock_dilation_sizes": [ 55 | [ 56 | 1, 57 | 3, 58 | 5 59 | ], 60 | [ 61 | 1, 62 | 3, 63 | 5 64 | ], 65 | [ 66 | 1, 67 | 3, 68 | 5 69 | ] 70 | ], 71 | "upsample_rates": [ 72 | 8, 73 | 8, 74 | 2, 75 | 2 76 | ], 77 | "upsample_initial_channel": 512, 78 | "upsample_kernel_sizes": [ 79 | 16, 80 | 16, 81 | 4, 82 | 4 83 | ], 84 | "n_layers_q": 3, 85 | "use_spectral_norm": false, 86 | "gin_channels": 256 87 | }, 88 | "symbols": [ 89 | "_", 90 | "\uff1b", 91 | "\uff1a", 92 | "\uff0c", 93 | "\u3002", 94 | "\uff01", 95 | "\uff1f", 96 | "-", 97 | "\u201c", 98 | "\u201d", 99 | "\u300a", 100 | "\u300b", 101 | "\u3001", 102 | "\uff08", 103 | "\uff09", 104 | "\u2026", 105 | "\u2014", 106 | " ", 107 | "A", 108 | "B", 109 | "C", 110 | "D", 111 | "E", 112 | "F", 113 | "G", 114 | "H", 115 | "I", 116 | "J", 117 | "K", 118 | "L", 119 | "M", 120 | "N", 121 | "O", 122 | "P", 123 | "Q", 124 | "R", 125 | "S", 126 | "T", 127 | "U", 128 | "V", 129 | "W", 130 | "X", 131 | "Y", 132 | "Z", 133 | "a", 134 | "b", 135 | "c", 136 | "d", 137 | "e", 138 | "f", 139 | "g", 140 | "h", 141 | "i", 142 | "j", 143 | "k", 144 | "l", 145 | "m", 146 | "n", 147 | "o", 148 | "p", 149 | "q", 150 | "r", 151 | "s", 152 | "t", 153 | "u", 154 | "v", 155 | "w", 156 | "x", 157 | "y", 158 | "z", 159 | "1", 160 | "2", 161 | "3", 162 | "4", 163 | "5", 164 | "0", 165 | "\uff22", 166 | "\uff30" 167 | ], 168 | "speakers": { 169 | "dingzhen": 0, 170 | "taffy": 1 171 | } 172 | } -------------------------------------------------------------------------------- /configs/uma_trilingual.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 16, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "training_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.train.txt.cleaned", 21 | "validation_files":"../CH_JA_EN_mix_voice/clipped_3_vits_trilingual_annotations.val.txt.cleaned", 22 | "text_cleaners":["cjke_cleaners2"], 23 | "max_wav_value": 32768.0, 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0.0, 30 | "mel_fmax": null, 31 | "add_blank": true, 32 | "n_speakers": 999, 33 | "cleaned_text": true 34 | }, 35 | "model": { 36 | "inter_channels": 192, 37 | "hidden_channels": 192, 38 | "filter_channels": 768, 39 | "n_heads": 2, 40 | "n_layers": 6, 41 | "kernel_size": 3, 42 | "p_dropout": 0.1, 43 | "resblock": "1", 44 | "resblock_kernel_sizes": [3,7,11], 45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 46 | "upsample_rates": [8,8,2,2], 47 | "upsample_initial_channel": 512, 48 | "upsample_kernel_sizes": [16,16,4,4], 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false, 51 | "gin_channels": 256 52 | }, 53 | "symbols": ["_", ",", ".", "!", "?", "-", "~", "\u2026", "N", "Q", "a", "b", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "s", "t", "u", "v", "w", "x", "y", "z", "\u0251", "\u00e6", "\u0283", "\u0291", "\u00e7", "\u026f", "\u026a", "\u0254", "\u025b", "\u0279", "\u00f0", "\u0259", "\u026b", "\u0265", "\u0278", "\u028a", "\u027e", "\u0292", "\u03b8", "\u03b2", "\u014b", "\u0266", "\u207c", "\u02b0", "`", "^", "#", "*", "=", "\u02c8", "\u02cc", "\u2192", "\u2193", "\u2191", " "] 54 | } -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | import torchaudio 8 | 9 | import commons 10 | from mel_processing import spectrogram_torch 11 | from utils import load_wav_to_torch, load_filepaths_and_text 12 | from text import text_to_sequence, cleaned_text_to_sequence 13 | """Multi speaker version""" 14 | 15 | 16 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 17 | """ 18 | 1) loads audio, speaker_id, text pairs 19 | 2) normalizes text and converts them to sequences of integers 20 | 3) computes spectrograms from audio files. 21 | """ 22 | 23 | def __init__(self, audiopaths_sid_text, hparams, symbols): 24 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 25 | self.text_cleaners = hparams.text_cleaners 26 | self.max_wav_value = hparams.max_wav_value 27 | self.sampling_rate = hparams.sampling_rate 28 | self.filter_length = hparams.filter_length 29 | self.hop_length = hparams.hop_length 30 | self.win_length = hparams.win_length 31 | self.sampling_rate = hparams.sampling_rate 32 | 33 | self.cleaned_text = getattr(hparams, "cleaned_text", False) 34 | 35 | self.add_blank = hparams.add_blank 36 | self.min_text_len = getattr(hparams, "min_text_len", 1) 37 | self.max_text_len = getattr(hparams, "max_text_len", 190) 38 | self.symbols = symbols 39 | 40 | random.seed(1234) 41 | random.shuffle(self.audiopaths_sid_text) 42 | self._filter() 43 | 44 | def _filter(self): 45 | """ 46 | Filter text & store spec lengths 47 | """ 48 | # Store spectrogram lengths for Bucketing 49 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 50 | # spec_length = wav_length // hop_length 51 | 52 | audiopaths_sid_text_new = [] 53 | lengths = [] 54 | for audiopath, sid, text in self.audiopaths_sid_text: 55 | # audiopath = "./user_voice/" + audiopath 56 | 57 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 58 | audiopaths_sid_text_new.append([audiopath, sid, text]) 59 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 60 | self.audiopaths_sid_text = audiopaths_sid_text_new 61 | self.lengths = lengths 62 | 63 | def get_audio_text_speaker_pair(self, audiopath_sid_text): 64 | # separate filename, speaker_id and text 65 | audiopath, sid, text = audiopath_sid_text[0], audiopath_sid_text[1], audiopath_sid_text[2] 66 | text = self.get_text(text) 67 | spec, wav = self.get_audio(audiopath) 68 | sid = self.get_sid(sid) 69 | return (text, spec, wav, sid) 70 | 71 | def get_audio(self, filename): 72 | # audio, sampling_rate = load_wav_to_torch(filename) 73 | # if sampling_rate != self.sampling_rate: 74 | # raise ValueError("{} {} SR doesn't match target {} SR".format( 75 | # sampling_rate, self.sampling_rate)) 76 | # audio_norm = audio / self.max_wav_value if audio.max() > 10 else audio 77 | # audio_norm = audio_norm.unsqueeze(0) 78 | audio_norm, sampling_rate = torchaudio.load(filename, frame_offset=0, num_frames=-1, normalize=True, channels_first=True) 79 | # spec_filename = filename.replace(".wav", ".spec.pt") 80 | # if os.path.exists(spec_filename): 81 | # spec = torch.load(spec_filename) 82 | # else: 83 | # try: 84 | spec = spectrogram_torch(audio_norm, self.filter_length, 85 | self.sampling_rate, self.hop_length, self.win_length, 86 | center=False) 87 | spec = spec.squeeze(0) 88 | # except NotImplementedError: 89 | # print("?") 90 | # spec = torch.squeeze(spec, 0) 91 | # torch.save(spec, spec_filename) 92 | return spec, audio_norm 93 | 94 | def get_text(self, text): 95 | if self.cleaned_text: 96 | text_norm = cleaned_text_to_sequence(text, self.symbols) 97 | else: 98 | text_norm = text_to_sequence(text, self.text_cleaners) 99 | if self.add_blank: 100 | text_norm = commons.intersperse(text_norm, 0) 101 | text_norm = torch.LongTensor(text_norm) 102 | return text_norm 103 | 104 | def get_sid(self, sid): 105 | sid = torch.LongTensor([int(sid)]) 106 | return sid 107 | 108 | def __getitem__(self, index): 109 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) 110 | 111 | def __len__(self): 112 | return len(self.audiopaths_sid_text) 113 | 114 | 115 | class TextAudioSpeakerCollate(): 116 | """ Zero-pads model inputs and targets 117 | """ 118 | 119 | def __init__(self, return_ids=False): 120 | self.return_ids = return_ids 121 | 122 | def __call__(self, batch): 123 | """Collate's training batch from normalized text, audio and speaker identities 124 | PARAMS 125 | ------ 126 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 127 | """ 128 | # Right zero-pad all one-hot text sequences to max input length 129 | _, ids_sorted_decreasing = torch.sort( 130 | torch.LongTensor([x[1].size(1) for x in batch]), 131 | dim=0, descending=True) 132 | 133 | max_text_len = max([len(x[0]) for x in batch]) 134 | max_spec_len = max([x[1].size(1) for x in batch]) 135 | max_wav_len = max([x[2].size(1) for x in batch]) 136 | 137 | text_lengths = torch.LongTensor(len(batch)) 138 | spec_lengths = torch.LongTensor(len(batch)) 139 | wav_lengths = torch.LongTensor(len(batch)) 140 | sid = torch.LongTensor(len(batch)) 141 | 142 | text_padded = torch.LongTensor(len(batch), max_text_len) 143 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 144 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 145 | text_padded.zero_() 146 | spec_padded.zero_() 147 | wav_padded.zero_() 148 | for i in range(len(ids_sorted_decreasing)): 149 | row = batch[ids_sorted_decreasing[i]] 150 | 151 | text = row[0] 152 | text_padded[i, :text.size(0)] = text 153 | text_lengths[i] = text.size(0) 154 | 155 | spec = row[1] 156 | spec_padded[i, :, :spec.size(1)] = spec 157 | spec_lengths[i] = spec.size(1) 158 | 159 | wav = row[2] 160 | wav_padded[i, :, :wav.size(1)] = wav 161 | wav_lengths[i] = wav.size(1) 162 | 163 | sid[i] = row[3] 164 | 165 | if self.return_ids: 166 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing 167 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid 168 | 169 | 170 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 171 | """ 172 | Maintain similar input lengths in a batch. 173 | Length groups are specified by boundaries. 174 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 175 | 176 | It removes samples which are not included in the boundaries. 177 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 178 | """ 179 | 180 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): 181 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 182 | self.lengths = dataset.lengths 183 | self.batch_size = batch_size 184 | self.boundaries = boundaries 185 | 186 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 187 | self.total_size = sum(self.num_samples_per_bucket) 188 | self.num_samples = self.total_size // self.num_replicas 189 | 190 | def _create_buckets(self): 191 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 192 | for i in range(len(self.lengths)): 193 | length = self.lengths[i] 194 | idx_bucket = self._bisect(length) 195 | if idx_bucket != -1: 196 | buckets[idx_bucket].append(i) 197 | 198 | try: 199 | for i in range(len(buckets) - 1, 0, -1): 200 | if len(buckets[i]) == 0: 201 | buckets.pop(i) 202 | self.boundaries.pop(i + 1) 203 | assert all(len(bucket) > 0 for bucket in buckets) 204 | # When one bucket is not traversed 205 | except Exception as e: 206 | print('Bucket warning ', e) 207 | for i in range(len(buckets) - 1, -1, -1): 208 | if len(buckets[i]) == 0: 209 | buckets.pop(i) 210 | self.boundaries.pop(i + 1) 211 | 212 | num_samples_per_bucket = [] 213 | for i in range(len(buckets)): 214 | len_bucket = len(buckets[i]) 215 | total_batch_size = self.num_replicas * self.batch_size 216 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 217 | num_samples_per_bucket.append(len_bucket + rem) 218 | return buckets, num_samples_per_bucket 219 | 220 | def __iter__(self): 221 | # deterministically shuffle based on epoch 222 | g = torch.Generator() 223 | g.manual_seed(self.epoch) 224 | 225 | indices = [] 226 | if self.shuffle: 227 | for bucket in self.buckets: 228 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 229 | else: 230 | for bucket in self.buckets: 231 | indices.append(list(range(len(bucket)))) 232 | 233 | batches = [] 234 | for i in range(len(self.buckets)): 235 | bucket = self.buckets[i] 236 | len_bucket = len(bucket) 237 | ids_bucket = indices[i] 238 | num_samples_bucket = self.num_samples_per_bucket[i] 239 | 240 | # add extra samples to make it evenly divisible 241 | rem = num_samples_bucket - len_bucket 242 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] 243 | 244 | # subsample 245 | ids_bucket = ids_bucket[self.rank::self.num_replicas] 246 | 247 | # batching 248 | for j in range(len(ids_bucket) // self.batch_size): 249 | batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] 250 | batches.append(batch) 251 | 252 | if self.shuffle: 253 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 254 | batches = [batches[i] for i in batch_ids] 255 | self.batches = batches 256 | 257 | assert len(self.batches) * self.batch_size == self.num_samples 258 | return iter(self.batches) 259 | 260 | def _bisect(self, x, lo=0, hi=None): 261 | if hi is None: 262 | hi = len(self.boundaries) - 1 263 | 264 | if hi > lo: 265 | mid = (hi + lo) // 2 266 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 267 | return mid 268 | elif x <= self.boundaries[mid]: 269 | return self._bisect(x, lo, mid) 270 | else: 271 | return self._bisect(x, mid + 1, hi) 272 | else: 273 | return -1 274 | 275 | def __len__(self): 276 | return self.num_samples // self.batch_size 277 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import commons 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1-dr)**2) 26 | g_loss = torch.mean(dg**2) 27 | loss += (r_loss + g_loss) 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1-dg)**2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | 66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 67 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 68 | 69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 70 | return spec 71 | 72 | 73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 74 | global mel_basis 75 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 76 | fmax_dtype_device = str(fmax) + '_' + dtype_device 77 | if fmax_dtype_device not in mel_basis: 78 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 81 | spec = spectral_normalize_torch(spec) 82 | return spec 83 | 84 | 85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 86 | if torch.min(y) < -1.: 87 | print('min value is ', torch.min(y)) 88 | if torch.max(y) > 1.: 89 | print('max value is ', torch.max(y)) 90 | 91 | global mel_basis, hann_window 92 | dtype_device = str(y.dtype) + '_' + str(y.device) 93 | fmax_dtype_device = str(fmax) + '_' + dtype_device 94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 95 | if fmax_dtype_device not in mel_basis: 96 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 98 | if wnsize_dtype_device not in hann_window: 99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 100 | 101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 102 | y = y.squeeze(1) 103 | 104 | spec = torch.stft(y.float(), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 105 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 106 | 107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 108 | 109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 110 | spec = spectral_normalize_torch(spec) 111 | 112 | return spec 113 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import commons 13 | from commons import init_weights, get_padding 14 | from transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.hidden_channels = hidden_channels 40 | self.out_channels = out_channels 41 | self.kernel_size = kernel_size 42 | self.n_layers = n_layers 43 | self.p_dropout = p_dropout 44 | assert n_layers > 1, "Number of layers should be larger than 0." 45 | 46 | self.conv_layers = nn.ModuleList() 47 | self.norm_layers = nn.ModuleList() 48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 49 | self.norm_layers.append(LayerNorm(hidden_channels)) 50 | self.relu_drop = nn.Sequential( 51 | nn.ReLU(), 52 | nn.Dropout(p_dropout)) 53 | for _ in range(n_layers-1): 54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 55 | self.norm_layers.append(LayerNorm(hidden_channels)) 56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 57 | self.proj.weight.data.zero_() 58 | self.proj.bias.data.zero_() 59 | 60 | def forward(self, x, x_mask): 61 | x_org = x 62 | for i in range(self.n_layers): 63 | x = self.conv_layers[i](x * x_mask) 64 | x = self.norm_layers[i](x) 65 | x = self.relu_drop(x) 66 | x = x_org + self.proj(x) 67 | return x * x_mask 68 | 69 | 70 | class DDSConv(nn.Module): 71 | """ 72 | Dilated and Depth-Separable Convolution 73 | """ 74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 75 | super().__init__() 76 | self.channels = channels 77 | self.kernel_size = kernel_size 78 | self.n_layers = n_layers 79 | self.p_dropout = p_dropout 80 | 81 | self.drop = nn.Dropout(p_dropout) 82 | self.convs_sep = nn.ModuleList() 83 | self.convs_1x1 = nn.ModuleList() 84 | self.norms_1 = nn.ModuleList() 85 | self.norms_2 = nn.ModuleList() 86 | for i in range(n_layers): 87 | dilation = kernel_size ** i 88 | padding = (kernel_size * dilation - dilation) // 2 89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 90 | groups=channels, dilation=dilation, padding=padding 91 | )) 92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 93 | self.norms_1.append(LayerNorm(channels)) 94 | self.norms_2.append(LayerNorm(channels)) 95 | 96 | def forward(self, x, x_mask, g=None): 97 | if g is not None: 98 | x = x + g 99 | for i in range(self.n_layers): 100 | y = self.convs_sep[i](x * x_mask) 101 | y = self.norms_1[i](y) 102 | y = F.gelu(y) 103 | y = self.convs_1x1[i](y) 104 | y = self.norms_2[i](y) 105 | y = F.gelu(y) 106 | y = self.drop(y) 107 | x = x + y 108 | return x * x_mask 109 | 110 | 111 | class WN(torch.nn.Module): 112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | self.hidden_channels =hidden_channels 116 | self.kernel_size = kernel_size, 117 | self.dilation_rate = dilation_rate 118 | self.n_layers = n_layers 119 | self.gin_channels = gin_channels 120 | self.p_dropout = p_dropout 121 | 122 | self.in_layers = torch.nn.ModuleList() 123 | self.res_skip_layers = torch.nn.ModuleList() 124 | self.drop = nn.Dropout(p_dropout) 125 | 126 | if gin_channels != 0: 127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 129 | 130 | for i in range(n_layers): 131 | dilation = dilation_rate ** i 132 | padding = int((kernel_size * dilation - dilation) / 2) 133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 134 | dilation=dilation, padding=padding) 135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 136 | self.in_layers.append(in_layer) 137 | 138 | # last one is not necessary 139 | if i < n_layers - 1: 140 | res_skip_channels = 2 * hidden_channels 141 | else: 142 | res_skip_channels = hidden_channels 143 | 144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 146 | self.res_skip_layers.append(res_skip_layer) 147 | 148 | def forward(self, x, x_mask, g=None, **kwargs): 149 | output = torch.zeros_like(x) 150 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 151 | 152 | if g is not None: 153 | g = self.cond_layer(g) 154 | 155 | for i in range(self.n_layers): 156 | x_in = self.in_layers[i](x) 157 | if g is not None: 158 | cond_offset = i * 2 * self.hidden_channels 159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 160 | else: 161 | g_l = torch.zeros_like(x_in) 162 | 163 | acts = commons.fused_add_tanh_sigmoid_multiply( 164 | x_in, 165 | g_l, 166 | n_channels_tensor) 167 | acts = self.drop(acts) 168 | 169 | res_skip_acts = self.res_skip_layers[i](acts) 170 | if i < self.n_layers - 1: 171 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 172 | x = (x + res_acts) * x_mask 173 | output = output + res_skip_acts[:,self.hidden_channels:,:] 174 | else: 175 | output = output + res_skip_acts 176 | return output * x_mask 177 | 178 | def remove_weight_norm(self): 179 | if self.gin_channels != 0: 180 | torch.nn.utils.remove_weight_norm(self.cond_layer) 181 | for l in self.in_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | for l in self.res_skip_layers: 184 | torch.nn.utils.remove_weight_norm(l) 185 | 186 | 187 | class ResBlock1(torch.nn.Module): 188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 189 | super(ResBlock1, self).__init__() 190 | self.convs1 = nn.ModuleList([ 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 192 | padding=get_padding(kernel_size, dilation[0]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 194 | padding=get_padding(kernel_size, dilation[1]))), 195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 196 | padding=get_padding(kernel_size, dilation[2]))) 197 | ]) 198 | self.convs1.apply(init_weights) 199 | 200 | self.convs2 = nn.ModuleList([ 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))), 205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 206 | padding=get_padding(kernel_size, 1))) 207 | ]) 208 | self.convs2.apply(init_weights) 209 | 210 | def forward(self, x, x_mask=None): 211 | for c1, c2 in zip(self.convs1, self.convs2): 212 | xt = F.leaky_relu(x, LRELU_SLOPE) 213 | if x_mask is not None: 214 | xt = xt * x_mask 215 | xt = c1(xt) 216 | xt = F.leaky_relu(xt, LRELU_SLOPE) 217 | if x_mask is not None: 218 | xt = xt * x_mask 219 | xt = c2(xt) 220 | x = xt + x 221 | if x_mask is not None: 222 | x = x * x_mask 223 | return x 224 | 225 | def remove_weight_norm(self): 226 | for l in self.convs1: 227 | remove_weight_norm(l) 228 | for l in self.convs2: 229 | remove_weight_norm(l) 230 | 231 | 232 | class ResBlock2(torch.nn.Module): 233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 234 | super(ResBlock2, self).__init__() 235 | self.convs = nn.ModuleList([ 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 237 | padding=get_padding(kernel_size, dilation[0]))), 238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 239 | padding=get_padding(kernel_size, dilation[1]))) 240 | ]) 241 | self.convs.apply(init_weights) 242 | 243 | def forward(self, x, x_mask=None): 244 | for c in self.convs: 245 | xt = F.leaky_relu(x, LRELU_SLOPE) 246 | if x_mask is not None: 247 | xt = xt * x_mask 248 | xt = c(xt) 249 | x = xt + x 250 | if x_mask is not None: 251 | x = x * x_mask 252 | return x 253 | 254 | def remove_weight_norm(self): 255 | for l in self.convs: 256 | remove_weight_norm(l) 257 | 258 | 259 | class Log(nn.Module): 260 | def forward(self, x, x_mask, reverse=False, **kwargs): 261 | if not reverse: 262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 263 | logdet = torch.sum(-y, [1, 2]) 264 | return y, logdet 265 | else: 266 | x = torch.exp(x) * x_mask 267 | return x 268 | 269 | 270 | class Flip(nn.Module): 271 | def forward(self, x, *args, reverse=False, **kwargs): 272 | x = torch.flip(x, [1]) 273 | if not reverse: 274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 275 | return x, logdet 276 | else: 277 | return x 278 | 279 | 280 | class ElementwiseAffine(nn.Module): 281 | def __init__(self, channels): 282 | super().__init__() 283 | self.channels = channels 284 | self.m = nn.Parameter(torch.zeros(channels,1)) 285 | self.logs = nn.Parameter(torch.zeros(channels,1)) 286 | 287 | def forward(self, x, x_mask, reverse=False, **kwargs): 288 | if not reverse: 289 | y = self.m + torch.exp(self.logs) * x 290 | y = y * x_mask 291 | logdet = torch.sum(self.logs * x_mask, [1,2]) 292 | return y, logdet 293 | else: 294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 295 | return x 296 | 297 | 298 | class ResidualCouplingLayer(nn.Module): 299 | def __init__(self, 300 | channels, 301 | hidden_channels, 302 | kernel_size, 303 | dilation_rate, 304 | n_layers, 305 | p_dropout=0, 306 | gin_channels=0, 307 | mean_only=False): 308 | assert channels % 2 == 0, "channels should be divisible by 2" 309 | super().__init__() 310 | self.channels = channels 311 | self.hidden_channels = hidden_channels 312 | self.kernel_size = kernel_size 313 | self.dilation_rate = dilation_rate 314 | self.n_layers = n_layers 315 | self.half_channels = channels // 2 316 | self.mean_only = mean_only 317 | 318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 320 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 321 | self.post.weight.data.zero_() 322 | self.post.bias.data.zero_() 323 | 324 | def forward(self, x, x_mask, g=None, reverse=False): 325 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 326 | h = self.pre(x0) * x_mask 327 | h = self.enc(h, x_mask, g=g) 328 | stats = self.post(h) * x_mask 329 | if not self.mean_only: 330 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 331 | else: 332 | m = stats 333 | logs = torch.zeros_like(m) 334 | 335 | if not reverse: 336 | x1 = m + x1 * torch.exp(logs) * x_mask 337 | x = torch.cat([x0, x1], 1) 338 | logdet = torch.sum(logs, [1,2]) 339 | return x, logdet 340 | else: 341 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 342 | x = torch.cat([x0, x1], 1) 343 | return x 344 | 345 | 346 | class ConvFlow(nn.Module): 347 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 348 | super().__init__() 349 | self.in_channels = in_channels 350 | self.filter_channels = filter_channels 351 | self.kernel_size = kernel_size 352 | self.n_layers = n_layers 353 | self.num_bins = num_bins 354 | self.tail_bound = tail_bound 355 | self.half_channels = in_channels // 2 356 | 357 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 358 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 359 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 360 | self.proj.weight.data.zero_() 361 | self.proj.bias.data.zero_() 362 | 363 | def forward(self, x, x_mask, g=None, reverse=False): 364 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 365 | h = self.pre(x0) 366 | h = self.convs(h, x_mask, g=g) 367 | h = self.proj(h) * x_mask 368 | 369 | b, c, t = x0.shape 370 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 371 | 372 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 373 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 374 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 375 | 376 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 377 | unnormalized_widths, 378 | unnormalized_heights, 379 | unnormalized_derivatives, 380 | inverse=reverse, 381 | tails='linear', 382 | tail_bound=self.tail_bound 383 | ) 384 | 385 | x = torch.cat([x0, x1], 1) * x_mask 386 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 387 | if not reverse: 388 | return x, logdet 389 | else: 390 | return x 391 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """ Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name = 'monotonic_align', 7 | ext_modules = cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /preprocess_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import sys 5 | sys.setrecursionlimit(500000) # Fix the error message of RecursionError: maximum recursion depth exceeded while calling a Python object. You can change the number as you want. 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--add_auxiliary_data", type=bool, help="Whether to add extra data as fine-tuning helper") 10 | parser.add_argument("--languages", default="CJE") 11 | args = parser.parse_args() 12 | if args.languages == "CJE": 13 | langs = ["[ZH]", "[JA]", "[EN]"] 14 | elif args.languages == "CJ": 15 | langs = ["[ZH]", "[JA]"] 16 | elif args.languages == "C": 17 | langs = ["[ZH]"] 18 | new_annos = [] 19 | # Source 1: transcribed short audios 20 | if os.path.exists("short_character_anno.txt"): 21 | with open("short_character_anno.txt", 'r', encoding='utf-8') as f: 22 | short_character_anno = f.readlines() 23 | new_annos += short_character_anno 24 | # Source 2: transcribed long audio segments 25 | if os.path.exists("./long_character_anno.txt"): 26 | with open("./long_character_anno.txt", 'r', encoding='utf-8') as f: 27 | long_character_anno = f.readlines() 28 | new_annos += long_character_anno 29 | 30 | # Get all speaker names 31 | speakers = [] 32 | for line in new_annos: 33 | path, speaker, text = line.split("|") 34 | if speaker not in speakers: 35 | speakers.append(speaker) 36 | assert (len(speakers) != 0), "No audio file found. Please check your uploaded file structure." 37 | # Source 3 (Optional): sampled audios as extra training helpers 38 | if args.add_auxiliary_data: 39 | with open("./sampled_audio4ft.txt", 'r', encoding='utf-8') as f: 40 | old_annos = f.readlines() 41 | # filter old_annos according to supported languages 42 | filtered_old_annos = [] 43 | for line in old_annos: 44 | for lang in langs: 45 | if lang in line: 46 | filtered_old_annos.append(line) 47 | old_annos = filtered_old_annos 48 | for line in old_annos: 49 | path, speaker, text = line.split("|") 50 | if speaker not in speakers: 51 | speakers.append(speaker) 52 | num_old_voices = len(old_annos) 53 | num_new_voices = len(new_annos) 54 | # STEP 1: balance number of new & old voices 55 | cc_duplicate = num_old_voices // num_new_voices 56 | if cc_duplicate == 0: 57 | cc_duplicate = 1 58 | 59 | 60 | # STEP 2: modify config file 61 | with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 62 | hps = json.load(f) 63 | 64 | # assign ids to new speakers 65 | speaker2id = {} 66 | for i, speaker in enumerate(speakers): 67 | speaker2id[speaker] = i 68 | # modify n_speakers 69 | hps['data']["n_speakers"] = len(speakers) 70 | # overwrite speaker names 71 | hps['speakers'] = speaker2id 72 | hps['train']['log_interval'] = 10 73 | hps['train']['eval_interval'] = 100 74 | hps['train']['batch_size'] = 16 75 | hps['data']['training_files'] = "final_annotation_train.txt" 76 | hps['data']['validation_files'] = "final_annotation_val.txt" 77 | # save modified config 78 | with open("./configs/modified_finetune_speaker.json", 'w', encoding='utf-8') as f: 79 | json.dump(hps, f, indent=2) 80 | 81 | # STEP 3: clean annotations, replace speaker names with assigned speaker IDs 82 | import text 83 | cleaned_new_annos = [] 84 | for i, line in enumerate(new_annos): 85 | path, speaker, txt = line.split("|") 86 | if len(txt) > 150: 87 | continue 88 | cleaned_text = text._clean_text(txt, hps['data']['text_cleaners']) 89 | cleaned_text += "\n" if not cleaned_text.endswith("\n") else "" 90 | cleaned_new_annos.append(path + "|" + str(speaker2id[speaker]) + "|" + cleaned_text) 91 | cleaned_old_annos = [] 92 | for i, line in enumerate(old_annos): 93 | path, speaker, txt = line.split("|") 94 | if len(txt) > 150: 95 | continue 96 | cleaned_text = text._clean_text(txt, hps['data']['text_cleaners']) 97 | cleaned_text += "\n" if not cleaned_text.endswith("\n") else "" 98 | cleaned_old_annos.append(path + "|" + str(speaker2id[speaker]) + "|" + cleaned_text) 99 | # merge with old annotation 100 | final_annos = cleaned_old_annos + cc_duplicate * cleaned_new_annos 101 | # save annotation file 102 | with open("./final_annotation_train.txt", 'w', encoding='utf-8') as f: 103 | for line in final_annos: 104 | f.write(line) 105 | # save annotation file for validation 106 | with open("./final_annotation_val.txt", 'w', encoding='utf-8') as f: 107 | for line in cleaned_new_annos: 108 | f.write(line) 109 | print("finished") 110 | else: 111 | # Do not add extra helper data 112 | # STEP 1: modify config file 113 | with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 114 | hps = json.load(f) 115 | 116 | # assign ids to new speakers 117 | speaker2id = {} 118 | for i, speaker in enumerate(speakers): 119 | speaker2id[speaker] = i 120 | # modify n_speakers 121 | hps['data']["n_speakers"] = len(speakers) 122 | # overwrite speaker names 123 | hps['speakers'] = speaker2id 124 | hps['train']['log_interval'] = 10 125 | hps['train']['eval_interval'] = 100 126 | hps['train']['batch_size'] = 16 127 | hps['data']['training_files'] = "final_annotation_train.txt" 128 | hps['data']['validation_files'] = "final_annotation_val.txt" 129 | # save modified config 130 | with open("./configs/modified_finetune_speaker.json", 'w', encoding='utf-8') as f: 131 | json.dump(hps, f, indent=2) 132 | 133 | # STEP 2: clean annotations, replace speaker names with assigned speaker IDs 134 | import text 135 | 136 | cleaned_new_annos = [] 137 | for i, line in enumerate(new_annos): 138 | path, speaker, txt = line.split("|") 139 | if len(txt) > 150: 140 | continue 141 | cleaned_text = text._clean_text(txt, hps['data']['text_cleaners']).replace("[ZH]", "") 142 | cleaned_text += "\n" if not cleaned_text.endswith("\n") else "" 143 | cleaned_new_annos.append(path + "|" + str(speaker2id[speaker]) + "|" + cleaned_text) 144 | 145 | final_annos = cleaned_new_annos 146 | # save annotation file 147 | with open("./final_annotation_train.txt", 'w', encoding='utf-8') as f: 148 | for line in final_annos: 149 | f.write(line) 150 | # save annotation file for validation 151 | with open("./final_annotation_val.txt", 'w', encoding='utf-8') as f: 152 | for line in cleaned_new_annos: 153 | f.write(line) 154 | print("finished") 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | librosa==0.9.2 3 | matplotlib==3.3.1 4 | scikit-learn==1.0.2 5 | scipy 6 | numpy==1.22 7 | tensorboard 8 | torch==2.1.2 9 | torchvision==0.16.2 10 | torchaudio==2.1.2 11 | unidecode 12 | pyopenjtalk-prebuilt 13 | jamo 14 | pypinyin 15 | jieba 16 | protobuf 17 | cn2an 18 | inflect 19 | eng_to_ipa 20 | ko_pron 21 | indic_transliteration==2.3.37 22 | num_thai==0.0.5 23 | opencc==1.1.1 24 | demucs 25 | git+https://github.com/openai/whisper.git 26 | gradio 27 | -------------------------------------------------------------------------------- /scripts/denoise_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torchaudio 4 | raw_audio_dir = "./raw_audio/" 5 | denoise_audio_dir = "./denoised_audio/" 6 | filelist = list(os.walk(raw_audio_dir))[0][2] 7 | # 2023/4/21: Get the target sampling rate 8 | with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 9 | hps = json.load(f) 10 | target_sr = hps['data']['sampling_rate'] 11 | for file in filelist: 12 | if file.endswith(".wav"): 13 | os.system(f"demucs --two-stems=vocals {raw_audio_dir}{file}") 14 | for file in filelist: 15 | file = file.replace(".wav", "") 16 | wav, sr = torchaudio.load(f"./separated/htdemucs/{file}/vocals.wav", frame_offset=0, num_frames=-1, normalize=True, 17 | channels_first=True) 18 | # merge two channels into one 19 | wav = wav.mean(dim=0).unsqueeze(0) 20 | if sr != target_sr: 21 | wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(wav) 22 | torchaudio.save(denoise_audio_dir + file + ".wav", wav, target_sr, channels_first=True) -------------------------------------------------------------------------------- /scripts/download_model.py: -------------------------------------------------------------------------------- 1 | from google.colab import files 2 | files.download("./G_latest.pth") 3 | files.download("./finetune_speaker.json") 4 | files.download("./moegoe_config.json") -------------------------------------------------------------------------------- /scripts/download_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | from concurrent.futures import ThreadPoolExecutor 5 | from google.colab import files 6 | import subprocess 7 | 8 | basepath = os.getcwd() 9 | uploaded = files.upload() # 上传文件 10 | for filename in uploaded.keys(): 11 | assert (filename.endswith(".txt")), "speaker-videolink info could only be .txt file!" 12 | shutil.move(os.path.join(basepath, filename), os.path.join("./speaker_links.txt")) 13 | 14 | 15 | def generate_infos(): 16 | infos = [] 17 | with open("./speaker_links.txt", 'r', encoding='utf-8') as f: 18 | lines = f.readlines() 19 | for line in lines: 20 | line = line.replace("\n", "").replace(" ", "") 21 | if line == "": 22 | continue 23 | speaker, link = line.split("|") 24 | filename = speaker + "_" + str(random.randint(0, 1000000)) 25 | infos.append({"link": link, "filename": filename}) 26 | return infos 27 | 28 | 29 | def download_video(info): 30 | 31 | link = info["link"] 32 | filename = info["filename"] 33 | print(f"Starting download for:\nFilename: {filename}\nLink: {link}") 34 | 35 | try: 36 | result = subprocess.run( 37 | ["yt-dlp", "-f", "30280", link, "-o", f"./video_data/{filename}.mp4", "--no-check-certificate"], 38 | stdout=subprocess.PIPE, 39 | stderr=subprocess.PIPE, 40 | text=True, 41 | check=True 42 | ) 43 | print(f"Download completed for {filename}:\n{result.stdout}") 44 | except subprocess.CalledProcessError as e: 45 | print(f"Failed to download {link}:\n{e.stderr}") 46 | 47 | 48 | if __name__ == "__main__": 49 | infos = generate_infos() 50 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: 51 | executor.map(download_video, infos) 52 | -------------------------------------------------------------------------------- /scripts/long_audio_transcribe.py: -------------------------------------------------------------------------------- 1 | from moviepy.editor import AudioFileClip 2 | import whisper 3 | import os 4 | import json 5 | import torchaudio 6 | import librosa 7 | import torch 8 | import argparse 9 | parent_dir = "./denoised_audio/" 10 | filelist = list(os.walk(parent_dir))[0][2] 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--languages", default="CJE") 14 | parser.add_argument("--whisper_size", default="medium") 15 | args = parser.parse_args() 16 | if args.languages == "CJE": 17 | lang2token = { 18 | 'zh': "[ZH]", 19 | 'ja': "[JA]", 20 | "en": "[EN]", 21 | } 22 | elif args.languages == "CJ": 23 | lang2token = { 24 | 'zh': "[ZH]", 25 | 'ja': "[JA]", 26 | } 27 | elif args.languages == "C": 28 | lang2token = { 29 | 'zh': "[ZH]", 30 | } 31 | assert(torch.cuda.is_available()), "Please enable GPU in order to run Whisper!" 32 | with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 33 | hps = json.load(f) 34 | target_sr = hps['data']['sampling_rate'] 35 | model = whisper.load_model(args.whisper_size) 36 | speaker_annos = [] 37 | for file in filelist: 38 | audio_path = os.path.join(parent_dir, file) 39 | print(f"Transcribing {audio_path}...\n") 40 | 41 | options = dict(beam_size=5, best_of=5) 42 | transcribe_options = dict(task="transcribe", **options) 43 | 44 | result = model.transcribe(audio_path, word_timestamps=True, **transcribe_options) 45 | segments = result["segments"] 46 | lang = result['language'] 47 | if lang not in lang2token: 48 | print(f"{lang} not supported, ignoring...\n") 49 | continue 50 | 51 | character_name = file.rstrip(".wav").split("_")[0] 52 | code = file.rstrip(".wav").split("_")[1] 53 | outdir = os.path.join("./segmented_character_voice", character_name) 54 | os.makedirs(outdir, exist_ok=True) 55 | 56 | wav, sr = torchaudio.load( 57 | audio_path, 58 | frame_offset=0, 59 | num_frames=-1, 60 | normalize=True, 61 | channels_first=True 62 | ) 63 | 64 | for i, seg in enumerate(segments): 65 | start_time = seg['start'] 66 | end_time = seg['end'] 67 | text = seg['text'] 68 | text_tokened = lang2token[lang] + text.replace("\n", "") + lang2token[lang] + "\n" 69 | start_idx = int(start_time * sr) 70 | end_idx = int(end_time * sr) 71 | num_samples = end_idx - start_idx 72 | if num_samples <= 0: 73 | print(f"Skipping zero-length segment: start={start_time}, end={end_time}") 74 | continue 75 | wav_seg = wav[:, start_idx:end_idx] 76 | if wav_seg.shape[1] == 0: 77 | print(f"Skipping empty segment i={i}, shape={wav_seg.shape}") 78 | continue 79 | if sr != target_sr: 80 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) 81 | wav_seg = resampler(wav_seg) 82 | 83 | wav_seg_name = f"{character_name}_{code}_{i}.wav" 84 | savepth = os.path.join(outdir, wav_seg_name) 85 | speaker_annos.append(savepth + "|" + character_name + "|" + text_tokened) 86 | print(f"Transcribed segment: {speaker_annos[-1]}") 87 | torchaudio.save(savepth, wav_seg, target_sr, channels_first=True) 88 | 89 | if len(speaker_annos) == 0: 90 | print("Warning: no long audios & videos found, this IS expected if you have only uploaded short audios") 91 | print("this IS NOT expected if you have uploaded any long audios, videos or video links. Please check your file structure or make sure your audio/video language is supported.") 92 | with open("./long_character_anno.txt", 'w', encoding='utf-8') as f: 93 | for line in speaker_annos: 94 | f.write(line) 95 | -------------------------------------------------------------------------------- /scripts/rearrange_speaker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import json 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--model_dir", type=str, default="./OUTPUT_MODEL/G_latest.pth") 8 | parser.add_argument("--config_dir", type=str, default="./configs/modified_finetune_speaker.json") 9 | args = parser.parse_args() 10 | 11 | model_sd = torch.load(args.model_dir, map_location='cpu') 12 | with open(args.config_dir, 'r', encoding='utf-8') as f: 13 | hps = json.load(f) 14 | 15 | valid_speakers = list(hps['speakers'].keys()) 16 | if hps['data']['n_speakers'] > len(valid_speakers): 17 | new_emb_g = torch.zeros([len(valid_speakers), 256]) 18 | old_emb_g = model_sd['model']['emb_g.weight'] 19 | for i, speaker in enumerate(valid_speakers): 20 | new_emb_g[i, :] = old_emb_g[hps['speakers'][speaker], :] 21 | hps['speakers'][speaker] = i 22 | hps['data']['n_speakers'] = len(valid_speakers) 23 | model_sd['model']['emb_g.weight'] = new_emb_g 24 | with open("./finetune_speaker.json", 'w', encoding='utf-8') as f: 25 | json.dump(hps, f, indent=2) 26 | torch.save(model_sd, "./G_latest.pth") 27 | else: 28 | with open("./finetune_speaker.json", 'w', encoding='utf-8') as f: 29 | json.dump(hps, f, indent=2) 30 | torch.save(model_sd, "./G_latest.pth") 31 | # save another config file copy in MoeGoe format 32 | hps['speakers'] = valid_speakers 33 | with open("./moegoe_config.json", 'w', encoding='utf-8') as f: 34 | json.dump(hps, f, indent=2) 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/resample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import torchaudio 5 | 6 | 7 | def main(): 8 | with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 9 | hps = json.load(f) 10 | target_sr = hps['data']['sampling_rate'] 11 | filelist = list(os.walk("./sampled_audio4ft"))[0][2] 12 | if target_sr != 22050: 13 | for wavfile in filelist: 14 | wav, sr = torchaudio.load("./sampled_audio4ft" + "/" + wavfile, frame_offset=0, num_frames=-1, 15 | normalize=True, channels_first=True) 16 | wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(wav) 17 | torchaudio.save("./sampled_audio4ft" + "/" + wavfile, wav, target_sr, channels_first=True) 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /scripts/short_audio_transcribe.py: -------------------------------------------------------------------------------- 1 | import whisper 2 | import os 3 | import json 4 | import torchaudio 5 | import argparse 6 | import torch 7 | 8 | lang2token = { 9 | 'zh': "[ZH]", 10 | 'ja': "[JA]", 11 | "en": "[EN]", 12 | } 13 | def transcribe_one(audio_path): 14 | try: 15 | # load audio and pad/trim it to fit 30 seconds 16 | audio = whisper.load_audio(audio_path) 17 | audio = whisper.pad_or_trim(audio) 18 | 19 | # make log-Mel spectrogram and move to the same device as the model 20 | mel = whisper.log_mel_spectrogram(audio).to(model.device) 21 | 22 | # detect the spoken language 23 | _, probs = model.detect_language(mel) 24 | print(f"Detected language: {max(probs, key=probs.get)}") 25 | lang = max(probs, key=probs.get) 26 | # decode the audio 27 | options = whisper.DecodingOptions(beam_size=5) 28 | result = whisper.decode(model, mel, options) 29 | 30 | # print the recognized text 31 | print(result.text) 32 | return lang, result.text 33 | except Exception as e: 34 | print(e) 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--languages", default="CJE") 38 | parser.add_argument("--whisper_size", default="medium") 39 | args = parser.parse_args() 40 | if args.languages == "CJE": 41 | lang2token = { 42 | 'zh': "[ZH]", 43 | 'ja': "[JA]", 44 | "en": "[EN]", 45 | } 46 | elif args.languages == "CJ": 47 | lang2token = { 48 | 'zh': "[ZH]", 49 | 'ja': "[JA]", 50 | } 51 | elif args.languages == "C": 52 | lang2token = { 53 | 'zh': "[ZH]", 54 | } 55 | assert (torch.cuda.is_available()), "Please enable GPU in order to run Whisper!" 56 | model = whisper.load_model(args.whisper_size) 57 | parent_dir = "./custom_character_voice/" 58 | speaker_names = list(os.walk(parent_dir))[0][1] 59 | speaker_annos = [] 60 | total_files = sum([len(files) for r, d, files in os.walk(parent_dir)]) 61 | # resample audios 62 | # 2023/4/21: Get the target sampling rate 63 | with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 64 | hps = json.load(f) 65 | target_sr = hps['data']['sampling_rate'] 66 | processed_files = 0 67 | for speaker in speaker_names: 68 | for i, wavfile in enumerate(list(os.walk(parent_dir + speaker))[0][2]): 69 | # try to load file as audio 70 | if wavfile.startswith("processed_"): 71 | continue 72 | try: 73 | wav, sr = torchaudio.load(parent_dir + speaker + "/" + wavfile, frame_offset=0, num_frames=-1, normalize=True, 74 | channels_first=True) 75 | wav = wav.mean(dim=0).unsqueeze(0) 76 | if sr != target_sr: 77 | wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(wav) 78 | if wav.shape[1] / sr > 20: 79 | print(f"{wavfile} too long, ignoring\n") 80 | save_path = parent_dir + speaker + "/" + f"processed_{i}.wav" 81 | torchaudio.save(save_path, wav, target_sr, channels_first=True) 82 | # transcribe text 83 | lang, text = transcribe_one(save_path) 84 | if lang not in list(lang2token.keys()): 85 | print(f"{lang} not supported, ignoring\n") 86 | continue 87 | text = lang2token[lang] + text + lang2token[lang] + "\n" 88 | speaker_annos.append(save_path + "|" + speaker + "|" + text) 89 | 90 | processed_files += 1 91 | print(f"Processed: {processed_files}/{total_files}") 92 | except: 93 | continue 94 | 95 | # # clean annotation 96 | # import argparse 97 | # import text 98 | # from utils import load_filepaths_and_text 99 | # for i, line in enumerate(speaker_annos): 100 | # path, sid, txt = line.split("|") 101 | # cleaned_text = text._clean_text(txt, ["cjke_cleaners2"]) 102 | # cleaned_text += "\n" if not cleaned_text.endswith("\n") else "" 103 | # speaker_annos[i] = path + "|" + sid + "|" + cleaned_text 104 | # write into annotation 105 | if len(speaker_annos) == 0: 106 | print("Warning: no short audios found, this IS expected if you have only uploaded long audios, videos or video links.") 107 | print("this IS NOT expected if you have uploaded a zip file of short audios. Please check your file structure or make sure your audio language is supported.") 108 | with open("short_character_anno.txt", 'w', encoding='utf-8') as f: 109 | for line in speaker_annos: 110 | f.write(line) 111 | 112 | # import json 113 | # # generate new config 114 | # with open("./configs/finetune_speaker.json", 'r', encoding='utf-8') as f: 115 | # hps = json.load(f) 116 | # # modify n_speakers 117 | # hps['data']["n_speakers"] = 1000 + len(speaker2id) 118 | # # add speaker names 119 | # for speaker in speaker_names: 120 | # hps['speakers'][speaker] = speaker2id[speaker] 121 | # # save modified config 122 | # with open("./configs/modified_finetune_speaker.json", 'w', encoding='utf-8') as f: 123 | # json.dump(hps, f, indent=2) 124 | # print("finished") 125 | -------------------------------------------------------------------------------- /scripts/video2audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | 4 | from moviepy.editor import AudioFileClip 5 | 6 | video_dir = "./video_data/" 7 | audio_dir = "./raw_audio/" 8 | filelist = list(os.walk(video_dir))[0][2] 9 | 10 | 11 | def generate_infos(): 12 | videos = [] 13 | for file in filelist: 14 | if file.endswith(".mp4"): 15 | videos.append(file) 16 | return videos 17 | 18 | 19 | def clip_file(file): 20 | my_audio_clip = AudioFileClip(video_dir + file) 21 | my_audio_clip.write_audiofile(audio_dir + file.rstrip("mp4") + "wav") 22 | 23 | 24 | if __name__ == "__main__": 25 | infos = generate_infos() 26 | with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: 27 | executor.map(clip_file, infos) 28 | -------------------------------------------------------------------------------- /scripts/voice_upload.py: -------------------------------------------------------------------------------- 1 | from google.colab import files 2 | import shutil 3 | import os 4 | import argparse 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--type", type=str, required=True, help="type of file to upload") 8 | args = parser.parse_args() 9 | file_type = args.type 10 | 11 | basepath = os.getcwd() 12 | uploaded = files.upload() # 上传文件 13 | assert(file_type in ['zip', 'audio', 'video']) 14 | if file_type == "zip": 15 | upload_path = "./custom_character_voice/" 16 | for filename in uploaded.keys(): 17 | #将上传的文件移动到指定的位置上 18 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, "custom_character_voice.zip")) 19 | elif file_type == "audio": 20 | upload_path = "./raw_audio/" 21 | for filename in uploaded.keys(): 22 | #将上传的文件移动到指定的位置上 23 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, filename)) 24 | elif file_type == "video": 25 | upload_path = "./video_data/" 26 | for filename in uploaded.keys(): 27 | # 将上传的文件移动到指定的位置上 28 | shutil.move(os.path.join(basepath, filename), os.path.join(upload_path, filename)) -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | from text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, symbols, cleaner_names): 12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | ''' 19 | sequence = [] 20 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | clean_text = _clean_text(text, cleaner_names) 22 | print(clean_text) 23 | print(f" length:{len(clean_text)}") 24 | for symbol in clean_text: 25 | if symbol not in symbol_to_id.keys(): 26 | continue 27 | symbol_id = symbol_to_id[symbol] 28 | sequence += [symbol_id] 29 | print(f" length:{len(sequence)}") 30 | return sequence 31 | 32 | 33 | def cleaned_text_to_sequence(cleaned_text, symbols): 34 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 35 | Args: 36 | text: string to convert to a sequence 37 | Returns: 38 | List of integers corresponding to the symbols in the text 39 | ''' 40 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 41 | sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()] 42 | return sequence 43 | 44 | 45 | def sequence_to_text(sequence): 46 | '''Converts a sequence of IDs back to a string''' 47 | result = '' 48 | for symbol_id in sequence: 49 | s = _id_to_symbol[symbol_id] 50 | result += s 51 | return result 52 | 53 | 54 | def _clean_text(text, cleaner_names): 55 | for name in cleaner_names: 56 | cleaner = getattr(cleaners, name) 57 | if not cleaner: 58 | raise Exception('Unknown cleaner: %s' % name) 59 | text = cleaner(text) 60 | return text 61 | -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/cleaners.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/english.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/english.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/japanese.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/japanese.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/korean.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/korean.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/mandarin.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/mandarin.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/sanskrit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/sanskrit.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/symbols.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/thai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Plachtaa/VITS-fast-fine-tuning/8d341c7215f7770e81e6ac4486602179883d09af/text/__pycache__/thai.cpython-37.pyc -------------------------------------------------------------------------------- /text/cantonese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cn2an 3 | import opencc 4 | 5 | 6 | converter = opencc.OpenCC('jyutjyu') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ei˥'), 11 | ('B', 'biː˥'), 12 | ('C', 'siː˥'), 13 | ('D', 'tiː˥'), 14 | ('E', 'iː˥'), 15 | ('F', 'e˥fuː˨˩'), 16 | ('G', 'tsiː˥'), 17 | ('H', 'ɪk̚˥tsʰyː˨˩'), 18 | ('I', 'ɐi˥'), 19 | ('J', 'tsei˥'), 20 | ('K', 'kʰei˥'), 21 | ('L', 'e˥llou˨˩'), 22 | ('M', 'ɛːm˥'), 23 | ('N', 'ɛːn˥'), 24 | ('O', 'ou˥'), 25 | ('P', 'pʰiː˥'), 26 | ('Q', 'kʰiːu˥'), 27 | ('R', 'aː˥lou˨˩'), 28 | ('S', 'ɛː˥siː˨˩'), 29 | ('T', 'tʰiː˥'), 30 | ('U', 'juː˥'), 31 | ('V', 'wiː˥'), 32 | ('W', 'tʊk̚˥piː˥juː˥'), 33 | ('X', 'ɪk̚˥siː˨˩'), 34 | ('Y', 'waːi˥'), 35 | ('Z', 'iː˨sɛːt̚˥') 36 | ]] 37 | 38 | 39 | def number_to_cantonese(text): 40 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text) 41 | 42 | 43 | def latin_to_ipa(text): 44 | for regex, replacement in _latin_to_ipa: 45 | text = re.sub(regex, replacement, text) 46 | return text 47 | 48 | 49 | def cantonese_to_ipa(text): 50 | text = number_to_cantonese(text.upper()) 51 | text = converter.convert(text).replace('-','').replace('$',' ') 52 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text) 53 | text = re.sub(r'[、;:]', ',', text) 54 | text = re.sub(r'\s*,\s*', ', ', text) 55 | text = re.sub(r'\s*。\s*', '. ', text) 56 | text = re.sub(r'\s*?\s*', '? ', text) 57 | text = re.sub(r'\s*!\s*', '! ', text) 58 | text = re.sub(r'\s*$', '', text) 59 | return text 60 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from text.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3 3 | from text.korean import latin_to_hangul, number_to_hangul, divide_hangul, korean_to_lazy_ipa, korean_to_ipa 4 | from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 5 | from text.sanskrit import devanagari_to_ipa 6 | from text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 7 | from text.thai import num_to_thai, latin_to_thai 8 | # from text.shanghainese import shanghainese_to_ipa 9 | # from text.cantonese import cantonese_to_ipa 10 | # from text.ngu_dialect import ngu_dialect_to_ipa 11 | 12 | 13 | def japanese_cleaners(text): 14 | text = japanese_to_romaji_with_accent(text) 15 | text = re.sub(r'([A-Za-z])$', r'\1.', text) 16 | return text 17 | 18 | 19 | def japanese_cleaners2(text): 20 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') 21 | 22 | 23 | def korean_cleaners(text): 24 | '''Pipeline for Korean text''' 25 | text = latin_to_hangul(text) 26 | text = number_to_hangul(text) 27 | text = divide_hangul(text) 28 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text) 29 | return text 30 | 31 | 32 | def chinese_cleaners(text): 33 | '''Pipeline for Chinese text''' 34 | text = text.replace("[ZH]", "") 35 | text = number_to_chinese(text) 36 | text = chinese_to_bopomofo(text) 37 | text = latin_to_bopomofo(text) 38 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text) 39 | return text 40 | 41 | 42 | def zh_ja_mixture_cleaners(text): 43 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 44 | lambda x: chinese_to_romaji(x.group(1))+' ', text) 45 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent( 46 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…')+' ', text) 47 | text = re.sub(r'\s+$', '', text) 48 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 49 | return text 50 | 51 | 52 | def sanskrit_cleaners(text): 53 | text = text.replace('॥', '।').replace('ॐ', 'ओम्') 54 | text = re.sub(r'([^।])$', r'\1।', text) 55 | return text 56 | 57 | 58 | def cjks_cleaners(text): 59 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 60 | lambda x: chinese_to_lazy_ipa(x.group(1))+' ', text) 61 | text = re.sub(r'\[JA\](.*?)\[JA\]', 62 | lambda x: japanese_to_ipa(x.group(1))+' ', text) 63 | text = re.sub(r'\[KO\](.*?)\[KO\]', 64 | lambda x: korean_to_lazy_ipa(x.group(1))+' ', text) 65 | text = re.sub(r'\[SA\](.*?)\[SA\]', 66 | lambda x: devanagari_to_ipa(x.group(1))+' ', text) 67 | text = re.sub(r'\[EN\](.*?)\[EN\]', 68 | lambda x: english_to_lazy_ipa(x.group(1))+' ', text) 69 | text = re.sub(r'\s+$', '', text) 70 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 71 | return text 72 | 73 | 74 | def cjke_cleaners(text): 75 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace( 76 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')+' ', text) 77 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace( 78 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')+' ', text) 79 | text = re.sub(r'\[KO\](.*?)\[KO\]', 80 | lambda x: korean_to_ipa(x.group(1))+' ', text) 81 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace( 82 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')+' ', text) 83 | text = re.sub(r'\s+$', '', text) 84 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 85 | return text 86 | 87 | 88 | def cjke_cleaners2(text): 89 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 90 | lambda x: chinese_to_ipa(x.group(1))+' ', text) 91 | text = re.sub(r'\[JA\](.*?)\[JA\]', 92 | lambda x: japanese_to_ipa2(x.group(1))+' ', text) 93 | text = re.sub(r'\[KO\](.*?)\[KO\]', 94 | lambda x: korean_to_ipa(x.group(1))+' ', text) 95 | text = re.sub(r'\[EN\](.*?)\[EN\]', 96 | lambda x: english_to_ipa2(x.group(1))+' ', text) 97 | text = re.sub(r'\s+$', '', text) 98 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 99 | return text 100 | 101 | 102 | def thai_cleaners(text): 103 | text = num_to_thai(text) 104 | text = latin_to_thai(text) 105 | return text 106 | 107 | 108 | # def shanghainese_cleaners(text): 109 | # text = shanghainese_to_ipa(text) 110 | # text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 111 | # return text 112 | 113 | 114 | # def chinese_dialect_cleaners(text): 115 | # text = re.sub(r'\[ZH\](.*?)\[ZH\]', 116 | # lambda x: chinese_to_ipa2(x.group(1))+' ', text) 117 | # text = re.sub(r'\[JA\](.*?)\[JA\]', 118 | # lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text) 119 | # text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5', 120 | # '˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text) 121 | # text = re.sub(r'\[GD\](.*?)\[GD\]', 122 | # lambda x: cantonese_to_ipa(x.group(1))+' ', text) 123 | # text = re.sub(r'\[EN\](.*?)\[EN\]', 124 | # lambda x: english_to_lazy_ipa2(x.group(1))+' ', text) 125 | # text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group( 126 | # 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text) 127 | # text = re.sub(r'\s+$', '', text) 128 | # text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 129 | # return text 130 | -------------------------------------------------------------------------------- /text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /text/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unidecode import unidecode 3 | import pyopenjtalk 4 | 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = '' 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if text != '': 83 | text += ' ' 84 | labels = pyopenjtalk.extract_fullcontext(sentence) 85 | for n, label in enumerate(labels): 86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 87 | if phoneme not in ['sil', 'pau']: 88 | text += phoneme.replace('ch', 'ʧ').replace('sh', 89 | 'ʃ').replace('cl', 'Q') 90 | else: 91 | continue 92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 95 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 97 | a2_next = -1 98 | else: 99 | a2_next = int( 100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 101 | # Accent phrase boundary 102 | if a3 == 1 and a2_next == 1: 103 | text += ' ' 104 | # Falling 105 | elif a1 == 0 and a2_next == a2 + 1: 106 | text += '↓' 107 | # Rising 108 | elif a2 == 1 and a2_next == 2: 109 | text += '↑' 110 | if i < len(marks): 111 | text += unidecode(marks[i]).replace(' ', '') 112 | return text 113 | 114 | 115 | def get_real_sokuon(text): 116 | for regex, replacement in _real_sokuon: 117 | text = re.sub(regex, replacement, text) 118 | return text 119 | 120 | 121 | def get_real_hatsuon(text): 122 | for regex, replacement in _real_hatsuon: 123 | text = re.sub(regex, replacement, text) 124 | return text 125 | 126 | 127 | def japanese_to_ipa(text): 128 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 129 | text = re.sub( 130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 131 | text = get_real_sokuon(text) 132 | text = get_real_hatsuon(text) 133 | for regex, replacement in _romaji_to_ipa: 134 | text = re.sub(regex, replacement, text) 135 | return text 136 | 137 | 138 | def japanese_to_ipa2(text): 139 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 140 | text = get_real_sokuon(text) 141 | text = get_real_hatsuon(text) 142 | for regex, replacement in _romaji_to_ipa2: 143 | text = re.sub(regex, replacement, text) 144 | return text 145 | 146 | 147 | def japanese_to_ipa3(text): 148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 150 | text = re.sub( 151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 153 | return text 154 | -------------------------------------------------------------------------------- /text/korean.py: -------------------------------------------------------------------------------- 1 | import re 2 | from jamo import h2j, j2hcj 3 | import ko_pron 4 | 5 | 6 | # This is a list of Korean classifiers preceded by pure Korean numerals. 7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' 8 | 9 | # List of (hangul, hangul divided) pairs: 10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ 11 | ('ㄳ', 'ㄱㅅ'), 12 | ('ㄵ', 'ㄴㅈ'), 13 | ('ㄶ', 'ㄴㅎ'), 14 | ('ㄺ', 'ㄹㄱ'), 15 | ('ㄻ', 'ㄹㅁ'), 16 | ('ㄼ', 'ㄹㅂ'), 17 | ('ㄽ', 'ㄹㅅ'), 18 | ('ㄾ', 'ㄹㅌ'), 19 | ('ㄿ', 'ㄹㅍ'), 20 | ('ㅀ', 'ㄹㅎ'), 21 | ('ㅄ', 'ㅂㅅ'), 22 | ('ㅘ', 'ㅗㅏ'), 23 | ('ㅙ', 'ㅗㅐ'), 24 | ('ㅚ', 'ㅗㅣ'), 25 | ('ㅝ', 'ㅜㅓ'), 26 | ('ㅞ', 'ㅜㅔ'), 27 | ('ㅟ', 'ㅜㅣ'), 28 | ('ㅢ', 'ㅡㅣ'), 29 | ('ㅑ', 'ㅣㅏ'), 30 | ('ㅒ', 'ㅣㅐ'), 31 | ('ㅕ', 'ㅣㅓ'), 32 | ('ㅖ', 'ㅣㅔ'), 33 | ('ㅛ', 'ㅣㅗ'), 34 | ('ㅠ', 'ㅣㅜ') 35 | ]] 36 | 37 | # List of (Latin alphabet, hangul) pairs: 38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 39 | ('a', '에이'), 40 | ('b', '비'), 41 | ('c', '시'), 42 | ('d', '디'), 43 | ('e', '이'), 44 | ('f', '에프'), 45 | ('g', '지'), 46 | ('h', '에이치'), 47 | ('i', '아이'), 48 | ('j', '제이'), 49 | ('k', '케이'), 50 | ('l', '엘'), 51 | ('m', '엠'), 52 | ('n', '엔'), 53 | ('o', '오'), 54 | ('p', '피'), 55 | ('q', '큐'), 56 | ('r', '아르'), 57 | ('s', '에스'), 58 | ('t', '티'), 59 | ('u', '유'), 60 | ('v', '브이'), 61 | ('w', '더블유'), 62 | ('x', '엑스'), 63 | ('y', '와이'), 64 | ('z', '제트') 65 | ]] 66 | 67 | # List of (ipa, lazy ipa) pairs: 68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 69 | ('t͡ɕ','ʧ'), 70 | ('d͡ʑ','ʥ'), 71 | ('ɲ','n^'), 72 | ('ɕ','ʃ'), 73 | ('ʷ','w'), 74 | ('ɭ','l`'), 75 | ('ʎ','ɾ'), 76 | ('ɣ','ŋ'), 77 | ('ɰ','ɯ'), 78 | ('ʝ','j'), 79 | ('ʌ','ə'), 80 | ('ɡ','g'), 81 | ('\u031a','#'), 82 | ('\u0348','='), 83 | ('\u031e',''), 84 | ('\u0320',''), 85 | ('\u0339','') 86 | ]] 87 | 88 | 89 | def latin_to_hangul(text): 90 | for regex, replacement in _latin_to_hangul: 91 | text = re.sub(regex, replacement, text) 92 | return text 93 | 94 | 95 | def divide_hangul(text): 96 | text = j2hcj(h2j(text)) 97 | for regex, replacement in _hangul_divided: 98 | text = re.sub(regex, replacement, text) 99 | return text 100 | 101 | 102 | def hangul_number(num, sino=True): 103 | '''Reference https://github.com/Kyubyong/g2pK''' 104 | num = re.sub(',', '', num) 105 | 106 | if num == '0': 107 | return '영' 108 | if not sino and num == '20': 109 | return '스무' 110 | 111 | digits = '123456789' 112 | names = '일이삼사오육칠팔구' 113 | digit2name = {d: n for d, n in zip(digits, names)} 114 | 115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉' 116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔' 117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} 118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} 119 | 120 | spelledout = [] 121 | for i, digit in enumerate(num): 122 | i = len(num) - i - 1 123 | if sino: 124 | if i == 0: 125 | name = digit2name.get(digit, '') 126 | elif i == 1: 127 | name = digit2name.get(digit, '') + '십' 128 | name = name.replace('일십', '십') 129 | else: 130 | if i == 0: 131 | name = digit2mod.get(digit, '') 132 | elif i == 1: 133 | name = digit2dec.get(digit, '') 134 | if digit == '0': 135 | if i % 4 == 0: 136 | last_three = spelledout[-min(3, len(spelledout)):] 137 | if ''.join(last_three) == '': 138 | spelledout.append('') 139 | continue 140 | else: 141 | spelledout.append('') 142 | continue 143 | if i == 2: 144 | name = digit2name.get(digit, '') + '백' 145 | name = name.replace('일백', '백') 146 | elif i == 3: 147 | name = digit2name.get(digit, '') + '천' 148 | name = name.replace('일천', '천') 149 | elif i == 4: 150 | name = digit2name.get(digit, '') + '만' 151 | name = name.replace('일만', '만') 152 | elif i == 5: 153 | name = digit2name.get(digit, '') + '십' 154 | name = name.replace('일십', '십') 155 | elif i == 6: 156 | name = digit2name.get(digit, '') + '백' 157 | name = name.replace('일백', '백') 158 | elif i == 7: 159 | name = digit2name.get(digit, '') + '천' 160 | name = name.replace('일천', '천') 161 | elif i == 8: 162 | name = digit2name.get(digit, '') + '억' 163 | elif i == 9: 164 | name = digit2name.get(digit, '') + '십' 165 | elif i == 10: 166 | name = digit2name.get(digit, '') + '백' 167 | elif i == 11: 168 | name = digit2name.get(digit, '') + '천' 169 | elif i == 12: 170 | name = digit2name.get(digit, '') + '조' 171 | elif i == 13: 172 | name = digit2name.get(digit, '') + '십' 173 | elif i == 14: 174 | name = digit2name.get(digit, '') + '백' 175 | elif i == 15: 176 | name = digit2name.get(digit, '') + '천' 177 | spelledout.append(name) 178 | return ''.join(elem for elem in spelledout) 179 | 180 | 181 | def number_to_hangul(text): 182 | '''Reference https://github.com/Kyubyong/g2pK''' 183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text)) 184 | for token in tokens: 185 | num, classifier = token 186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: 187 | spelledout = hangul_number(num, sino=False) 188 | else: 189 | spelledout = hangul_number(num, sino=True) 190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}') 191 | # digit by digit for remaining digits 192 | digits = '0123456789' 193 | names = '영일이삼사오육칠팔구' 194 | for d, n in zip(digits, names): 195 | text = text.replace(d, n) 196 | return text 197 | 198 | 199 | def korean_to_lazy_ipa(text): 200 | text = latin_to_hangul(text) 201 | text = number_to_hangul(text) 202 | text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text) 203 | for regex, replacement in _ipa_to_lazy_ipa: 204 | text = re.sub(regex, replacement, text) 205 | return text 206 | 207 | 208 | def korean_to_ipa(text): 209 | text = korean_to_lazy_ipa(text) 210 | return text.replace('ʧ','tʃ').replace('ʥ','dʑ') 211 | -------------------------------------------------------------------------------- /text/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from pypinyin import lazy_pinyin, BOPOMOFO 5 | import jieba 6 | import cn2an 7 | import logging 8 | 9 | 10 | # List of (Latin alphabet, bopomofo) pairs: 11 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 12 | ('a', 'ㄟˉ'), 13 | ('b', 'ㄅㄧˋ'), 14 | ('c', 'ㄙㄧˉ'), 15 | ('d', 'ㄉㄧˋ'), 16 | ('e', 'ㄧˋ'), 17 | ('f', 'ㄝˊㄈㄨˋ'), 18 | ('g', 'ㄐㄧˋ'), 19 | ('h', 'ㄝˇㄑㄩˋ'), 20 | ('i', 'ㄞˋ'), 21 | ('j', 'ㄐㄟˋ'), 22 | ('k', 'ㄎㄟˋ'), 23 | ('l', 'ㄝˊㄛˋ'), 24 | ('m', 'ㄝˊㄇㄨˋ'), 25 | ('n', 'ㄣˉ'), 26 | ('o', 'ㄡˉ'), 27 | ('p', 'ㄆㄧˉ'), 28 | ('q', 'ㄎㄧㄡˉ'), 29 | ('r', 'ㄚˋ'), 30 | ('s', 'ㄝˊㄙˋ'), 31 | ('t', 'ㄊㄧˋ'), 32 | ('u', 'ㄧㄡˉ'), 33 | ('v', 'ㄨㄧˉ'), 34 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 35 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 36 | ('y', 'ㄨㄞˋ'), 37 | ('z', 'ㄗㄟˋ') 38 | ]] 39 | 40 | # List of (bopomofo, romaji) pairs: 41 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 42 | ('ㄅㄛ', 'p⁼wo'), 43 | ('ㄆㄛ', 'pʰwo'), 44 | ('ㄇㄛ', 'mwo'), 45 | ('ㄈㄛ', 'fwo'), 46 | ('ㄅ', 'p⁼'), 47 | ('ㄆ', 'pʰ'), 48 | ('ㄇ', 'm'), 49 | ('ㄈ', 'f'), 50 | ('ㄉ', 't⁼'), 51 | ('ㄊ', 'tʰ'), 52 | ('ㄋ', 'n'), 53 | ('ㄌ', 'l'), 54 | ('ㄍ', 'k⁼'), 55 | ('ㄎ', 'kʰ'), 56 | ('ㄏ', 'h'), 57 | ('ㄐ', 'ʧ⁼'), 58 | ('ㄑ', 'ʧʰ'), 59 | ('ㄒ', 'ʃ'), 60 | ('ㄓ', 'ʦ`⁼'), 61 | ('ㄔ', 'ʦ`ʰ'), 62 | ('ㄕ', 's`'), 63 | ('ㄖ', 'ɹ`'), 64 | ('ㄗ', 'ʦ⁼'), 65 | ('ㄘ', 'ʦʰ'), 66 | ('ㄙ', 's'), 67 | ('ㄚ', 'a'), 68 | ('ㄛ', 'o'), 69 | ('ㄜ', 'ə'), 70 | ('ㄝ', 'e'), 71 | ('ㄞ', 'ai'), 72 | ('ㄟ', 'ei'), 73 | ('ㄠ', 'au'), 74 | ('ㄡ', 'ou'), 75 | ('ㄧㄢ', 'yeNN'), 76 | ('ㄢ', 'aNN'), 77 | ('ㄧㄣ', 'iNN'), 78 | ('ㄣ', 'əNN'), 79 | ('ㄤ', 'aNg'), 80 | ('ㄧㄥ', 'iNg'), 81 | ('ㄨㄥ', 'uNg'), 82 | ('ㄩㄥ', 'yuNg'), 83 | ('ㄥ', 'əNg'), 84 | ('ㄦ', 'əɻ'), 85 | ('ㄧ', 'i'), 86 | ('ㄨ', 'u'), 87 | ('ㄩ', 'ɥ'), 88 | ('ˉ', '→'), 89 | ('ˊ', '↑'), 90 | ('ˇ', '↓↑'), 91 | ('ˋ', '↓'), 92 | ('˙', ''), 93 | (',', ','), 94 | ('。', '.'), 95 | ('!', '!'), 96 | ('?', '?'), 97 | ('—', '-') 98 | ]] 99 | 100 | # List of (romaji, ipa) pairs: 101 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 102 | ('ʃy', 'ʃ'), 103 | ('ʧʰy', 'ʧʰ'), 104 | ('ʧ⁼y', 'ʧ⁼'), 105 | ('NN', 'n'), 106 | ('Ng', 'ŋ'), 107 | ('y', 'j'), 108 | ('h', 'x') 109 | ]] 110 | 111 | # List of (bopomofo, ipa) pairs: 112 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 113 | ('ㄅㄛ', 'p⁼wo'), 114 | ('ㄆㄛ', 'pʰwo'), 115 | ('ㄇㄛ', 'mwo'), 116 | ('ㄈㄛ', 'fwo'), 117 | ('ㄅ', 'p⁼'), 118 | ('ㄆ', 'pʰ'), 119 | ('ㄇ', 'm'), 120 | ('ㄈ', 'f'), 121 | ('ㄉ', 't⁼'), 122 | ('ㄊ', 'tʰ'), 123 | ('ㄋ', 'n'), 124 | ('ㄌ', 'l'), 125 | ('ㄍ', 'k⁼'), 126 | ('ㄎ', 'kʰ'), 127 | ('ㄏ', 'x'), 128 | ('ㄐ', 'tʃ⁼'), 129 | ('ㄑ', 'tʃʰ'), 130 | ('ㄒ', 'ʃ'), 131 | ('ㄓ', 'ts`⁼'), 132 | ('ㄔ', 'ts`ʰ'), 133 | ('ㄕ', 's`'), 134 | ('ㄖ', 'ɹ`'), 135 | ('ㄗ', 'ts⁼'), 136 | ('ㄘ', 'tsʰ'), 137 | ('ㄙ', 's'), 138 | ('ㄚ', 'a'), 139 | ('ㄛ', 'o'), 140 | ('ㄜ', 'ə'), 141 | ('ㄝ', 'ɛ'), 142 | ('ㄞ', 'aɪ'), 143 | ('ㄟ', 'eɪ'), 144 | ('ㄠ', 'ɑʊ'), 145 | ('ㄡ', 'oʊ'), 146 | ('ㄧㄢ', 'jɛn'), 147 | ('ㄩㄢ', 'ɥæn'), 148 | ('ㄢ', 'an'), 149 | ('ㄧㄣ', 'in'), 150 | ('ㄩㄣ', 'ɥn'), 151 | ('ㄣ', 'ən'), 152 | ('ㄤ', 'ɑŋ'), 153 | ('ㄧㄥ', 'iŋ'), 154 | ('ㄨㄥ', 'ʊŋ'), 155 | ('ㄩㄥ', 'jʊŋ'), 156 | ('ㄥ', 'əŋ'), 157 | ('ㄦ', 'əɻ'), 158 | ('ㄧ', 'i'), 159 | ('ㄨ', 'u'), 160 | ('ㄩ', 'ɥ'), 161 | ('ˉ', '→'), 162 | ('ˊ', '↑'), 163 | ('ˇ', '↓↑'), 164 | ('ˋ', '↓'), 165 | ('˙', ''), 166 | (',', ','), 167 | ('。', '.'), 168 | ('!', '!'), 169 | ('?', '?'), 170 | ('—', '-') 171 | ]] 172 | 173 | # List of (bopomofo, ipa2) pairs: 174 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 175 | ('ㄅㄛ', 'pwo'), 176 | ('ㄆㄛ', 'pʰwo'), 177 | ('ㄇㄛ', 'mwo'), 178 | ('ㄈㄛ', 'fwo'), 179 | ('ㄅ', 'p'), 180 | ('ㄆ', 'pʰ'), 181 | ('ㄇ', 'm'), 182 | ('ㄈ', 'f'), 183 | ('ㄉ', 't'), 184 | ('ㄊ', 'tʰ'), 185 | ('ㄋ', 'n'), 186 | ('ㄌ', 'l'), 187 | ('ㄍ', 'k'), 188 | ('ㄎ', 'kʰ'), 189 | ('ㄏ', 'h'), 190 | ('ㄐ', 'tɕ'), 191 | ('ㄑ', 'tɕʰ'), 192 | ('ㄒ', 'ɕ'), 193 | ('ㄓ', 'tʂ'), 194 | ('ㄔ', 'tʂʰ'), 195 | ('ㄕ', 'ʂ'), 196 | ('ㄖ', 'ɻ'), 197 | ('ㄗ', 'ts'), 198 | ('ㄘ', 'tsʰ'), 199 | ('ㄙ', 's'), 200 | ('ㄚ', 'a'), 201 | ('ㄛ', 'o'), 202 | ('ㄜ', 'ɤ'), 203 | ('ㄝ', 'ɛ'), 204 | ('ㄞ', 'aɪ'), 205 | ('ㄟ', 'eɪ'), 206 | ('ㄠ', 'ɑʊ'), 207 | ('ㄡ', 'oʊ'), 208 | ('ㄧㄢ', 'jɛn'), 209 | ('ㄩㄢ', 'yæn'), 210 | ('ㄢ', 'an'), 211 | ('ㄧㄣ', 'in'), 212 | ('ㄩㄣ', 'yn'), 213 | ('ㄣ', 'ən'), 214 | ('ㄤ', 'ɑŋ'), 215 | ('ㄧㄥ', 'iŋ'), 216 | ('ㄨㄥ', 'ʊŋ'), 217 | ('ㄩㄥ', 'jʊŋ'), 218 | ('ㄥ', 'ɤŋ'), 219 | ('ㄦ', 'əɻ'), 220 | ('ㄧ', 'i'), 221 | ('ㄨ', 'u'), 222 | ('ㄩ', 'y'), 223 | ('ˉ', '˥'), 224 | ('ˊ', '˧˥'), 225 | ('ˇ', '˨˩˦'), 226 | ('ˋ', '˥˩'), 227 | ('˙', ''), 228 | (',', ','), 229 | ('。', '.'), 230 | ('!', '!'), 231 | ('?', '?'), 232 | ('—', '-') 233 | ]] 234 | 235 | 236 | def number_to_chinese(text): 237 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 238 | for number in numbers: 239 | text = text.replace(number, cn2an.an2cn(number), 1) 240 | return text 241 | 242 | 243 | def chinese_to_bopomofo(text): 244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 245 | words = jieba.lcut(text, cut_all=False) 246 | text = '' 247 | for word in words: 248 | bopomofos = lazy_pinyin(word, BOPOMOFO) 249 | if not re.search('[\u4e00-\u9fff]', word): 250 | text += word 251 | continue 252 | for i in range(len(bopomofos)): 253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 254 | if text != '': 255 | text += ' ' 256 | text += ''.join(bopomofos) 257 | return text 258 | 259 | 260 | def latin_to_bopomofo(text): 261 | for regex, replacement in _latin_to_bopomofo: 262 | text = re.sub(regex, replacement, text) 263 | return text 264 | 265 | 266 | def bopomofo_to_romaji(text): 267 | for regex, replacement in _bopomofo_to_romaji: 268 | text = re.sub(regex, replacement, text) 269 | return text 270 | 271 | 272 | def bopomofo_to_ipa(text): 273 | for regex, replacement in _bopomofo_to_ipa: 274 | text = re.sub(regex, replacement, text) 275 | return text 276 | 277 | 278 | def bopomofo_to_ipa2(text): 279 | for regex, replacement in _bopomofo_to_ipa2: 280 | text = re.sub(regex, replacement, text) 281 | return text 282 | 283 | 284 | def chinese_to_romaji(text): 285 | text = number_to_chinese(text) 286 | text = chinese_to_bopomofo(text) 287 | text = latin_to_bopomofo(text) 288 | text = bopomofo_to_romaji(text) 289 | text = re.sub('i([aoe])', r'y\1', text) 290 | text = re.sub('u([aoəe])', r'w\1', text) 291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 294 | return text 295 | 296 | 297 | def chinese_to_lazy_ipa(text): 298 | text = chinese_to_romaji(text) 299 | for regex, replacement in _romaji_to_ipa: 300 | text = re.sub(regex, replacement, text) 301 | return text 302 | 303 | 304 | def chinese_to_ipa(text): 305 | text = number_to_chinese(text) 306 | text = chinese_to_bopomofo(text) 307 | text = latin_to_bopomofo(text) 308 | text = bopomofo_to_ipa(text) 309 | text = re.sub('i([aoe])', r'j\1', text) 310 | text = re.sub('u([aoəe])', r'w\1', text) 311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 314 | return text 315 | 316 | 317 | def chinese_to_ipa2(text): 318 | text = number_to_chinese(text) 319 | text = chinese_to_bopomofo(text) 320 | text = latin_to_bopomofo(text) 321 | text = bopomofo_to_ipa2(text) 322 | text = re.sub(r'i([aoe])', r'j\1', text) 323 | text = re.sub(r'u([aoəe])', r'w\1', text) 324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 326 | return text 327 | -------------------------------------------------------------------------------- /text/ngu_dialect.py: -------------------------------------------------------------------------------- 1 | import re 2 | import opencc 3 | 4 | 5 | dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou', 6 | 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing', 7 | 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang', 8 | 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan', 9 | 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen', 10 | 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'} 11 | 12 | converters = {} 13 | 14 | for dialect in dialects.values(): 15 | try: 16 | converters[dialect] = opencc.OpenCC(dialect) 17 | except: 18 | pass 19 | 20 | 21 | def ngu_dialect_to_ipa(text, dialect): 22 | dialect = dialects[dialect] 23 | text = converters[dialect].convert(text).replace('-','').replace('$',' ') 24 | text = re.sub(r'[、;:]', ',', text) 25 | text = re.sub(r'\s*,\s*', ', ', text) 26 | text = re.sub(r'\s*。\s*', '. ', text) 27 | text = re.sub(r'\s*?\s*', '? ', text) 28 | text = re.sub(r'\s*!\s*', '! ', text) 29 | text = re.sub(r'\s*$', '', text) 30 | return text 31 | -------------------------------------------------------------------------------- /text/sanskrit.py: -------------------------------------------------------------------------------- 1 | import re 2 | from indic_transliteration import sanscript 3 | 4 | 5 | # List of (iast, ipa) pairs: 6 | _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 7 | ('a', 'ə'), 8 | ('ā', 'aː'), 9 | ('ī', 'iː'), 10 | ('ū', 'uː'), 11 | ('ṛ', 'ɹ`'), 12 | ('ṝ', 'ɹ`ː'), 13 | ('ḷ', 'l`'), 14 | ('ḹ', 'l`ː'), 15 | ('e', 'eː'), 16 | ('o', 'oː'), 17 | ('k', 'k⁼'), 18 | ('k⁼h', 'kʰ'), 19 | ('g', 'g⁼'), 20 | ('g⁼h', 'gʰ'), 21 | ('ṅ', 'ŋ'), 22 | ('c', 'ʧ⁼'), 23 | ('ʧ⁼h', 'ʧʰ'), 24 | ('j', 'ʥ⁼'), 25 | ('ʥ⁼h', 'ʥʰ'), 26 | ('ñ', 'n^'), 27 | ('ṭ', 't`⁼'), 28 | ('t`⁼h', 't`ʰ'), 29 | ('ḍ', 'd`⁼'), 30 | ('d`⁼h', 'd`ʰ'), 31 | ('ṇ', 'n`'), 32 | ('t', 't⁼'), 33 | ('t⁼h', 'tʰ'), 34 | ('d', 'd⁼'), 35 | ('d⁼h', 'dʰ'), 36 | ('p', 'p⁼'), 37 | ('p⁼h', 'pʰ'), 38 | ('b', 'b⁼'), 39 | ('b⁼h', 'bʰ'), 40 | ('y', 'j'), 41 | ('ś', 'ʃ'), 42 | ('ṣ', 's`'), 43 | ('r', 'ɾ'), 44 | ('l̤', 'l`'), 45 | ('h', 'ɦ'), 46 | ("'", ''), 47 | ('~', '^'), 48 | ('ṃ', '^') 49 | ]] 50 | 51 | 52 | def devanagari_to_ipa(text): 53 | text = text.replace('ॐ', 'ओम्') 54 | text = re.sub(r'\s*।\s*$', '.', text) 55 | text = re.sub(r'\s*।\s*', ', ', text) 56 | text = re.sub(r'\s*॥', '.', text) 57 | text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST) 58 | for regex, replacement in _iast_to_ipa: 59 | text = re.sub(regex, replacement, text) 60 | text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0) 61 | [:-1]+'h'+x.group(1)+'*', text) 62 | return text 63 | -------------------------------------------------------------------------------- /text/shanghainese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cn2an 3 | import opencc 4 | 5 | 6 | converter = opencc.OpenCC('zaonhe') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ᴇ'), 11 | ('B', 'bi'), 12 | ('C', 'si'), 13 | ('D', 'di'), 14 | ('E', 'i'), 15 | ('F', 'ᴇf'), 16 | ('G', 'dʑi'), 17 | ('H', 'ᴇtɕʰ'), 18 | ('I', 'ᴀi'), 19 | ('J', 'dʑᴇ'), 20 | ('K', 'kʰᴇ'), 21 | ('L', 'ᴇl'), 22 | ('M', 'ᴇm'), 23 | ('N', 'ᴇn'), 24 | ('O', 'o'), 25 | ('P', 'pʰi'), 26 | ('Q', 'kʰiu'), 27 | ('R', 'ᴀl'), 28 | ('S', 'ᴇs'), 29 | ('T', 'tʰi'), 30 | ('U', 'ɦiu'), 31 | ('V', 'vi'), 32 | ('W', 'dᴀbɤliu'), 33 | ('X', 'ᴇks'), 34 | ('Y', 'uᴀi'), 35 | ('Z', 'zᴇ') 36 | ]] 37 | 38 | 39 | def _number_to_shanghainese(num): 40 | num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两') 41 | return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num) 42 | 43 | 44 | def number_to_shanghainese(text): 45 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text) 46 | 47 | 48 | def latin_to_ipa(text): 49 | for regex, replacement in _latin_to_ipa: 50 | text = re.sub(regex, replacement, text) 51 | return text 52 | 53 | 54 | def shanghainese_to_ipa(text): 55 | text = number_to_shanghainese(text.upper()) 56 | text = converter.convert(text).replace('-','').replace('$',' ') 57 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text) 58 | text = re.sub(r'[、;:]', ',', text) 59 | text = re.sub(r'\s*,\s*', ', ', text) 60 | text = re.sub(r'\s*。\s*', '. ', text) 61 | text = re.sub(r'\s*?\s*', '? ', text) 62 | text = re.sub(r'\s*!\s*', '! ', text) 63 | text = re.sub(r'\s*$', '', text) 64 | return text 65 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | # _pad = '_' 7 | # _punctuation = ',.!?-' 8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | 11 | '''# japanese_cleaners2 12 | _pad = '_' 13 | _punctuation = ',.!?-~…' 14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 15 | ''' 16 | 17 | 18 | '''# korean_cleaners 19 | _pad = '_' 20 | _punctuation = ',.!?…~' 21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 22 | ''' 23 | 24 | '''# chinese_cleaners 25 | _pad = '_' 26 | _punctuation = ',。!?—…' 27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 28 | ''' 29 | 30 | # # zh_ja_mixture_cleaners 31 | # _pad = '_' 32 | # _punctuation = ',.!?-~…' 33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 34 | 35 | 36 | '''# sanskrit_cleaners 37 | _pad = '_' 38 | _punctuation = '।' 39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 40 | ''' 41 | 42 | '''# cjks_cleaners 43 | _pad = '_' 44 | _punctuation = ',.!?-~…' 45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 46 | ''' 47 | 48 | '''# thai_cleaners 49 | _pad = '_' 50 | _punctuation = '.!? ' 51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 52 | ''' 53 | 54 | # # cjke_cleaners2 55 | _pad = '_' 56 | _punctuation = ',.!?-~…' 57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 58 | 59 | 60 | '''# shanghainese_cleaners 61 | _pad = '_' 62 | _punctuation = ',.!?…' 63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 64 | ''' 65 | 66 | '''# chinese_dialect_cleaners 67 | _pad = '_' 68 | _punctuation = ',.!?~…─' 69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 70 | ''' 71 | 72 | # Export all symbols: 73 | symbols = [_pad] + list(_punctuation) + list(_letters) 74 | 75 | # Special symbol ids 76 | SPACE_ID = symbols.index(" ") 77 | -------------------------------------------------------------------------------- /text/thai.py: -------------------------------------------------------------------------------- 1 | import re 2 | from num_thai.thainumbers import NumThai 3 | 4 | 5 | num = NumThai() 6 | 7 | # List of (Latin alphabet, Thai) pairs: 8 | _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 9 | ('a', 'เอ'), 10 | ('b','บี'), 11 | ('c','ซี'), 12 | ('d','ดี'), 13 | ('e','อี'), 14 | ('f','เอฟ'), 15 | ('g','จี'), 16 | ('h','เอช'), 17 | ('i','ไอ'), 18 | ('j','เจ'), 19 | ('k','เค'), 20 | ('l','แอล'), 21 | ('m','เอ็ม'), 22 | ('n','เอ็น'), 23 | ('o','โอ'), 24 | ('p','พี'), 25 | ('q','คิว'), 26 | ('r','แอร์'), 27 | ('s','เอส'), 28 | ('t','ที'), 29 | ('u','ยู'), 30 | ('v','วี'), 31 | ('w','ดับเบิลยู'), 32 | ('x','เอ็กซ์'), 33 | ('y','วาย'), 34 | ('z','ซี') 35 | ]] 36 | 37 | 38 | def num_to_thai(text): 39 | return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text) 40 | 41 | def latin_to_thai(text): 42 | for regex, replacement in _latin_to_thai: 43 | text = re.sub(regex, replacement, text) 44 | return text 45 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | import regex as re 12 | 13 | MATPLOTLIB_FLAG = False 14 | 15 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 16 | logger = logging 17 | 18 | 19 | 20 | zh_pattern = re.compile(r'[\u4e00-\u9fa5]') 21 | en_pattern = re.compile(r'[a-zA-Z]') 22 | jp_pattern = re.compile(r'[\u3040-\u30ff\u31f0-\u31ff]') 23 | kr_pattern = re.compile(r'[\uac00-\ud7af\u1100-\u11ff\u3130-\u318f\ua960-\ua97f]') 24 | num_pattern=re.compile(r'[0-9]') 25 | comma=r"(?<=[.。!!??;;,,、::'\"‘“”’()()《》「」~——])" #向前匹配但固定长度 26 | tags={'ZH':'[ZH]','EN':'[EN]','JP':'[JA]','KR':'[KR]'} 27 | 28 | def tag_cjke(text): 29 | '''为中英日韩加tag,中日正则分不开,故先分句分离中日再识别,以应对大部分情况''' 30 | sentences = re.split(r"([.。!!??;;,,、::'\"‘“”’()()【】《》「」~——]+ *(?![0-9]))", text) #分句,排除小数点 31 | sentences.append("") 32 | sentences = ["".join(i) for i in zip(sentences[0::2],sentences[1::2])] 33 | # print(sentences) 34 | prev_lang=None 35 | tagged_text = "" 36 | for s in sentences: 37 | #全为符号跳过 38 | nu = re.sub(r'[\s\p{P}]+', '', s, flags=re.U).strip() 39 | if len(nu)==0: 40 | continue 41 | s = re.sub(r'[()()《》「」【】‘“”’]+', '', s) 42 | jp=re.findall(jp_pattern, s) 43 | #本句含日语字符判断为日语 44 | if len(jp)>0: 45 | prev_lang,tagged_jke=tag_jke(s,prev_lang) 46 | tagged_text +=tagged_jke 47 | else: 48 | prev_lang,tagged_cke=tag_cke(s,prev_lang) 49 | tagged_text +=tagged_cke 50 | return tagged_text 51 | 52 | def tag_jke(text,prev_sentence=None): 53 | '''为英日韩加tag''' 54 | # 初始化标记变量 55 | tagged_text = "" 56 | prev_lang = None 57 | tagged=0 58 | # 遍历文本 59 | for char in text: 60 | # 判断当前字符属于哪种语言 61 | if jp_pattern.match(char): 62 | lang = "JP" 63 | elif zh_pattern.match(char): 64 | lang = "JP" 65 | elif kr_pattern.match(char): 66 | lang = "KR" 67 | elif en_pattern.match(char): 68 | lang = "EN" 69 | # elif num_pattern.match(char): 70 | # lang = prev_sentence 71 | else: 72 | lang = None 73 | tagged_text += char 74 | continue 75 | # 如果当前语言与上一个语言不同,就添加标记 76 | if lang != prev_lang: 77 | tagged=1 78 | if prev_lang==None: # 开头 79 | tagged_text =tags[lang]+tagged_text 80 | else: 81 | tagged_text =tagged_text+tags[prev_lang]+tags[lang] 82 | 83 | # 重置标记变量 84 | prev_lang = lang 85 | 86 | # 添加当前字符到标记文本中 87 | tagged_text += char 88 | 89 | # 在最后一个语言的结尾添加对应的标记 90 | if prev_lang: 91 | tagged_text += tags[prev_lang] 92 | if not tagged: 93 | prev_lang=prev_sentence 94 | tagged_text =tags[prev_lang]+tagged_text+tags[prev_lang] 95 | 96 | return prev_lang,tagged_text 97 | 98 | def tag_cke(text,prev_sentence=None): 99 | '''为中英韩加tag''' 100 | # 初始化标记变量 101 | tagged_text = "" 102 | prev_lang = None 103 | # 是否全略过未标签 104 | tagged=0 105 | 106 | # 遍历文本 107 | for char in text: 108 | # 判断当前字符属于哪种语言 109 | if zh_pattern.match(char): 110 | lang = "ZH" 111 | elif kr_pattern.match(char): 112 | lang = "KR" 113 | elif en_pattern.match(char): 114 | lang = "EN" 115 | # elif num_pattern.match(char): 116 | # lang = prev_sentence 117 | else: 118 | # 略过 119 | lang = None 120 | tagged_text += char 121 | continue 122 | 123 | # 如果当前语言与上一个语言不同,添加标记 124 | if lang != prev_lang: 125 | tagged=1 126 | if prev_lang==None: # 开头 127 | tagged_text =tags[lang]+tagged_text 128 | else: 129 | tagged_text =tagged_text+tags[prev_lang]+tags[lang] 130 | 131 | # 重置标记变量 132 | prev_lang = lang 133 | 134 | # 添加当前字符到标记文本中 135 | tagged_text += char 136 | 137 | # 在最后一个语言的结尾添加对应的标记 138 | if prev_lang: 139 | tagged_text += tags[prev_lang] 140 | # 未标签则继承上一句标签 141 | if tagged==0: 142 | prev_lang=prev_sentence 143 | tagged_text =tags[prev_lang]+tagged_text+tags[prev_lang] 144 | return prev_lang,tagged_text 145 | 146 | 147 | 148 | def load_checkpoint(checkpoint_path, model, optimizer=None, drop_speaker_emb=False): 149 | assert os.path.isfile(checkpoint_path) 150 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 151 | iteration = checkpoint_dict['iteration'] 152 | learning_rate = checkpoint_dict['learning_rate'] 153 | if optimizer is not None: 154 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 155 | saved_state_dict = checkpoint_dict['model'] 156 | if hasattr(model, 'module'): 157 | state_dict = model.module.state_dict() 158 | else: 159 | state_dict = model.state_dict() 160 | new_state_dict = {} 161 | for k, v in state_dict.items(): 162 | try: 163 | if k == 'emb_g.weight': 164 | if drop_speaker_emb: 165 | new_state_dict[k] = v 166 | continue 167 | v[:saved_state_dict[k].shape[0], :] = saved_state_dict[k] 168 | new_state_dict[k] = v 169 | else: 170 | new_state_dict[k] = saved_state_dict[k] 171 | except: 172 | logger.info("%s is not in the checkpoint" % k) 173 | new_state_dict[k] = v 174 | if hasattr(model, 'module'): 175 | model.module.load_state_dict(new_state_dict) 176 | else: 177 | model.load_state_dict(new_state_dict) 178 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 179 | checkpoint_path, iteration)) 180 | return model, optimizer, learning_rate, iteration 181 | 182 | 183 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 184 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 185 | iteration, checkpoint_path)) 186 | if hasattr(model, 'module'): 187 | state_dict = model.module.state_dict() 188 | else: 189 | state_dict = model.state_dict() 190 | torch.save({'model': state_dict, 191 | 'iteration': iteration, 192 | 'optimizer': optimizer.state_dict() if optimizer is not None else None, 193 | 'learning_rate': learning_rate}, checkpoint_path) 194 | 195 | 196 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 197 | for k, v in scalars.items(): 198 | writer.add_scalar(k, v, global_step) 199 | for k, v in histograms.items(): 200 | writer.add_histogram(k, v, global_step) 201 | for k, v in images.items(): 202 | writer.add_image(k, v, global_step, dataformats='HWC') 203 | for k, v in audios.items(): 204 | writer.add_audio(k, v, global_step, audio_sampling_rate) 205 | 206 | 207 | def extract_digits(f): 208 | digits = "".join(filter(str.isdigit, f)) 209 | return int(digits) if digits else -1 210 | 211 | 212 | def latest_checkpoint_path(dir_path, regex="G_[0-9]*.pth"): 213 | f_list = glob.glob(os.path.join(dir_path, regex)) 214 | f_list.sort(key=lambda f: extract_digits(f)) 215 | x = f_list[-1] 216 | print(f"latest_checkpoint_path:{x}") 217 | return x 218 | 219 | 220 | def oldest_checkpoint_path(dir_path, regex="G_[0-9]*.pth", preserved=4): 221 | f_list = glob.glob(os.path.join(dir_path, regex)) 222 | f_list.sort(key=lambda f: extract_digits(f)) 223 | if len(f_list) > preserved: 224 | x = f_list[0] 225 | print(f"oldest_checkpoint_path:{x}") 226 | return x 227 | return "" 228 | 229 | 230 | def plot_spectrogram_to_numpy(spectrogram): 231 | global MATPLOTLIB_FLAG 232 | if not MATPLOTLIB_FLAG: 233 | import matplotlib 234 | matplotlib.use("Agg") 235 | MATPLOTLIB_FLAG = True 236 | mpl_logger = logging.getLogger('matplotlib') 237 | mpl_logger.setLevel(logging.WARNING) 238 | import matplotlib.pylab as plt 239 | import numpy as np 240 | 241 | fig, ax = plt.subplots(figsize=(10, 2)) 242 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 243 | interpolation='none') 244 | plt.colorbar(im, ax=ax) 245 | plt.xlabel("Frames") 246 | plt.ylabel("Channels") 247 | plt.tight_layout() 248 | 249 | fig.canvas.draw() 250 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 251 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 252 | plt.close() 253 | return data 254 | 255 | 256 | def plot_alignment_to_numpy(alignment, info=None): 257 | global MATPLOTLIB_FLAG 258 | if not MATPLOTLIB_FLAG: 259 | import matplotlib 260 | matplotlib.use("Agg") 261 | MATPLOTLIB_FLAG = True 262 | mpl_logger = logging.getLogger('matplotlib') 263 | mpl_logger.setLevel(logging.WARNING) 264 | import matplotlib.pylab as plt 265 | import numpy as np 266 | 267 | fig, ax = plt.subplots(figsize=(6, 4)) 268 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 269 | interpolation='none') 270 | fig.colorbar(im, ax=ax) 271 | xlabel = 'Decoder timestep' 272 | if info is not None: 273 | xlabel += '\n\n' + info 274 | plt.xlabel(xlabel) 275 | plt.ylabel('Encoder timestep') 276 | plt.tight_layout() 277 | 278 | fig.canvas.draw() 279 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 280 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 281 | plt.close() 282 | return data 283 | 284 | 285 | def load_wav_to_torch(full_path): 286 | sampling_rate, data = read(full_path) 287 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 288 | 289 | 290 | def load_filepaths_and_text(filename, split="|"): 291 | with open(filename, encoding='utf-8') as f: 292 | filepaths_and_text = [line.strip().split(split) for line in f] 293 | return filepaths_and_text 294 | 295 | 296 | def str2bool(v): 297 | if isinstance(v, bool): 298 | return v 299 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 300 | return True 301 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 302 | return False 303 | else: 304 | raise argparse.ArgumentTypeError('Boolean value expected.') 305 | 306 | 307 | def get_hparams(init=True): 308 | parser = argparse.ArgumentParser() 309 | parser.add_argument('-c', '--config', type=str, default="./configs/modified_finetune_speaker.json", 310 | help='JSON file for configuration') 311 | parser.add_argument('-m', '--model', type=str, default="pretrained_models", 312 | help='Model name') 313 | parser.add_argument('-n', '--max_epochs', type=int, default=50, 314 | help='finetune epochs') 315 | parser.add_argument('--cont', type=str2bool, default=False, help='whether to continue training on the latest checkpoint') 316 | parser.add_argument('--drop_speaker_embed', type=str2bool, default=False, help='whether to drop existing characters') 317 | parser.add_argument('--train_with_pretrained_model', type=str2bool, default=True, 318 | help='whether to train with pretrained model') 319 | parser.add_argument('--preserved', type=int, default=4, 320 | help='Number of preserved models') 321 | 322 | args = parser.parse_args() 323 | model_dir = os.path.join("./", args.model) 324 | 325 | if not os.path.exists(model_dir): 326 | os.makedirs(model_dir) 327 | 328 | config_path = args.config 329 | config_save_path = os.path.join(model_dir, "config.json") 330 | if init: 331 | with open(config_path, "r") as f: 332 | data = f.read() 333 | with open(config_save_path, "w") as f: 334 | f.write(data) 335 | else: 336 | with open(config_save_path, "r") as f: 337 | data = f.read() 338 | config = json.loads(data) 339 | 340 | hparams = HParams(**config) 341 | hparams.model_dir = model_dir 342 | hparams.max_epochs = args.max_epochs 343 | hparams.cont = args.cont 344 | hparams.drop_speaker_embed = args.drop_speaker_embed 345 | hparams.train_with_pretrained_model = args.train_with_pretrained_model 346 | hparams.preserved = args.preserved 347 | return hparams 348 | 349 | 350 | def get_hparams_from_dir(model_dir): 351 | config_save_path = os.path.join(model_dir, "config.json") 352 | with open(config_save_path, "r") as f: 353 | data = f.read() 354 | config = json.loads(data) 355 | 356 | hparams = HParams(**config) 357 | hparams.model_dir = model_dir 358 | return hparams 359 | 360 | 361 | def get_hparams_from_file(config_path): 362 | with open(config_path, "r", encoding="utf-8") as f: 363 | data = f.read() 364 | config = json.loads(data) 365 | 366 | hparams = HParams(**config) 367 | return hparams 368 | 369 | 370 | def check_git_hash(model_dir): 371 | source_dir = os.path.dirname(os.path.realpath(__file__)) 372 | if not os.path.exists(os.path.join(source_dir, ".git")): 373 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 374 | source_dir 375 | )) 376 | return 377 | 378 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 379 | 380 | path = os.path.join(model_dir, "githash") 381 | if os.path.exists(path): 382 | saved_hash = open(path).read() 383 | if saved_hash != cur_hash: 384 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 385 | saved_hash[:8], cur_hash[:8])) 386 | else: 387 | open(path, "w").write(cur_hash) 388 | 389 | 390 | def get_logger(model_dir, filename="train.log"): 391 | global logger 392 | logger = logging.getLogger(os.path.basename(model_dir)) 393 | logger.setLevel(logging.DEBUG) 394 | 395 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 396 | if not os.path.exists(model_dir): 397 | os.makedirs(model_dir) 398 | h = logging.FileHandler(os.path.join(model_dir, filename),encoding="utf-8") 399 | h.setLevel(logging.DEBUG) 400 | h.setFormatter(formatter) 401 | logger.addHandler(h) 402 | return logger 403 | 404 | 405 | class HParams(): 406 | def __init__(self, **kwargs): 407 | for k, v in kwargs.items(): 408 | if type(v) == dict: 409 | v = HParams(**v) 410 | self[k] = v 411 | 412 | def keys(self): 413 | return self.__dict__.keys() 414 | 415 | def items(self): 416 | return self.__dict__.items() 417 | 418 | def values(self): 419 | return self.__dict__.values() 420 | 421 | def __len__(self): 422 | return len(self.__dict__) 423 | 424 | def __getitem__(self, key): 425 | return getattr(self, key) 426 | 427 | def __setitem__(self, key, value): 428 | return setattr(self, key, value) 429 | 430 | def __contains__(self, key): 431 | return key in self.__dict__ 432 | 433 | def __repr__(self): 434 | return self.__dict__.__repr__() --------------------------------------------------------------------------------