├── LICENSE ├── README.md ├── configs ├── requirements.txt └── singing_base.yaml ├── pit_export.py ├── pit_train.py ├── pitch ├── __init__.py ├── base.py ├── data_utils.py ├── diffusion.py ├── models.py └── utils.py ├── pitch_extend ├── dataloader.py ├── plotting.py ├── train.py ├── validation.py └── writer.py ├── resource ├── vising_loss.png ├── vising_mel.png └── vising_sample.wav ├── svs ├── __init__.py ├── midi-HZ.scp ├── midi-note.scp ├── phone_map.py └── phone_uv.py ├── svs_export.py ├── svs_infer.py ├── svs_infer.txt ├── svs_infer_pitch.py ├── svs_song.py ├── svs_song.txt ├── svs_song_pitch.py ├── svs_train.py ├── util ├── __init__.py ├── generate_index.py ├── generate_label.py └── resample.py ├── vits ├── __init__.py ├── attentions.py ├── commons.py ├── data_utils.py ├── losses.py ├── models.py ├── modules.py ├── spectrogram.py └── utils.py ├── vits_decoder ├── __init__.py ├── alias │ ├── __init__.py │ ├── act.py │ ├── filter.py │ └── resample.py ├── bigv.py ├── discriminator.py ├── generator.py ├── mpd.py ├── mrd.py ├── msd.py └── nsf.py └── vits_extend ├── __init__.py ├── dataloader.py ├── plotting.py ├── stft.py ├── stft_loss.py ├── train.py ├── validation.py └── writer.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | <div align="center"> 2 | <h1> Variational Inference with adversarial learning for end-to-end Singing Voice Synthesis </h1> 3 | 4 | Different from VISinger, It is just VITS without MAS and DurationPredictor. 5 | 6 | 作为一个用于学习的项目,就这样了:Pitch的预测是需要改进的地方 7 | 8 |  9 | 10 |  11 | 12 | </div> 13 | 14 | **Pitch and Duration will be developed as add-on!** 15 | 16 | # 训练步骤 17 | 18 | - 1 下载数据 segments.zip,并解压 19 | 20 | ``` 21 | segments 22 | |-- test.txt 23 | |-- train.txt 24 | |-- transcriptions.txt 25 | `-- wavs 26 | |-- 2001000001.wav 27 | |-- 2001000002.wav 28 | |-- 2001000003.wav 29 | ``` 30 | 31 | - 2 转换采样率: 本项目采用32KHz 32 | ``` 33 | python util/resample.py -w segments/wavs/ -o data_svs/wavs -s 32000 34 | ``` 35 | 36 | - 3 生成数据标注 37 | ``` 38 | python util/generate_label.py --config configs/singing_base.yaml --data data_svs/ --file segments/transcriptions.txt 39 | ``` 40 | 41 | data_svs/labels.txt,内容格式:wave path|label path|score path|pitch path|slurs path 42 | 43 | - 3 划分训练索引 44 | ``` 45 | python util/generate_label.py --file data_svs/labels.txt 46 | ``` 47 | 48 | 生成 filelists/singing_train.txt 和 filelists/singing_valid.txt 49 | 50 | - 4 启动训练 51 | ``` 52 | python svs_train.py -c configs/singing_base.yaml -n vits_svs 53 | ``` 54 | 55 | - 5 训练Pitch 56 | ``` 57 | python pit_train.py -c configs/singing_base.yaml -n pitch 58 | ``` 59 | 60 | # 推理验证 61 | 62 | - 0 模型导出 63 | ``` 64 | python svs_export.py --config configs/singing_base.yaml --model chkpt/vits_svs/vits_svs_****.pt 65 | ``` 66 | 67 | - 1 推理验证: F0根据乐谱生成 68 | ``` 69 | python svs_infer.py --config configs/singing_base.yaml --model svs_opencpop.pt 70 | ``` 71 | 72 | - 2 完整歌曲合成([使用release模型](https://github.com/PlayVoice/VI-SVS/releases/tag/0.0.3)) 73 | ``` 74 | python svs_song.py --config configs/singing_base.yaml --model svs_opencpop.pt 75 | ``` 76 | 77 | # 推理验证,使用Pitch预测,效果不佳 78 | 79 | - 0 模型导出 80 | ``` 81 | python svs_export.py --config configs/singing_base.yaml --model chkpt/vits_svs/vits_svs_****.pt 82 | ``` 83 | 84 | ``` 85 | python pit_export.py --config configs/singing_base.yaml --model chkpt/pitch/pitch_****.pt 86 | ``` 87 | 88 | - 1 推理验证 89 | ``` 90 | python svs_infer_pitch.py --config configs/singing_base.yaml --model svs_opencpop.pt --pitch pit_opencpop.pt 91 | ``` 92 | 93 | - 2 完整歌曲合成([使用release模型](https://github.com/PlayVoice/VI-SVS/releases/tag/0.0.3)) 94 | ``` 95 | python svs_song_pitch.py --config configs/singing_base.yaml --model svs_opencpop.pt --pitch pit_opencpop.pt 96 | ``` 97 | 98 | # 数据 99 | 100 | https://wenet.org.cn/opencpop/ 101 | 102 | # 歌声合成参考 103 | 104 | https://github.com/SJTMusicTeam/Muskits 105 | 106 | https://github.com/MoonInTheRiver/DiffSinger 107 | 108 | [VISinger: Variational Inference with Adversarial Learning for End-to-End Singing Voice Synthesis](https://arxiv.org/abs/2110.08813) 109 | 110 | # 模型设计参考 111 | 112 | https://github.com/NVIDIA/BigVGAN 113 | 114 | https://github.com/jaywalnut310/vits 115 | 116 | https://github.com/mindslab-ai/univnet 117 | 118 | https://github.com/PlayVoice/so-vits-svc-5.0 119 | 120 | https://github.com/shivammehta25/Matcha-TTS 121 | 122 | [RoFormer: Enhanced Transformer with rotary position embedding](https://arxiv.org/abs/2104.09864) 123 | 124 | # Diffusion Pitch 125 | 126 | https://github.com/thuhcsi/DiffVar 127 | 128 | https://github.com/hayeong0/Diff-HierVC 129 | 130 | https://github.com/tonnetonne814/SiFi-VITS2-44100-Ja 131 | 132 | [Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech](https://arxiv.org/abs/2105.06337) 133 | 134 | # Diffusion Pitch of Diff-HierVC 135 |  136 | -------------------------------------------------------------------------------- /configs/requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | librosa==0.8.0 3 | matplotlib==3.3.1 4 | numpy==1.18.5 5 | phonemizer==2.2.1 6 | scipy==1.5.2 7 | tensorboard==2.3.0 8 | torch==1.6.0 9 | torchvision==0.7.0 10 | Unidecode==1.1.1 11 | -------------------------------------------------------------------------------- /configs/singing_base.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | model: "vits-svs" 3 | seed: 1234 4 | epochs: 10000 5 | learning_rate: 1e-4 6 | betas: [0.8, 0.99] 7 | lr_decay: 0.999875 8 | eps: 1e-9 9 | batch_size: 6 10 | c_stft: 9 11 | c_mel: 1. 12 | c_kl: 0.2 13 | port: 8001 14 | pretrain: "" 15 | ############################# 16 | data: 17 | training_files: "filelists/singing_train.txt" 18 | validation_files: "filelists/singing_valid.txt" 19 | segment_size: 8000 # WARNING: base on hop_length 20 | max_wav_value: 32768.0 21 | sampling_rate: 32000 22 | filter_length: 1024 23 | hop_length: 320 24 | win_length: 1024 25 | mel_channels: 100 26 | mel_fmin: 50.0 27 | mel_fmax: 16000.0 28 | ############################# 29 | vits: 30 | gin_channels: 0 31 | inter_channels: 192 32 | hidden_channels: 192 33 | filter_channels: 640 34 | ############################# 35 | gen: 36 | upsample_input: 192 37 | upsample_rates: [5,4,4,2,2] 38 | upsample_kernel_sizes: [15,8,8,4,4] 39 | upsample_initial_channel: 480 40 | resblock_kernel_sizes: [3,7,11] 41 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 42 | ############################# 43 | mpd: 44 | periods: [2,3,5,7,11] 45 | kernel_size: 5 46 | stride: 3 47 | use_spectral_norm: False 48 | lReLU_slope: 0.2 49 | ############################# 50 | mrd: 51 | resolutions: "[(1024, 120, 600), (2048, 240, 1200), (4096, 480, 2400), (512, 50, 240)]" # (filter_length, hop_length, win_length) 52 | use_spectral_norm: False 53 | lReLU_slope: 0.2 54 | ############################# 55 | log: 56 | info_interval: 100 57 | eval_interval: 1 58 | save_interval: 5 59 | num_audio: 6 60 | pth_dir: 'chkpt' 61 | log_dir: 'logs' 62 | keep_ckpts: 0 63 | ############################# 64 | dist_config: 65 | dist_backend: "nccl" 66 | dist_url: "tcp://localhost:54321" 67 | world_size: 1 68 | 69 | -------------------------------------------------------------------------------- /pit_export.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import torch 4 | import argparse 5 | 6 | from pitch.models import PitchDiffusion 7 | 8 | 9 | def load_model(checkpoint_path, model): 10 | assert os.path.isfile(checkpoint_path) 11 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 12 | saved_state_dict = checkpoint_dict["model_g"] 13 | if hasattr(model, "module"): 14 | state_dict = model.module.state_dict() 15 | else: 16 | state_dict = model.state_dict() 17 | new_state_dict = {} 18 | for k, v in state_dict.items(): 19 | try: 20 | new_state_dict[k] = saved_state_dict[k] 21 | except: 22 | new_state_dict[k] = v 23 | if hasattr(model, "module"): 24 | model.module.load_state_dict(new_state_dict) 25 | else: 26 | model.load_state_dict(new_state_dict) 27 | return model 28 | 29 | 30 | def save_model(model, checkpoint_path): 31 | if hasattr(model, 'module'): 32 | state_dict = model.module.state_dict() 33 | else: 34 | state_dict = model.state_dict() 35 | torch.save({'model_g': state_dict}, checkpoint_path) 36 | 37 | 38 | def main(args): 39 | model = PitchDiffusion() 40 | load_model(args.model, model) 41 | save_model(model, "pit_opencpop.pt") 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('-c', '--config', type=str, required=True, 47 | help="yaml file for config. will use hp_str from checkpoint if not given.") 48 | parser.add_argument('-m', '--model', type=str, required=True, 49 | help="path of checkpoint pt file for evaluation") 50 | args = parser.parse_args() 51 | 52 | main(args) 53 | -------------------------------------------------------------------------------- /pit_train.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import argparse 4 | import torch 5 | import torch.multiprocessing as mp 6 | from omegaconf import OmegaConf 7 | 8 | from pitch_extend.train import train 9 | 10 | torch.backends.cudnn.benchmark = True 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-c', '--config', type=str, required=True, 16 | help="yaml file for configuration") 17 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, 18 | help="path of checkpoint pt file to resume training") 19 | parser.add_argument('-n', '--name', type=str, required=True, 20 | help="name of the model for logging, saving checkpoint") 21 | args = parser.parse_args() 22 | 23 | hp = OmegaConf.load(args.config) 24 | with open(args.config, 'r') as f: 25 | hp_str = ''.join(f.readlines()) 26 | 27 | assert hp.data.hop_length == 320, \ 28 | 'hp.data.hop_length must be equal to 320, got %d' % hp.data.hop_length 29 | 30 | args.num_gpus = 0 31 | torch.manual_seed(hp.train.seed) 32 | if torch.cuda.is_available(): 33 | torch.cuda.manual_seed(hp.train.seed) 34 | args.num_gpus = torch.cuda.device_count() 35 | print('Batch size per GPU :', hp.train.batch_size) 36 | 37 | if args.num_gpus > 1: 38 | mp.spawn(train, nprocs=args.num_gpus, 39 | args=(args, args.checkpoint_path, hp, hp_str,)) 40 | else: 41 | train(0, args, args.checkpoint_path, hp, hp_str) 42 | else: 43 | print('No GPU find!') 44 | -------------------------------------------------------------------------------- /pitch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/pitch/__init__.py -------------------------------------------------------------------------------- /pitch/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BaseModule(torch.nn.Module): 6 | def __init__(self): 7 | super(BaseModule, self).__init__() 8 | 9 | @property 10 | def nparams(self): 11 | """ 12 | Returns number of trainable parameters of the module. 13 | """ 14 | num_params = 0 15 | for name, param in self.named_parameters(): 16 | if param.requires_grad: 17 | num_params += np.prod(param.detach().cpu().numpy().shape) 18 | return num_params 19 | 20 | 21 | def relocate_input(self, x: list): 22 | """ 23 | Relocates provided tensors to the same device set for the module. 24 | """ 25 | device = next(self.parameters()).device 26 | for i in range(len(x)): 27 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 28 | x[i] = x[i].to(device) 29 | return x 30 | -------------------------------------------------------------------------------- /pitch/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | from vits.utils import load_wav_to_torch 7 | from vits.spectrogram import spectrogram_torch 8 | from pitch.utils import fix_len_compatibility 9 | 10 | 11 | def load_filepaths(filename, split="|"): 12 | with open(filename, encoding='utf-8') as f: 13 | filepaths = [line.strip().split(split) for line in f] 14 | return filepaths 15 | 16 | 17 | class TextAudioLoader(torch.utils.data.Dataset): 18 | """ 19 | 1) loads audio, text pairs 20 | 2) normalizes text and converts them to sequences of integers 21 | 3) computes spectrograms from audio files. 22 | """ 23 | 24 | def __init__(self, audiopaths_and_text, hparams): 25 | self.audiopaths_and_text = load_filepaths(audiopaths_and_text) 26 | self.max_wav_value = hparams.max_wav_value 27 | self.sampling_rate = hparams.sampling_rate 28 | self.filter_length = hparams.filter_length 29 | self.hop_length = hparams.hop_length 30 | self.win_length = hparams.win_length 31 | self.sampling_rate = hparams.sampling_rate 32 | self.min_text_len = getattr(hparams, "min_text_len", 1) 33 | self.max_text_len = getattr(hparams, "max_text_len", 5000) 34 | self._filter() 35 | print(f"~~~~~~~~~~~~~~~~~~~~~{len(self)}~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 36 | 37 | def _filter(self): 38 | """ 39 | Filter text & store spec lengths 40 | """ 41 | # Store spectrogram lengths for Bucketing 42 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 43 | # spec_length = wav_length // hop_length 44 | audiopaths_and_text_new = [] 45 | lengths = [] 46 | for audiopath, text, score, pitch, slur in self.audiopaths_and_text: 47 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 48 | wav_len = os.path.getsize(audiopath) // (2 * self.hop_length) 49 | if wav_len < 50: # no use short wave 50 | continue 51 | audiopaths_and_text_new.append([audiopath, text, score, pitch, slur]) 52 | lengths.append(wav_len) 53 | self.audiopaths_and_text = audiopaths_and_text_new 54 | self.lengths = lengths 55 | 56 | def get_audio_text_pair(self, audiopath_and_text): 57 | # separate filename and text 58 | file = audiopath_and_text[0] 59 | phone = audiopath_and_text[1] 60 | score = audiopath_and_text[2] 61 | pitch = audiopath_and_text[3] 62 | slurs = audiopath_and_text[4] 63 | 64 | phone, score, pitch, slurs = self.get_labels(phone, score, pitch, slurs) 65 | spec, wav = self.get_audio(file) 66 | 67 | len_phone = phone.size()[0] 68 | len_spec = spec.size()[-1] 69 | 70 | if len_phone != len_spec: 71 | # print("**************CareFull*******************") 72 | # print(f"filepath={audiopath_and_text[0]}") 73 | # print(f"len_text={len_phone}") 74 | # print(f"len_spec={len_spec}") 75 | if len_phone > len_spec: 76 | print(file) 77 | print("len_phone", len_phone) 78 | print("len_spec", len_spec) 79 | assert len_phone < len_spec 80 | len_min = min(len_phone, len_spec) 81 | len_wav = len_min * self.hop_length 82 | # print(wav.size()) 83 | # print(f"len_min={len_min}") 84 | # print(f"len_wav={len_wav}") 85 | spec = spec[:, :len_min] 86 | wav = wav[:, :len_wav] 87 | return (phone, score, pitch, slurs, spec, wav) 88 | 89 | def get_labels(self, phone, score, pitch, slurs): 90 | phone = np.load(phone) 91 | score = np.load(score) 92 | pitch = np.load(pitch) 93 | slurs = np.load(slurs) 94 | phone = torch.LongTensor(phone) 95 | score = torch.LongTensor(score) 96 | pitch = torch.FloatTensor(pitch) 97 | slurs = torch.LongTensor(slurs) 98 | return phone, score, pitch, slurs 99 | 100 | def get_audio(self, filename): 101 | audio, sampling_rate = load_wav_to_torch(filename) 102 | if sampling_rate != self.sampling_rate: 103 | raise ValueError( 104 | "{} {} SR doesn't match target {} SR".format( 105 | sampling_rate, self.sampling_rate 106 | ) 107 | ) 108 | audio_norm = audio / self.max_wav_value 109 | audio_norm = audio_norm.unsqueeze(0) 110 | spec_filename = filename.replace(".wav", ".spec.pt") 111 | if os.path.exists(spec_filename): 112 | spec = torch.load(spec_filename) 113 | else: 114 | spec = spectrogram_torch( 115 | audio_norm, 116 | self.filter_length, 117 | self.sampling_rate, 118 | self.hop_length, 119 | self.win_length, 120 | center=False, 121 | ) 122 | spec = torch.squeeze(spec, 0) 123 | torch.save(spec, spec_filename) 124 | return spec, audio_norm 125 | 126 | def __getitem__(self, index): 127 | return self.get_audio_text_pair(self.audiopaths_and_text[index]) 128 | 129 | def __len__(self): 130 | return len(self.audiopaths_and_text) 131 | 132 | 133 | class TextAudioCollate: 134 | """Zero-pads model inputs and targets""" 135 | 136 | def __init__(self, return_ids=False): 137 | self.return_ids = return_ids 138 | 139 | def __call__(self, batch): 140 | """Collate's training batch from normalized text and aduio 141 | PARAMS 142 | ------ 143 | batch: [text_normalized, spec_normalized, wav_normalized] 144 | """ 145 | # Right zero-pad all one-hot text sequences to max input length 146 | _, ids_sorted_decreasing = torch.sort( 147 | torch.LongTensor([x[4].size(1) for x in batch]), dim=0, descending=True 148 | ) 149 | 150 | max_phone_len = max([len(x[0]) for x in batch]) 151 | # For Unet 152 | max_phone_len = fix_len_compatibility(max_phone_len) 153 | 154 | phone_lengths = torch.LongTensor(len(batch)) 155 | phone_padded = torch.LongTensor(len(batch), max_phone_len) 156 | score_padded = torch.LongTensor(len(batch), max_phone_len) 157 | pitch_padded = torch.FloatTensor(len(batch), max_phone_len) 158 | slurs_padded = torch.LongTensor(len(batch), max_phone_len) 159 | phone_padded.zero_() 160 | score_padded.zero_() 161 | pitch_padded.zero_() 162 | slurs_padded.zero_() 163 | 164 | for i in range(len(ids_sorted_decreasing)): 165 | row = batch[ids_sorted_decreasing[i]] 166 | 167 | phone = row[0] 168 | phone_padded[i, : phone.size(0)] = phone 169 | phone_lengths[i] = phone.size(0) 170 | 171 | score = row[1] 172 | score_padded[i, : score.size(0)] = score 173 | 174 | pitch = row[2] 175 | pitch_padded[i, : pitch.size(0)] = pitch 176 | 177 | slurs = row[3] 178 | slurs_padded[i, : slurs.size(0)] = slurs 179 | 180 | return ( 181 | phone_padded, 182 | phone_lengths, 183 | score_padded, 184 | pitch_padded, 185 | slurs_padded, 186 | ) 187 | 188 | 189 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 190 | """ 191 | Maintain similar input lengths in a batch. 192 | Length groups are specified by boundaries. 193 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 194 | 195 | It removes samples which are not included in the boundaries. 196 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 197 | """ 198 | 199 | def __init__( 200 | self, 201 | dataset, 202 | batch_size, 203 | boundaries, 204 | num_replicas=None, 205 | rank=None, 206 | shuffle=True, 207 | ): 208 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 209 | self.lengths = dataset.lengths 210 | self.batch_size = batch_size 211 | self.boundaries = boundaries 212 | 213 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 214 | self.total_size = sum(self.num_samples_per_bucket) 215 | self.num_samples = self.total_size // self.num_replicas 216 | 217 | def _create_buckets(self): 218 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 219 | for i in range(len(self.lengths)): 220 | length = self.lengths[i] 221 | idx_bucket = self._bisect(length) 222 | if idx_bucket != -1: 223 | buckets[idx_bucket].append(i) 224 | 225 | for i in range(len(buckets) - 1, 0, -1): 226 | if len(buckets[i]) == 0: 227 | buckets.pop(i) 228 | self.boundaries.pop(i + 1) 229 | 230 | num_samples_per_bucket = [] 231 | for i in range(len(buckets)): 232 | len_bucket = len(buckets[i]) 233 | total_batch_size = self.num_replicas * self.batch_size 234 | rem = ( 235 | total_batch_size - (len_bucket % total_batch_size) 236 | ) % total_batch_size 237 | num_samples_per_bucket.append(len_bucket + rem) 238 | return buckets, num_samples_per_bucket 239 | 240 | def __iter__(self): 241 | # deterministically shuffle based on epoch 242 | g = torch.Generator() 243 | g.manual_seed(self.epoch) 244 | 245 | indices = [] 246 | if self.shuffle: 247 | for bucket in self.buckets: 248 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 249 | else: 250 | for bucket in self.buckets: 251 | indices.append(list(range(len(bucket)))) 252 | 253 | batches = [] 254 | for i in range(len(self.buckets)): 255 | bucket = self.buckets[i] 256 | len_bucket = len(bucket) 257 | if (len_bucket == 0): 258 | continue 259 | ids_bucket = indices[i] 260 | num_samples_bucket = self.num_samples_per_bucket[i] 261 | 262 | # add extra samples to make it evenly divisible 263 | rem = num_samples_bucket - len_bucket 264 | ids_bucket = ( 265 | ids_bucket 266 | + ids_bucket * (rem // len_bucket) 267 | + ids_bucket[: (rem % len_bucket)] 268 | ) 269 | 270 | # subsample 271 | ids_bucket = ids_bucket[self.rank:: self.num_replicas] 272 | 273 | # batching 274 | for j in range(len(ids_bucket) // self.batch_size): 275 | batch = [ 276 | bucket[idx] 277 | for idx in ids_bucket[ 278 | j * self.batch_size: (j + 1) * self.batch_size 279 | ] 280 | ] 281 | batches.append(batch) 282 | 283 | if self.shuffle: 284 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 285 | batches = [batches[i] for i in batch_ids] 286 | self.batches = batches 287 | 288 | assert len(self.batches) * self.batch_size == self.num_samples 289 | return iter(self.batches) 290 | 291 | def _bisect(self, x, lo=0, hi=None): 292 | if hi is None: 293 | hi = len(self.boundaries) - 1 294 | 295 | if hi > lo: 296 | mid = (hi + lo) // 2 297 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 298 | return mid 299 | elif x <= self.boundaries[mid]: 300 | return self._bisect(x, lo, mid) 301 | else: 302 | return self._bisect(x, mid + 1, hi) 303 | else: 304 | return -1 305 | 306 | def __len__(self): 307 | return self.num_samples // self.batch_size 308 | -------------------------------------------------------------------------------- /pitch/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from einops import rearrange 4 | from pitch.base import BaseModule 5 | 6 | 7 | class Mish(BaseModule): 8 | def forward(self, x): 9 | return x * torch.tanh(torch.nn.functional.softplus(x)) 10 | 11 | 12 | class Rezero(BaseModule): 13 | def __init__(self, fn): 14 | super(Rezero, self).__init__() 15 | self.fn = fn 16 | self.g = torch.nn.Parameter(torch.zeros(1)) 17 | 18 | def forward(self, x): 19 | return self.fn(x) * self.g 20 | 21 | 22 | class Block(BaseModule): 23 | def __init__(self, dim, dim_out, groups=8): 24 | super(Block, self).__init__() 25 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, 26 | padding=1), torch.nn.GroupNorm( 27 | groups, dim_out), Mish()) 28 | 29 | def forward(self, x, mask): 30 | output = self.block(x * mask) 31 | return output * mask 32 | 33 | 34 | class ResnetBlock(BaseModule): 35 | def __init__(self, dim, dim_out, time_emb_dim, groups=8): 36 | super(ResnetBlock, self).__init__() 37 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 38 | dim_out)) 39 | 40 | self.block1 = Block(dim, dim_out, groups=groups) 41 | self.block2 = Block(dim_out, dim_out, groups=groups) 42 | if dim != dim_out: 43 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) 44 | else: 45 | self.res_conv = torch.nn.Identity() 46 | 47 | def forward(self, x, mask, time_emb): 48 | h = self.block1(x, mask) 49 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 50 | h = self.block2(h, mask) 51 | output = h + self.res_conv(x * mask) 52 | return output 53 | 54 | 55 | class LinearAttention(BaseModule): 56 | def __init__(self, dim, heads=4, dim_head=32): 57 | super(LinearAttention, self).__init__() 58 | self.heads = heads 59 | hidden_dim = dim_head * heads 60 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 61 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) 62 | 63 | def forward(self, x): 64 | b, c, h, w = x.shape 65 | qkv = self.to_qkv(x) 66 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', 67 | heads = self.heads, qkv=3) 68 | k = k.softmax(dim=-1) 69 | context = torch.einsum('bhdn,bhen->bhde', k, v) 70 | out = torch.einsum('bhde,bhdn->bhen', context, q) 71 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', 72 | heads=self.heads, h=h, w=w) 73 | return self.to_out(out) 74 | 75 | 76 | class Residual(BaseModule): 77 | def __init__(self, fn): 78 | super(Residual, self).__init__() 79 | self.fn = fn 80 | 81 | def forward(self, x, *args, **kwargs): 82 | output = self.fn(x, *args, **kwargs) + x 83 | return output 84 | 85 | 86 | class SinusoidalPosEmb(BaseModule): 87 | def __init__(self, dim): 88 | super(SinusoidalPosEmb, self).__init__() 89 | self.dim = dim 90 | 91 | def forward(self, x, scale=1000): 92 | device = x.device 93 | half_dim = self.dim // 2 94 | emb = math.log(10000) / (half_dim - 1) 95 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 96 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 97 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 98 | return emb 99 | 100 | 101 | class GradLogPEstimator2d(BaseModule): 102 | def __init__(self, n_feat, n_cond, dim, dim_mults=(1, 2, 4), groups=8, pe_scale=1000): 103 | super(GradLogPEstimator2d, self).__init__() 104 | self.dim = dim 105 | self.dim_mults = dim_mults 106 | self.groups = groups 107 | self.pe_scale = pe_scale 108 | 109 | self.cond = torch.nn.Sequential(torch.nn.Conv1d(n_cond, dim * 4, 1), Mish(), 110 | torch.nn.Conv1d(dim * 4, n_feat, 1)) 111 | self.time_pos_emb = SinusoidalPosEmb(dim) 112 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), 113 | torch.nn.Linear(dim * 4, dim)) 114 | 115 | dims = [2 + 1, *map(lambda m: dim * m, dim_mults)] 116 | in_out = list(zip(dims[:-1], dims[1:])) 117 | self.downs = torch.nn.ModuleList([]) 118 | self.ups = torch.nn.ModuleList([]) 119 | num_resolutions = len(in_out) 120 | 121 | for ind, (dim_in, dim_out) in enumerate(in_out): # 2 downs 122 | is_last = ind >= (num_resolutions - 1) 123 | self.downs.append(torch.nn.ModuleList([ 124 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim), 125 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim), 126 | Residual(Rezero(LinearAttention(dim_out))), 127 | torch.nn.Identity()])) 128 | 129 | mid_dim = dims[-1] 130 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 131 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 132 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) 133 | 134 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): # 2 ups 135 | self.ups.append(torch.nn.ModuleList([ 136 | ResnetBlock(dim_out, dim_in, time_emb_dim=dim), 137 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim), 138 | Residual(Rezero(LinearAttention(dim_in))), 139 | torch.nn.Identity()])) 140 | self.final_block = Block(dim, dim) 141 | self.final_conv = torch.nn.Conv2d(dim, 1, 1) 142 | 143 | def forward(self, x, mask, mu, c, t): 144 | 145 | t = self.time_pos_emb(t, scale=self.pe_scale) 146 | t = self.mlp(t) 147 | c = self.cond(c) 148 | 149 | x = torch.stack([mu, x, c], 1) 150 | mask = mask.unsqueeze(1) 151 | 152 | for resnet1, resnet2, attn, downsample in self.downs: 153 | x = resnet1(x, mask, t) 154 | x = resnet2(x, mask, t) 155 | x = attn(x) 156 | x = downsample(x * mask) 157 | 158 | x = self.mid_block1(x, mask, t) 159 | x = self.mid_attn(x) 160 | x = self.mid_block2(x, mask, t) 161 | 162 | for resnet1, resnet2, attn, upsample in self.ups: 163 | x = resnet1(x, mask, t) 164 | x = resnet2(x, mask, t) 165 | x = attn(x) 166 | x = upsample(x * mask) 167 | 168 | x = self.final_block(x, mask) 169 | output = self.final_conv(x * mask) 170 | 171 | return (output * mask).squeeze(1) 172 | 173 | 174 | class Diffusion(BaseModule): 175 | def __init__(self, n_feat, n_cond, dim, beta_min=0.05, beta_max=20, pe_scale=1000): 176 | super(Diffusion, self).__init__() 177 | self.estimator = GradLogPEstimator2d(n_feat, n_cond, dim, pe_scale=pe_scale) 178 | self.n_feat = n_feat 179 | self.beta_min = beta_min 180 | self.beta_max = beta_max 181 | 182 | def get_beta(self, t): 183 | beta = self.beta_min + (self.beta_max - self.beta_min) * t 184 | return beta 185 | 186 | def get_gamma(self, s, t, p=1.0, use_torch=False): 187 | beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s) 188 | beta_integral *= (t - s) 189 | if use_torch: 190 | gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1) 191 | else: 192 | gamma = math.exp(-0.5 * p * beta_integral) 193 | return gamma 194 | 195 | def get_mu(self, s, t): 196 | a = self.get_gamma(s, t) 197 | b = 1.0 - self.get_gamma(0, s, p=2.0) 198 | c = 1.0 - self.get_gamma(0, t, p=2.0) 199 | return a * b / c 200 | 201 | def get_nu(self, s, t): 202 | a = self.get_gamma(0, s) 203 | b = 1.0 - self.get_gamma(s, t, p=2.0) 204 | c = 1.0 - self.get_gamma(0, t, p=2.0) 205 | return a * b / c 206 | 207 | def get_sigma(self, s, t): 208 | a = 1.0 - self.get_gamma(0, s, p=2.0) 209 | b = 1.0 - self.get_gamma(s, t, p=2.0) 210 | c = 1.0 - self.get_gamma(0, t, p=2.0) 211 | return math.sqrt(a * b / c) 212 | 213 | @torch.no_grad() 214 | def reverse_diffusion(self, z, mask, mu, mu_c, n_timesteps): 215 | h = 1.0 / n_timesteps 216 | xt = z * mask 217 | 218 | for i in range(n_timesteps): 219 | t = 1.0 - i * h 220 | time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 221 | beta_t = self.get_beta(t) 222 | 223 | kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) 224 | kappa /= (self.get_gamma(0, t) * beta_t * h) 225 | kappa -= 1.0 226 | omega = self.get_nu(t - h, t) / self.get_gamma(0, t) 227 | omega += self.get_mu(t - h, t) 228 | omega -= (0.5 * beta_t * h + 1.0) 229 | sigma = self.get_sigma(t - h, t) 230 | 231 | dxt = (mu - xt) * (0.5 * beta_t * h + omega) 232 | dxt -= (self.estimator(xt, mask, mu, mu_c, time)) * (1.0 + kappa) * (beta_t * h) 233 | dxt += torch.randn_like(z, device=z.device) * sigma 234 | xt = (xt - dxt) * mask 235 | 236 | return xt 237 | 238 | @torch.no_grad() 239 | def forward(self, z, mask, mu, mu_c, n_timesteps): 240 | return self.reverse_diffusion(z, mask, mu, mu_c, n_timesteps) 241 | 242 | # train: mel means f0_groun_truth 243 | def get_noise(self, t, beta_init, beta_term, cumulative=False): 244 | if cumulative: 245 | noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) 246 | else: 247 | noise = beta_init + (beta_term - beta_init)*t 248 | return noise 249 | 250 | def forward_diffusion(self, mel, mask, mu, t): 251 | time = t.unsqueeze(-1).unsqueeze(-1) 252 | cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True) 253 | mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) 254 | variance = 1.0 - torch.exp(-cum_noise) 255 | z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, 256 | requires_grad=False) 257 | xt = mean + z * torch.sqrt(variance) 258 | return xt * mask, z * mask 259 | 260 | def loss_t(self, mel, mask, mu, mu_c, t): 261 | xt, z = self.forward_diffusion(mel, mask, mu, t) 262 | time = t.unsqueeze(-1).unsqueeze(-1) 263 | cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True) 264 | noise_estimation = self.estimator(xt, mask, mu, mu_c, t) 265 | noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) 266 | loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feat) 267 | return loss, xt 268 | 269 | def compute_loss(self, mel, mask, mu, mu_c, offset=1e-5): 270 | t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False) 271 | t = torch.clamp(t, offset, 1.0 - offset) 272 | return self.loss_t(mel, mask, mu, mu_c, t) 273 | -------------------------------------------------------------------------------- /pitch/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from torch import nn 6 | from pitch.diffusion import Diffusion 7 | from pitch.utils import rand_ids_segments, slice_segments 8 | 9 | from vits import attentions 10 | from vits import commons 11 | 12 | 13 | class TextEncoder(nn.Module): 14 | def __init__(self, 15 | hidden_channels, 16 | filter_channels, 17 | n_heads, 18 | n_layers, 19 | kernel_size, 20 | p_dropout): 21 | super().__init__() 22 | self.hidden_channels = hidden_channels 23 | self.emb_phone = nn.Embedding(63, hidden_channels) # phone lables 24 | self.emb_score = nn.Embedding(128, hidden_channels) # pitch notes 25 | self.emb_slurs = nn.Embedding(2, hidden_channels) # phone slur 26 | nn.init.normal_(self.emb_phone.weight, 0.0, hidden_channels**-0.5) 27 | nn.init.normal_(self.emb_score.weight, 0.0, hidden_channels**-0.5) 28 | nn.init.normal_(self.emb_slurs.weight, 0.0, hidden_channels**-0.5) 29 | self.enc = attentions.Encoder( 30 | hidden_channels, 31 | filter_channels, 32 | n_heads, 33 | n_layers, 34 | kernel_size, 35 | p_dropout) 36 | self.proj = nn.Conv1d(hidden_channels, 2, 1) # pitch + uv 37 | 38 | def forward(self, phone, lengths, score, slurs): 39 | x = self.emb_phone(phone) + self.emb_score(score) + self.emb_slurs(slurs) 40 | x = x * math.sqrt(self.hidden_channels) # [b, t, h] 41 | x = torch.transpose(x, 1, -1) # [b, h, t] 42 | x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( 43 | x.dtype 44 | ) 45 | x = self.enc(x * x_mask, x_mask) 46 | c = x 47 | x = self.proj(x) 48 | return x, x_mask, c 49 | 50 | 51 | class PitchDiffusion(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.pit_encoder = TextEncoder(hidden_channels=192, filter_channels=768, 55 | n_heads=2, n_layers=5, kernel_size=5, p_dropout=0.1) 56 | self.decoder = Diffusion(2, 192, 64, beta_min=0.05, beta_max=20.0, pe_scale=1000) 57 | 58 | 59 | @torch.no_grad() 60 | def forward(self, phone, lengths, score, slurs, n_timesteps, temperature=1.0): 61 | # Encoder 62 | mu_x, mask_x, c = self.pit_encoder(phone, lengths, score, slurs) 63 | encoder_outputs = mu_x 64 | 65 | # Sample latent representation from terminal distribution N(mu_y, I) 66 | z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature 67 | # Generate sample by performing reverse dynamics 68 | decoder_outputs = self.decoder(z, mask_x, mu_x, c, n_timesteps) 69 | return encoder_outputs, decoder_outputs 70 | 71 | def compute_loss(self, phone, lengths, score, slurs, pitch, out_size): 72 | # Get encoder_outputs `mu_x` 73 | mu_x, mask_x, c = self.pit_encoder(phone, lengths, score, slurs) 74 | 75 | # Compute loss between encoder outputs and pitch 76 | floor = torch.ones_like(pitch) 77 | pitch = torch.maximum(pitch, floor) 78 | pitch = torch.log2(pitch) 79 | # Loss 80 | loss_f0 = F.l1_loss(mu_x[:, 0, :], pitch) 81 | uv_gt = (pitch > 0).to(pitch.dtype) 82 | loss_uv = F.binary_cross_entropy_with_logits(mu_x[:, 1, :], uv_gt) 83 | prior_loss = loss_f0 + loss_uv 84 | # pitch_gt 85 | pitch_gt = torch.zeros_like(mu_x, device=mu_x.device) 86 | pitch_gt[:, 0, :] = pitch 87 | pitch_gt[:, 1, :] = uv_gt 88 | # Compute loss of score-based decoder 89 | # Cut a small segment of pitch in order to increase batch size 90 | if not isinstance(out_size, type(None)) and out_size < pitch_gt.shape[1]: 91 | ids = rand_ids_segments(lengths, out_size) 92 | pitch_gt = slice_segments(pitch_gt, ids, out_size) 93 | 94 | mask_x = slice_segments(mask_x, ids, out_size) 95 | mu_x = slice_segments(mu_x, ids, out_size) 96 | c = slice_segments(c, ids, out_size) 97 | 98 | diff_loss, xt = self.decoder.compute_loss(pitch_gt, mask_x, mu_x, c) 99 | return prior_loss, diff_loss 100 | 101 | -------------------------------------------------------------------------------- /pitch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | 4 | 5 | def sequence_mask(length, max_length=None): 6 | if max_length is None: 7 | max_length = length.max() 8 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 9 | return x.unsqueeze(0) < length.unsqueeze(1) 10 | 11 | 12 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 13 | while True: 14 | if length % (2**num_downsamplings_in_unet) == 0: 15 | return length 16 | length += 1 17 | 18 | 19 | def convert_pad_shape(pad_shape): 20 | l = pad_shape[::-1] 21 | pad_shape = [item for sublist in l for item in sublist] 22 | return pad_shape 23 | 24 | 25 | def generate_path(duration, mask): 26 | device = duration.device 27 | 28 | b, t_x, t_y = mask.shape 29 | cum_duration = torch.cumsum(duration, 1) 30 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 31 | 32 | cum_duration_flat = cum_duration.view(b * t_x) 33 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 34 | path = path.view(b, t_x, t_y) 35 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], 36 | [1, 0], [0, 0]]))[:, :-1] 37 | path = path * mask 38 | return path 39 | 40 | 41 | def duration_loss(logw, logw_, lengths): 42 | loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) 43 | return loss 44 | 45 | 46 | def rand_ids_segments(lengths, segment_size=200): 47 | b = lengths.shape[0] 48 | ids_str_max = lengths - segment_size 49 | ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(dtype=torch.long) 50 | ids_str = torch.where(ids_str < 0, 0, ids_str) # fix error 51 | return ids_str 52 | 53 | 54 | def slice_segments(x, ids_str, segment_size=200): 55 | ret = torch.zeros_like(x[:, :, :segment_size]) 56 | for i in range(x.size(0)): 57 | idx_str = ids_str[i] 58 | idx_end = idx_str + segment_size 59 | ret[i] = x[i, :, idx_str:idx_end] 60 | return ret 61 | 62 | 63 | def retrieve_name(var): 64 | for fi in reversed(inspect.stack()): 65 | names = [var_name for var_name, 66 | var_val in fi.frame.f_locals.items() if var_val is var] 67 | if len(names) > 0: 68 | return names[0] 69 | 70 | 71 | Debug_Enable = True 72 | 73 | 74 | def debug_shapes(var): 75 | if Debug_Enable: 76 | print(retrieve_name(var), var.shape) 77 | -------------------------------------------------------------------------------- /pitch_extend/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from pitch.data_utils import DistributedBucketSampler 3 | from pitch.data_utils import TextAudioLoader 4 | from pitch.data_utils import TextAudioCollate 5 | 6 | 7 | def create_dataloader_train(hps, n_gpus, rank): 8 | collate_fn = TextAudioCollate() 9 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data) 10 | train_sampler = DistributedBucketSampler( 11 | train_dataset, 12 | hps.train.batch_size, 13 | [32, 300, 400, 500, 600, 700, 800, 900, 1000], 14 | num_replicas=n_gpus, 15 | rank=rank, 16 | shuffle=True) 17 | train_loader = DataLoader( 18 | train_dataset, 19 | num_workers=4, 20 | shuffle=False, 21 | pin_memory=True, 22 | collate_fn=collate_fn, 23 | batch_sampler=train_sampler) 24 | return train_loader 25 | 26 | 27 | def create_dataloader_eval(hps): 28 | collate_fn = TextAudioCollate() 29 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data) 30 | eval_loader = DataLoader( 31 | eval_dataset, 32 | num_workers=2, 33 | shuffle=False, 34 | batch_size=hps.train.batch_size, 35 | pin_memory=True, 36 | drop_last=False, 37 | collate_fn=collate_fn) 38 | return eval_loader 39 | -------------------------------------------------------------------------------- /pitch_extend/plotting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | mpl_logger = logging.getLogger('matplotlib') # must before import matplotlib 3 | mpl_logger.setLevel(logging.WARNING) 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | 7 | import numpy as np 8 | import matplotlib.pylab as plt 9 | 10 | 11 | def save_figure_to_numpy(fig): 12 | # save it to a numpy array. 13 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 14 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 15 | data = np.transpose(data, (2, 0, 1)) 16 | return data 17 | 18 | 19 | def plot_f0_to_numpy(f0_pre, f0_gt=None): 20 | fig = plt.figure(figsize=(12, 6)) 21 | plt.plot(f0_pre.T, "g") 22 | if f0_gt is not None: 23 | plt.plot(f0_gt.T, "r") 24 | fig.canvas.draw() 25 | data = save_figure_to_numpy(fig) 26 | plt.close() 27 | return data 28 | -------------------------------------------------------------------------------- /pitch_extend/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import math 5 | import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.distributed import init_process_group 10 | from torch.nn.parallel import DistributedDataParallel 11 | 12 | from vits.commons import clip_grad_value_ 13 | 14 | from pitch.utils import fix_len_compatibility 15 | from pitch.models import PitchDiffusion 16 | from pitch_extend.validation import validate 17 | from pitch_extend.writer import MyWriter 18 | from pitch_extend.dataloader import create_dataloader_train 19 | from pitch_extend.dataloader import create_dataloader_eval 20 | 21 | 22 | def load_model(model, saved_state_dict): 23 | if hasattr(model, 'module'): 24 | state_dict = model.module.state_dict() 25 | else: 26 | state_dict = model.state_dict() 27 | new_state_dict = {} 28 | for k, v in state_dict.items(): 29 | try: 30 | new_state_dict[k] = saved_state_dict[k] 31 | except: 32 | print("%s is not in the checkpoint" % k) 33 | new_state_dict[k] = v 34 | if hasattr(model, 'module'): 35 | model.module.load_state_dict(new_state_dict) 36 | else: 37 | model.load_state_dict(new_state_dict) 38 | return model 39 | 40 | 41 | # 400 frames 42 | out_size = fix_len_compatibility(400) 43 | 44 | 45 | def train(rank, args, chkpt_path, hp, hp_str): 46 | 47 | if args.num_gpus > 1: 48 | init_process_group(backend=hp.dist_config.dist_backend, init_method=hp.dist_config.dist_url, 49 | world_size=hp.dist_config.world_size * args.num_gpus, rank=rank) 50 | 51 | torch.cuda.manual_seed(hp.train.seed) 52 | device = torch.device('cuda:{:d}'.format(rank)) 53 | 54 | model_g = PitchDiffusion().to(device) 55 | 56 | optim_g = torch.optim.AdamW(model_g.parameters(), 57 | lr=hp.train.learning_rate, betas=hp.train.betas, eps=hp.train.eps) 58 | 59 | init_epoch = 1 60 | step = 0 61 | 62 | # define logger, writer, valloader, stft at rank_zero 63 | if rank == 0: 64 | pth_dir = os.path.join(hp.log.pth_dir, args.name) 65 | log_dir = os.path.join(hp.log.log_dir, args.name) 66 | os.makedirs(pth_dir, exist_ok=True) 67 | os.makedirs(log_dir, exist_ok=True) 68 | 69 | logging.basicConfig( 70 | level=logging.INFO, 71 | format='%(asctime)s - %(levelname)s - %(message)s', 72 | handlers=[ 73 | logging.FileHandler(os.path.join(log_dir, '%s-%d.log' % (args.name, time.time()))), 74 | logging.StreamHandler() 75 | ] 76 | ) 77 | logger = logging.getLogger() 78 | writer = MyWriter(hp, log_dir) 79 | valloader = create_dataloader_eval(hp) 80 | 81 | if chkpt_path is not None: 82 | if rank == 0: 83 | logger.info("Resuming from checkpoint: %s" % chkpt_path) 84 | checkpoint = torch.load(chkpt_path, map_location='cpu') 85 | load_model(model_g, checkpoint['model_g']) 86 | optim_g.load_state_dict(checkpoint['optim_g']) 87 | init_epoch = checkpoint['epoch'] 88 | step = checkpoint['step'] 89 | 90 | if rank == 0: 91 | if hp_str != checkpoint['hp_str']: 92 | logger.warning("New hparams is different from checkpoint. Will use new.") 93 | else: 94 | if rank == 0: 95 | logger.info("Starting new training run.") 96 | 97 | if args.num_gpus > 1: 98 | model_g = DistributedDataParallel(model_g, device_ids=[rank]) 99 | 100 | # this accelerates training when the size of minibatch is always consistent. 101 | # if not consistent, it'll horribly slow down. 102 | torch.backends.cudnn.benchmark = True 103 | 104 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.lr_decay, last_epoch=init_epoch-2) 105 | trainloader = create_dataloader_train(hp, args.num_gpus, rank) 106 | 107 | for epoch in range(init_epoch, hp.train.epochs): 108 | 109 | trainloader.batch_sampler.set_epoch(epoch) 110 | 111 | if rank == 0 and epoch % hp.log.eval_interval == 0: 112 | with torch.no_grad(): 113 | validate(hp, model_g, valloader, writer, step, device) 114 | 115 | if rank == 0: 116 | loader = tqdm.tqdm(trainloader, desc='Loading train data') 117 | else: 118 | loader = trainloader 119 | 120 | model_g.train() 121 | 122 | for phone, phone_l, score, pitch, slurs in loader: 123 | 124 | phone = phone.to(device) 125 | phone_l = phone_l.to(device) 126 | score = score.to(device) 127 | pitch = pitch.to(device) 128 | slurs = slurs.to(device) 129 | 130 | # generator 131 | optim_g.zero_grad() 132 | # 133 | prior_loss, diff_loss = model_g.compute_loss(phone, phone_l, score, slurs, pitch, out_size=out_size) 134 | loss_g = sum([prior_loss, diff_loss]) 135 | loss_g.backward() 136 | clip_grad_value_(model_g.parameters(), None) 137 | optim_g.step() 138 | 139 | step += 1 140 | # logging 141 | loss_g = loss_g.item() 142 | if rank == 0 and step % hp.log.info_interval == 0: 143 | writer.log_training(loss_g, prior_loss, diff_loss, step) 144 | logger.info("epoch %d | g %.04f prior_loss %.04f diff_loss %.04f | step %d" % ( 145 | epoch, loss_g, prior_loss, diff_loss, step)) 146 | 147 | if rank == 0 and epoch % hp.log.save_interval == 0: 148 | save_path = os.path.join(pth_dir, '%s_%04d.pt' 149 | % (args.name, epoch)) 150 | torch.save({ 151 | 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(), 152 | 'optim_g': optim_g.state_dict(), 153 | 'step': step, 154 | 'epoch': epoch, 155 | 'hp_str': hp_str, 156 | }, save_path) 157 | logger.info("Saved checkpoint to: %s" % save_path) 158 | 159 | scheduler_g.step() 160 | -------------------------------------------------------------------------------- /pitch_extend/validation.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def validate(hp, generator, valloader, writer, step, device): 7 | generator.eval() 8 | torch.backends.cudnn.benchmark = False 9 | 10 | loader = tqdm.tqdm(valloader, desc='Validation loop') 11 | vali_loss = 0.0 12 | for idx, (phone, phone_l, score, pitch, slurs) in enumerate(loader): 13 | phone = phone.to(device) 14 | phone_l = phone_l.to(device) 15 | score = score.to(device) 16 | pitch = pitch.to(device) 17 | slurs = slurs.to(device) 18 | 19 | pitch_pri, pitch_pre = generator(phone, phone_l, score, slurs, n_timesteps=50) 20 | 21 | # De-Log 22 | pitch_pri = torch.pow(2, pitch_pri) 23 | pitch_pre = torch.pow(2, pitch_pre) 24 | 25 | loss_f0 = F.l1_loss(pitch_pre[:, 0, :], pitch) 26 | vali_loss += loss_f0.item() 27 | 28 | if idx < hp.log.num_audio: 29 | writer.log_fig_pitch(pitch_pri, pitch_pre, pitch, idx, step) 30 | 31 | vali_loss = vali_loss / len(valloader.dataset) 32 | writer.log_validation(vali_loss, step) 33 | 34 | torch.backends.cudnn.benchmark = True 35 | -------------------------------------------------------------------------------- /pitch_extend/writer.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from .plotting import plot_f0_to_numpy 3 | 4 | 5 | class MyWriter(SummaryWriter): 6 | def __init__(self, hp, logdir): 7 | super(MyWriter, self).__init__(logdir) 8 | 9 | def log_training(self, loss_g, prior_loss, diff_loss, step): 10 | self.add_scalar('train/loss_g', loss_g, step) 11 | self.add_scalar('train/loss_prior', prior_loss, step) 12 | self.add_scalar('train/loss_diff', diff_loss, step) 13 | 14 | def log_validation(self, vali_loss, step): 15 | self.add_scalar('validation/vali_loss', vali_loss, step) 16 | 17 | def log_fig_pitch(self, pitch_prio, pitch_fake, pitch_real, idx, step): 18 | if idx == 0: 19 | pitch_prio = pitch_prio[0, 0, :].data.cpu().numpy() 20 | pitch_fake = pitch_fake[0, 0, :].data.cpu().numpy() 21 | pitch_prio[pitch_prio > 1000] = 1000 22 | pitch_fake[pitch_fake > 1000] = 1000 23 | pitch_real = pitch_real[0].data.cpu().numpy() 24 | self.add_image(f'pitch_prio/{step}', plot_f0_to_numpy(pitch_prio, pitch_real), step) 25 | self.add_image(f'pitch_fake/{step}', plot_f0_to_numpy(pitch_fake, pitch_real), step) 26 | # self.add_image(f'pitch_real/{step}', plot_f0_to_numpy(pitch_real), step) 27 | -------------------------------------------------------------------------------- /resource/vising_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/resource/vising_loss.png -------------------------------------------------------------------------------- /resource/vising_mel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/resource/vising_mel.png -------------------------------------------------------------------------------- /resource/vising_sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/resource/vising_sample.wav -------------------------------------------------------------------------------- /svs/__init__.py: -------------------------------------------------------------------------------- 1 | from svs.phone_map import label_to_ids 2 | from svs.phone_uv import uv_map 3 | 4 | 5 | def load_midi_map(): 6 | notemap = {} 7 | notemap["rest"] = 0 8 | fo = open("./svs/midi-note.scp", "r+") 9 | while True: 10 | try: 11 | message = fo.readline().strip() 12 | except Exception as e: 13 | print("nothing of except:", e) 14 | break 15 | if message == None: 16 | break 17 | if message == "": 18 | break 19 | infos = message.split() 20 | notemap[infos[1]] = int(infos[0]) 21 | fo.close() 22 | return notemap 23 | -------------------------------------------------------------------------------- /svs/midi-HZ.scp: -------------------------------------------------------------------------------- 1 | 127 G9 12543.9 2 | 126 F#9/Gb9 11839.8 3 | 125 F9 11175.3 4 | 124 E9 10548.1 5 | 123 D#9/Eb9 9956.1 6 | 122 D9 9397.3 7 | 121 C#9/Db9 8869.8 8 | 120 C9 8372 9 | 119 B8 7902.1 10 | 118 A#8/Bb8 7458.6 11 | 117 A8 7040 12 | 116 G#8/Ab8 6644.9 13 | 115 G8 6271.9 14 | 114 F#8/Gb8 5919.9 15 | 113 F8 5587.7 16 | 112 E8 5274 17 | 111 D#8/Eb8 4978 18 | 110 D8 4698.6 19 | 109 C#8/Db8 4434.9 20 | 108 C8 4186 21 | 107 B7 3951.1 22 | 106 A#7/Bb7 3729.3 23 | 105 A7 3520 24 | 104 G#7/Ab7 3322.4 25 | 103 G7 3136 26 | 102 F#7/Gb7 2960 27 | 101 F7 2793.8 28 | 100 E7 2637 29 | 99 D#7/Eb7 2489 30 | 98 D7 2349.3 31 | 97 C#7/Db7 2217.5 32 | 96 C7 2093 33 | 95 B6 1975.5 34 | 94 A#6/Bb6 1864.7 35 | 93 A6 1760 36 | 92 G#6/Ab6 1661.2 37 | 91 G6 1568 38 | 90 F#6/Gb6 1480 39 | 89 F6 1396.9 40 | 88 E6 1318.5 41 | 87 D#6/Eb6 1244.5 42 | 86 D6 1174.7 43 | 85 C#6/Db6 1108.7 44 | 84 C6 1046.5 45 | 83 B5 987.8 46 | 82 A#5/Bb5 932.3 47 | 81 A5 880 48 | 80 G#5/Ab5 830.6 49 | 79 G5 784 50 | 78 F#5/Gb5 740 51 | 77 F5 698.5 52 | 76 E5 659.3 53 | 75 D#5/Eb5 622.3 54 | 74 D5 587.3 55 | 73 C#5/Db5 554.4 56 | 72 C5 523.3 57 | 71 B4 493.9 58 | 70 A#4/Bb4 466.2 59 | 69 A4 440 60 | 68 G#4/Ab4 415.3 61 | 67 G4 392 62 | 66 F#4/Gb4 370 63 | 65 F4 349.2 64 | 64 E4 329.6 65 | 63 D#4/Eb4 311.1 66 | 62 D4 293.7 67 | 61 C#4/Db4 277.2 68 | 60 C4 261.6 69 | 59 B3 246.9 70 | 58 A#3/Bb3 233.1 71 | 57 A3 220 72 | 56 G#3/Ab3 207.7 73 | 55 G3 196 74 | 54 F#3/Gb3 185 75 | 53 F3 174.6 76 | 52 E3 164.8 77 | 51 D#3/Eb3 155.6 78 | 50 D3 146.8 79 | 49 C#3/Db3 138.6 80 | 48 C3 130.8 81 | 47 B2 123.5 82 | 46 A#2/Bb2 116.5 83 | 45 A2 110 84 | 44 G#2/Ab2 103. 85 | 43 G2 98 86 | 42 F#2/Gb2 92.5 87 | 41 F2 87.3 88 | 40 E2 82.4 89 | 39 D#2/Eb2 77.8 90 | 38 D2 73.4 91 | 37 C#2/Db2 69.3 92 | 36 C2 65.4 93 | 35 B1 61.7 94 | 34 A#1/Bb1 58.3 95 | 33 A1 55 96 | 32 G#1/Ab1 51.9 97 | 31 G1 49 98 | 30 F#1/Gb1 46.2 99 | 29 F1 43.7 100 | 28 E1 41.2 101 | 27 D#1/Eb1 38.9 102 | 26 D1 36.7 103 | 25 C#1/Db1 34.6 104 | 24 C1 32.7 105 | 23 B0 30.9 106 | 22 A#0/Bb0 29.1 107 | 21 A0 27.5 108 | 0 rest 0 -------------------------------------------------------------------------------- /svs/midi-note.scp: -------------------------------------------------------------------------------- 1 | 127 G9 2 | 126 F#9/Gb9 3 | 125 F9 4 | 124 E9 5 | 123 D#9/Eb9 6 | 122 D9 7 | 121 C#9/Db9 8 | 120 C9 9 | 119 B8 10 | 118 A#8/Bb8 11 | 117 A8 12 | 116 G#8/Ab8 13 | 115 G8 14 | 114 F#8/Gb8 15 | 113 F8 16 | 112 E8 17 | 111 D#8/Eb8 18 | 110 D8 19 | 109 C#8/Db8 20 | 108 C8 21 | 107 B7 22 | 106 A#7/Bb7 23 | 105 A7 24 | 104 G#7/Ab7 25 | 103 G7 26 | 102 F#7/Gb7 27 | 101 F7 28 | 100 E7 29 | 99 D#7/Eb7 30 | 98 D7 31 | 97 C#7/Db7 32 | 96 C7 33 | 95 B6 34 | 94 A#6/Bb6 35 | 93 A6 36 | 92 G#6/Ab6 37 | 91 G6 38 | 90 F#6/Gb6 39 | 89 F6 40 | 88 E6 41 | 87 D#6/Eb6 42 | 86 D6 43 | 85 C#6/Db6 44 | 84 C6 45 | 83 B5 46 | 82 A#5/Bb5 47 | 81 A5 48 | 80 G#5/Ab5 49 | 79 G5 50 | 78 F#5/Gb5 51 | 77 F5 52 | 76 E5 53 | 75 D#5/Eb5 54 | 74 D5 55 | 73 C#5/Db5 56 | 72 C5 57 | 71 B4 58 | 70 A#4/Bb4 59 | 69 A4 60 | 68 G#4/Ab4 61 | 67 G4 62 | 66 F#4/Gb4 63 | 65 F4 64 | 64 E4 65 | 63 D#4/Eb4 66 | 62 D4 67 | 61 C#4/Db4 68 | 60 C4 69 | 59 B3 70 | 58 A#3/Bb3 71 | 57 A3 72 | 56 G#3/Ab3 73 | 55 G3 74 | 54 F#3/Gb3 75 | 53 F3 76 | 52 E3 77 | 51 D#3/Eb3 78 | 50 D3 79 | 49 C#3/Db3 80 | 48 C3 81 | 47 B2 82 | 46 A#2/Bb2 83 | 45 A2 84 | 44 G#2/Ab2 85 | 43 G2 86 | 42 F#2/Gb2 87 | 41 F2 88 | 40 E2 89 | 39 D#2/Eb2 90 | 38 D2 91 | 37 C#2/Db2 92 | 36 C2 93 | 35 B1 94 | 34 A#1/Bb1 95 | 33 A1 96 | 32 G#1/Ab1 97 | 31 G1 98 | 30 F#1/Gb1 99 | 29 F1 100 | 28 E1 101 | 27 D#1/Eb1 102 | 26 D1 103 | 25 C#1/Db1 104 | 24 C1 105 | 23 B0 106 | 22 A#0/Bb0 107 | 21 A0 -------------------------------------------------------------------------------- /svs/phone_map.py: -------------------------------------------------------------------------------- 1 | _pause = ["unk", "sos", "eos", "ap", "sp"] 2 | 3 | _initials = [ 4 | "b", 5 | "c", 6 | "ch", 7 | "d", 8 | "f", 9 | "g", 10 | "h", 11 | "j", 12 | "k", 13 | "l", 14 | "m", 15 | "n", 16 | "p", 17 | "q", 18 | "r", 19 | "s", 20 | "sh", 21 | "t", 22 | "w", 23 | "x", 24 | "y", 25 | "z", 26 | "zh", 27 | ] 28 | 29 | _finals = [ 30 | "a", 31 | "ai", 32 | "an", 33 | "ang", 34 | "ao", 35 | "e", 36 | "ei", 37 | "en", 38 | "eng", 39 | "er", 40 | "i", 41 | "ia", 42 | "ian", 43 | "iang", 44 | "iao", 45 | "ie", 46 | "in", 47 | "ing", 48 | "iong", 49 | "iu", 50 | "o", 51 | "ong", 52 | "ou", 53 | "u", 54 | "ua", 55 | "uai", 56 | "uan", 57 | "uang", 58 | "ui", 59 | "un", 60 | "uo", 61 | "v", 62 | "van", 63 | "ve", 64 | "vn", 65 | ] 66 | 67 | 68 | symbols = _pause + _initials + _finals 69 | 70 | # Mappings from symbol to numeric ID and vice versa: 71 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 72 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 73 | 74 | 75 | def label_to_ids(phones): 76 | # use lower letter 77 | sequence = [_symbol_to_id[symbol.lower()] for symbol in phones] 78 | return sequence 79 | -------------------------------------------------------------------------------- /svs/phone_uv.py: -------------------------------------------------------------------------------- 1 | # 普通话发音基础声母韵母 2 | # 普通话声母只有 4 个浊音:m、n、l、r,其余 17 个辅音声母都是清音 3 | # 汉语拼音的 y 和 w 只出现在零声母音节的开头,它们的作用主要是使音节界限清楚。 4 | # https://baijiahao.baidu.com/s?id=1655739561730224990&wfr=spider&for=pc 5 | 6 | uv_map = { 7 | "unk":0, 8 | "sos":0, 9 | "eos":0, 10 | "ap":0, 11 | "sp":0, 12 | "b":0, 13 | "c":0, 14 | "ch":0, 15 | "d":0, 16 | "f":0, 17 | "g":0, 18 | "h":0, 19 | "j":0, 20 | "k":0, 21 | "l":1, 22 | "m":1, 23 | "n":1, 24 | "p":0, 25 | "q":0, 26 | "r":1, 27 | "s":0, 28 | "sh":0, 29 | "t":0, 30 | "w":1, 31 | "x":0, 32 | "y":1, 33 | "z":0, 34 | "zh":0, 35 | "a":1, 36 | "ai":1, 37 | "an":1, 38 | "ang":1, 39 | "ao":1, 40 | "e":1, 41 | "ei":1, 42 | "en":1, 43 | "eng":1, 44 | "er":1, 45 | "i":1, 46 | "ia":1, 47 | "ian":1, 48 | "iang":1, 49 | "iao":1, 50 | "ie":1, 51 | "in":1, 52 | "ing":1, 53 | "iong":1, 54 | "iu":1, 55 | "o":1, 56 | "ong":1, 57 | "ou":1, 58 | "u":1, 59 | "ua":1, 60 | "uai":1, 61 | "uan":1, 62 | "uang":1, 63 | "ui":1, 64 | "un":1, 65 | "uo":1, 66 | "v":1, 67 | "van":1, 68 | "ve":1, 69 | "vn":1 70 | } -------------------------------------------------------------------------------- /svs_export.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import torch 4 | import argparse 5 | from omegaconf import OmegaConf 6 | 7 | from vits.models import SynthesizerTrn 8 | 9 | 10 | def load_model(checkpoint_path, model): 11 | assert os.path.isfile(checkpoint_path) 12 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 13 | saved_state_dict = checkpoint_dict["model_g"] 14 | if hasattr(model, "module"): 15 | state_dict = model.module.state_dict() 16 | else: 17 | state_dict = model.state_dict() 18 | new_state_dict = {} 19 | for k, v in state_dict.items(): 20 | try: 21 | new_state_dict[k] = saved_state_dict[k] 22 | except: 23 | new_state_dict[k] = v 24 | if hasattr(model, "module"): 25 | model.module.load_state_dict(new_state_dict) 26 | else: 27 | model.load_state_dict(new_state_dict) 28 | return model 29 | 30 | 31 | def save_pretrain(checkpoint_path, save_path): 32 | assert os.path.isfile(checkpoint_path) 33 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 34 | torch.save({ 35 | 'model_g': checkpoint_dict['model_g'], 36 | 'model_d': checkpoint_dict['model_d'], 37 | }, save_path) 38 | 39 | 40 | def save_model(model, checkpoint_path): 41 | if hasattr(model, 'module'): 42 | state_dict = model.module.state_dict() 43 | else: 44 | state_dict = model.state_dict() 45 | torch.save({'model_g': state_dict}, checkpoint_path) 46 | 47 | 48 | def main(args): 49 | hp = OmegaConf.load(args.config) 50 | model = SynthesizerTrn( 51 | hp.data.filter_length // 2 + 1, 52 | hp.data.segment_size // hp.data.hop_length, 53 | hp) 54 | 55 | load_model(args.model, model) 56 | save_model(model, "svs_opencpop.pt") 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('-c', '--config', type=str, required=True, 62 | help="yaml file for config. will use hp_str from checkpoint if not given.") 63 | parser.add_argument('-m', '--model', type=str, required=True, 64 | help="path of checkpoint pt file for evaluation") 65 | args = parser.parse_args() 66 | 67 | main(args) 68 | -------------------------------------------------------------------------------- /svs_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from scipy.io import wavfile 6 | from time import * 7 | 8 | import torch 9 | import argparse 10 | 11 | from vits.models import SynthesizerTrn 12 | from util import SingInput 13 | from util import FeatureInput 14 | from omegaconf import OmegaConf 15 | 16 | 17 | def save_wav(wav, path, rate): 18 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 19 | wavfile.write(path, rate, wav.astype(np.int16)) 20 | 21 | 22 | def load_svs_model(checkpoint_path, model): 23 | assert os.path.isfile(checkpoint_path) 24 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 25 | saved_state_dict = checkpoint_dict["model_g"] 26 | state_dict = model.state_dict() 27 | new_state_dict = {} 28 | for k, v in state_dict.items(): 29 | try: 30 | new_state_dict[k] = saved_state_dict[k] 31 | except: 32 | print("%s is not in the checkpoint" % k) 33 | new_state_dict[k] = v 34 | model.load_state_dict(new_state_dict) 35 | return model 36 | 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('-c', '--config', type=str, required=True, 41 | help="yaml file for configuration") 42 | parser.add_argument('-m', '--model', type=str, required=True, 43 | help="path of checkpoint pt file") 44 | args = parser.parse_args() 45 | 46 | # define model and load checkpoint 47 | hps = OmegaConf.load(args.config) 48 | 49 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 50 | featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length) 51 | 52 | net_g = SynthesizerTrn( 53 | hps.data.filter_length // 2 + 1, 54 | hps.data.segment_size // hps.data.hop_length, 55 | hps).cuda() 56 | net_g.eval() 57 | 58 | load_svs_model(args.model, net_g) 59 | 60 | # check directory existence 61 | os.makedirs("./svs_out", exist_ok=True) 62 | fo = open("./svs_infer.txt", "r+") 63 | while True: 64 | try: 65 | message = fo.readline().strip() 66 | except Exception as e: 67 | print("nothing of except:", e) 68 | break 69 | if message == None: 70 | break 71 | if message == "": 72 | break 73 | print(message) 74 | ( 75 | file, 76 | labels_ids, 77 | labels_frames, 78 | scores_ids, 79 | scores_dur, 80 | labels_slr, 81 | labels_uvs, 82 | ) = singInput.parseInput(message) 83 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 84 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 85 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 86 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 87 | scores_pit = singInput.scorePitch(scores_ids) 88 | # elments by elments 89 | scores_pit_ = scores_pit * labels_uvs 90 | scores_pit = singInput.smoothPitch(scores_pit_) 91 | 92 | fig = plt.figure(figsize=(12, 6)) 93 | plt.plot(scores_pit_.T, "g") 94 | plt.plot(scores_pit.T, "r") 95 | plt.savefig(f"./svs_out/{file}_f0_.png", format="png") 96 | plt.close(fig) 97 | 98 | phone = torch.LongTensor(labels_ids) 99 | score = torch.LongTensor(scores_ids) 100 | slurs = torch.LongTensor(labels_slr) 101 | pitch = torch.FloatTensor(scores_pit) 102 | 103 | phone_lengths = phone.size()[0] 104 | 105 | with torch.no_grad(): 106 | phone = phone.cuda().unsqueeze(0) 107 | score = score.cuda().unsqueeze(0) 108 | pitch = pitch.cuda().unsqueeze(0) 109 | slurs = slurs.cuda().unsqueeze(0) 110 | phone_lengths = torch.LongTensor([phone_lengths]).cuda() 111 | audio = ( 112 | net_g.infer(phone, phone_lengths, score, pitch, slurs)[0, 0] 113 | .data.cpu() 114 | .float() 115 | .numpy() 116 | ) 117 | 118 | save_wav(audio, f"./svs_out/{file}.wav", hps.data.sampling_rate) 119 | fo.close() 120 | # can be deleted 121 | os.system("chmod 777 ./svs_out -R") 122 | -------------------------------------------------------------------------------- /svs_infer.txt: -------------------------------------------------------------------------------- 1 | 2001000001|感受停在我发端的指尖|g an sh ou t ing z ai w o f a d uan d e SP zh i j ian AP|G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 E4 E4 E4 E4 D#4/Eb4 D#4/Eb4 D#4/Eb4 D#4/Eb4 rest E4 E4 E4 E4 rest|0.253030 0.253030 0.428030 0.428030 0.320870 0.320870 0.358110 0.358110 0.218610 0.218610 0.519380 0.519380 0.351070 0.351070 0.152260 0.152260 0.089470 0.405810 0.405810 0.696660 0.696660 0.284630|0.0317 0.22133 0.15421 0.27382 0.06335 0.25752 0.07101 0.2871 0.03623 0.18238 0.18629 0.33309 0.01471 0.33636 0.01415 0.13811 0.08947 0.12862 0.27719 0.07962 0.61704 0.28463|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 | 2001000002|如何瞬间冻结时间|r u h e sh un j ian AP SP d ong j ie sh i j ian SP|B3 B3 B3 B3 B3 B3 G#4/Ab4 G#4/Ab4 rest rest B3 B3 B3 B3 B3 B3 F#4/Gb4 F#4/Gb4 rest|0.294760 0.294760 0.283550 0.283550 0.795250 0.795250 0.992200 0.992200 0.297130 0.104830 0.311040 0.311040 0.214620 0.214620 0.782750 0.782750 1.519540 1.519540 1.179120|0.06588 0.22888 0.11684 0.16671 0.18746 0.60779 0.11194 0.88026 0.29713 0.10483 0.03166 0.27938 0.05057 0.16405 0.21149 0.57126 0.13926 1.38028 1.17912|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 | 2001000003|记住望着我坚定的双眼|j i zh u w ang zh e w o SP j ian d ing d e sh uang y an AP|G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 E4 E4 rest E4 E4 D#4/Eb4 D#4/Eb4 D#4/Eb4 D#4/Eb4 E4 E4 E4 E4 rest|0.388470 0.388470 0.368320 0.368320 0.363510 0.363510 0.316690 0.316690 0.161350 0.161350 0.055570 0.495580 0.495580 0.342860 0.342860 0.141750 0.141750 0.398360 0.398360 0.785070 0.785070 0.317450|0.09945 0.28902 0.08103 0.28729 0.05083 0.31268 0.04303 0.27366 0.03603 0.12532 0.05557 0.15191 0.34367 0.02357 0.31929 0.02939 0.11236 0.21916 0.1792 0.22549 0.55958 0.31745|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 | 2001000004|也许已经没有明天|y e x v y i j ing AP m ei y ou m ing t ian SP AP|B3 B3 B3 B3 B3 B3 G#4/Ab4 G#4/Ab4 rest B3 B3 B3 B3 B3 B3 F#4/Gb4 F#4/Gb4 rest rest|0.236860 0.236860 0.426110 0.426110 0.660620 0.660620 1.021220 1.021220 0.409380 0.243270 0.243270 0.327560 0.327560 0.741700 0.741700 1.335140 1.335140 0.591900 0.515310|0.07979 0.15707 0.2089 0.21721 0.12179 0.53883 0.16915 0.85207 0.40938 0.06617 0.1771 0.04273 0.28483 0.11939 0.62231 0.17586 1.15928 0.5919 0.51531|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 5 | 2001000005|面对浩瀚的星海我们微小得像尘埃|m ian d ui h ao h an an d e x ing h ai ai ai AP w o m en w ei x iao d e x iang ch en ai ai ai SP|C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 E4 D#4/Eb4 D#4/Eb4 E4 E4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest C#4/Db4 C#4/Db4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 E4 E4 E4 E4 G#4/Ab4 A4 G#4/Ab4 rest|0.196990 0.196990 0.102120 0.102120 0.304680 0.304680 0.096780 0.096780 0.100220 0.150010 0.150010 0.361460 0.361460 0.221070 0.221070 0.183240 0.478670 0.384620 0.106510 0.106510 0.143020 0.143020 0.169480 0.169480 0.224180 0.224180 0.089360 0.089360 0.414460 0.414460 0.378050 0.378050 0.162790 0.207380 0.317260 0.297040|0.02765 0.16934 0.01874 0.08338 0.0821 0.22258 0.0693 0.02748 0.10022 0.07137 0.07864 0.12471 0.23675 0.12356 0.09751 0.18324 0.47867 0.38462 0.0405 0.06601 0.08303 0.05999 0.04687 0.12261 0.09778 0.1264 0.02321 0.06615 0.11958 0.29488 0.06723 0.31082 0.16279 0.20738 0.31726 0.29704|0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 6 | 2001000006|漂浮在一片无奈|p iao f u z ai ai ai AP SP y i i p ian ian ian w u n ai SP AP|E4 E4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest rest E4 E4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 E4 E4 F#4/Gb4 F#4/Gb4 rest rest|0.185230 0.185230 0.177410 0.177410 0.193930 0.193930 0.259670 0.299340 0.215550 0.031770 0.197520 0.197520 0.165450 0.184760 0.184760 0.212290 0.246960 0.440370 0.440370 1.524950 1.524950 0.855830 0.559100|0.06011 0.12512 0.07517 0.10224 0.08603 0.1079 0.25967 0.29934 0.21555 0.03177 0.05175 0.14577 0.16545 0.0748 0.10996 0.21229 0.24696 0.09617 0.3442 0.1437 1.38125 0.85583 0.5591|0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 7 | 2001000007|缘份让我们相遇乱世以外|y van f en r ang w o m en x iang y v AP l uan sh i y i w ai AP|D#4/Eb4 D#4/Eb4 E4 E4 E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 rest B4 B4 B4 B4 C#5/Db5 C#5/Db5 C#5/Db5 C#5/Db5 rest|0.323070 0.323070 0.325290 0.325290 0.483290 0.483290 0.212040 0.212040 0.294600 0.294600 0.465110 0.465110 0.364020 0.364020 0.137130 0.151270 0.151270 0.270860 0.270860 0.434770 0.434770 1.570380 1.570380 0.462970|0.12204 0.20103 0.11182 0.21347 0.09912 0.38417 0.05549 0.15655 0.10139 0.19321 0.17622 0.28889 0.0609 0.30312 0.13713 0.03605 0.11522 0.14541 0.12545 0.12186 0.31291 0.09403 1.47635 0.46297|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 | 2001000008|命运却要我们危难中相爱|m ing y van q ve y ao w o m en w ei n an zh ong x iang ai SP AP|E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 B4 B4 B4 B4 E4 E4 G#4/Ab4 G#4/Ab4 F#4/Gb4 F#4/Gb4 F#4/Gb4 rest rest|0.332160 0.332160 0.315140 0.315140 0.371590 0.371590 0.285140 0.285140 0.394510 0.394510 0.358480 0.358480 0.524060 0.524060 0.176940 0.176940 0.239510 0.239510 0.494880 0.494880 1.260320 0.317390 0.358080|0.03995 0.29221 0.08516 0.22998 0.12953 0.24206 0.09533 0.18981 0.09528 0.29923 0.06899 0.28949 0.03119 0.49287 0.048 0.12894 0.04204 0.19747 0.1539 0.34098 1.26032 0.31739 0.35808|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 | 2001000009|也许未来遥远在光年之外|y e x v w ei l ai y ao y van z ai SP g uang n ian zh i w ai SP AP|E4 E4 E4 E4 E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 rest B4 B4 B4 B4 C#5/Db5 C#5/Db5 C#5/Db5 C#5/Db5 rest rest|0.226010 0.226010 0.367780 0.367780 0.377380 0.377380 0.308330 0.308330 0.397890 0.397890 0.369570 0.369570 0.452320 0.452320 0.075060 0.237700 0.237700 0.272190 0.272190 0.325600 0.325600 1.446250 1.446250 0.243310 0.346690|0.11666 0.10935 0.23067 0.13711 0.14195 0.23543 0.12932 0.17901 0.16096 0.23693 0.19611 0.17346 0.08484 0.36748 0.07506 0.0593 0.1784 0.06402 0.20817 0.07175 0.25385 0.093 1.35325 0.24331 0.34669|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 10 | 2001000010|我愿守候未知里为你等待|w o y van sh ou h ou w ei zh i l i AP w ei n i SP d eng d ai AP|E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 B4 B4 B4 B4 rest E4 E4 G#4/Ab4 G#4/Ab4 rest F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 rest|0.302770 0.302770 0.288530 0.288530 0.402910 0.402910 0.447020 0.447020 0.296470 0.296470 0.202850 0.202850 0.466880 0.466880 0.207550 0.135530 0.135530 0.337900 0.337900 0.070010 0.249830 0.249830 0.392400 0.392400 0.210080|0.10342 0.19935 0.06127 0.22726 0.16322 0.23969 0.12336 0.32366 0.07033 0.22614 0.09677 0.10608 0.18788 0.279 0.20755 0.06243 0.0731 0.09532 0.24258 0.07001 0.03048 0.21935 0.10486 0.28754 0.21008|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 11 | -------------------------------------------------------------------------------- /svs_infer_pitch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from scipy.io import wavfile 6 | from time import * 7 | 8 | import torch 9 | import argparse 10 | 11 | from vits.models import SynthesizerTrn 12 | from util import SingInput 13 | from util import FeatureInput 14 | from omegaconf import OmegaConf 15 | 16 | from pitch.models import PitchDiffusion 17 | from pitch.utils import fix_len_compatibility 18 | 19 | 20 | def save_wav(wav, path, rate): 21 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 22 | wavfile.write(path, rate, wav.astype(np.int16)) 23 | 24 | 25 | def load_svs_model(checkpoint_path, model): 26 | assert os.path.isfile(checkpoint_path) 27 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 28 | saved_state_dict = checkpoint_dict["model_g"] 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | print("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | model.load_state_dict(new_state_dict) 38 | return model 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('-c', '--config', type=str, required=True, 44 | help="yaml file for configuration") 45 | parser.add_argument('-m', '--model', type=str, required=True, 46 | help="path of checkpoint pt file") 47 | parser.add_argument('-p', '--pitch', type=str, required=True, 48 | help="path of checkpoint pt file") 49 | args = parser.parse_args() 50 | 51 | # define model and load checkpoint 52 | hps = OmegaConf.load(args.config) 53 | 54 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 55 | featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length) 56 | 57 | net_g = SynthesizerTrn( 58 | hps.data.filter_length // 2 + 1, 59 | hps.data.segment_size // hps.data.hop_length, 60 | hps).cuda() 61 | net_g.eval() 62 | 63 | load_svs_model(args.model, net_g) 64 | 65 | net_p = PitchDiffusion().cuda() 66 | net_p.eval() 67 | load_svs_model(args.pitch, net_p) 68 | 69 | # check directory existence 70 | os.makedirs("./svs_out", exist_ok=True) 71 | fo = open("./svs_infer.txt", "r+") 72 | while True: 73 | try: 74 | message = fo.readline().strip() 75 | except Exception as e: 76 | print("nothing of except:", e) 77 | break 78 | if message == None: 79 | break 80 | if message == "": 81 | break 82 | print(message) 83 | ( 84 | file, 85 | labels_ids, 86 | labels_frames, 87 | scores_ids, 88 | scores_dur, 89 | labels_slr, 90 | labels_uvs, 91 | ) = singInput.parseInput(message) 92 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 93 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 94 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 95 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 96 | 97 | phone = torch.LongTensor(labels_ids) 98 | score = torch.LongTensor(scores_ids) 99 | slurs = torch.LongTensor(labels_slr) 100 | 101 | lengths = phone.size()[0] 102 | lengths_fix = fix_len_compatibility(lengths) 103 | 104 | phone_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda() 105 | score_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda() 106 | slurs_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda() 107 | phone_fix[0, :lengths] = phone 108 | score_fix[0, :lengths] = score 109 | slurs_fix[0, :lengths] = slurs 110 | 111 | with torch.no_grad(): 112 | n_timesteps = 50 113 | temperature = 1 114 | # PIT 115 | phone_lengths = torch.LongTensor([lengths_fix]).cuda() 116 | pit_pri, pit_pre = net_p(phone_fix, phone_lengths, score_fix, slurs_fix, n_timesteps, temperature) 117 | pitch = pit_pre[:, 0, :] 118 | pitch = 2**pitch 119 | print('~~~~~~~') 120 | # SVS 121 | audio = ( 122 | net_g.infer(phone_fix, phone_lengths, score_fix, pitch, slurs_fix)[0, 0] 123 | .data.cpu() 124 | .float() 125 | .numpy() 126 | ) 127 | 128 | save_wav(audio, f"./svs_out/{file}.wav", hps.data.sampling_rate) 129 | fo.close() 130 | # can be deleted 131 | os.system("chmod 777 ./svs_out -R") 132 | -------------------------------------------------------------------------------- /svs_song.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from scipy.io import wavfile 5 | from time import * 6 | 7 | import torch 8 | import argparse 9 | 10 | from vits.models import SynthesizerTrn 11 | from util import SingInput 12 | from util import FeatureInput 13 | from omegaconf import OmegaConf 14 | 15 | 16 | def save_wav(wav, path, rate): 17 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 18 | wavfile.write(path, rate, wav.astype(np.int16)) 19 | 20 | 21 | def load_svs_model(checkpoint_path, model): 22 | assert os.path.isfile(checkpoint_path) 23 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 24 | saved_state_dict = checkpoint_dict["model_g"] 25 | state_dict = model.state_dict() 26 | new_state_dict = {} 27 | for k, v in state_dict.items(): 28 | try: 29 | new_state_dict[k] = saved_state_dict[k] 30 | except: 31 | print("%s is not in the checkpoint" % k) 32 | new_state_dict[k] = v 33 | model.load_state_dict(new_state_dict) 34 | return model 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-c', '--config', type=str, required=True, 40 | help="yaml file for configuration") 41 | parser.add_argument('-m', '--model', type=str, required=True, 42 | help="path of checkpoint pt file") 43 | args = parser.parse_args() 44 | 45 | # define model and load checkpoint 46 | hps = OmegaConf.load(args.config) 47 | 48 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 49 | featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length) 50 | 51 | net_g = SynthesizerTrn( 52 | hps.data.filter_length // 2 + 1, 53 | hps.data.segment_size // hps.data.hop_length, 54 | hps).cuda() 55 | net_g.eval() 56 | 57 | load_svs_model(args.model, net_g) 58 | 59 | # check directory existence 60 | os.makedirs("./svs_out", exist_ok=True) 61 | fo = open("./svs_song.txt", "r+") 62 | song_rate = hps.data.sampling_rate 63 | song_time = fo.readline().strip().split("|")[1] 64 | song_length = int(song_rate * (float(song_time) + 30)) 65 | song_data = np.zeros(song_length, dtype="float32") 66 | while True: 67 | try: 68 | message = fo.readline().strip() 69 | except Exception as e: 70 | print("nothing of except:", e) 71 | break 72 | if message == None: 73 | break 74 | if message == "": 75 | break 76 | ( 77 | item_indx, 78 | item_time, 79 | labels_ids, 80 | labels_frames, 81 | scores_ids, 82 | scores_dur, 83 | labels_slr, 84 | labels_uvs, 85 | ) = singInput.parseSong(message) 86 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 87 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 88 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 89 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 90 | scores_pit = singInput.scorePitch(scores_ids) 91 | # elments by elments 92 | scores_pit = scores_pit * labels_uvs 93 | # scores_pit = singInput.smoothPitch(scores_pit) 94 | # scores_pit = scores_pit * labels_uvs 95 | phone = torch.LongTensor(labels_ids) 96 | score = torch.LongTensor(scores_ids) 97 | slurs = torch.LongTensor(labels_slr) 98 | pitch = torch.FloatTensor(scores_pit) 99 | 100 | phone_lengths = phone.size()[0] 101 | 102 | begin_time = time() 103 | with torch.no_grad(): 104 | phone = phone.cuda().unsqueeze(0) 105 | score = score.cuda().unsqueeze(0) 106 | pitch = pitch.cuda().unsqueeze(0) 107 | slurs = slurs.cuda().unsqueeze(0) 108 | phone_lengths = torch.LongTensor([phone_lengths]).cuda() 109 | audio = ( 110 | net_g.infer(phone, phone_lengths, score, pitch, slurs)[0, 0] 111 | .data.cpu() 112 | .float() 113 | .numpy() 114 | ) 115 | 116 | save_wav(audio, f"./svs_out/{item_indx}.wav", hps.data.sampling_rate) 117 | # wav 118 | item_start = int(song_rate * float(item_time)) 119 | item_end = item_start + len(audio) 120 | song_data[item_start:item_end] = audio 121 | # out of for 122 | song_data = np.array(song_data, dtype="float32") 123 | save_wav(song_data, f"./svs_out/_song.wav", hps.data.sampling_rate) 124 | fo.close() 125 | # can be deleted 126 | os.system("chmod 777 ./svs_out -R") 127 | -------------------------------------------------------------------------------- /svs_song.txt: -------------------------------------------------------------------------------- 1 | song_time|116.88723672656248 2 | 0|0000.694| 化 外 山 间 岁 月 皆 看 老|h ua w ai sh an j ian s ui y ve j ie k an l ao|57 57 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.506 0.506|0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 0.241 0.096 0.506|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 | 1|0006.140| 洛 雪 无 声 天 地 掩 尘 嚣|l uo x ve w u sh eng t ian d i y an ch en x iao|57 57 64 64 62 62 60 60 59 59 60 60 62 62 64 64 69 69|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.590 0.590|0.096 0.249 0.088 0.249 0.088 0.249 0.088 0.305 0.032 0.305 0.032 0.249 0.088 0.273 0.064 0.249 0.088 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 | 2|0010.923| 他 看 尽 晨 曦 日 暮 AP 饮 罢 腰 间 酒 一 壶 AP 依 稀 当 年 孤 旅 踏 苍 霞 尽 处|t a k an j in ch en x i r i m u AP y in b a y ao j ian j iu y i h u AP y i x i d ang n ian g u l v t a c ang x ia j in ch u|60 60 62 62 64 64 62 62 67 67 64 64 62 62 rest 64 64 67 67 72 72 71 71 69 69 67 67 69 69 rest 67 67 64 64 62 62 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 1.180 1.180|0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.249 0.088 0.337 0.249 0.088 0.297 0.040 0.249 0.088 0.273 0.064 0.273 0.064 0.249 0.088 0.273 0.064 0.421 0.165 0.088 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 1.180|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 5 | 3|0021.678| 风 霜 冷 冽 他 眉 目 AP 时 光 雕 琢 他 风 骨 AP 浮 世 南 柯 一 梦 冷 暖 都 藏 住|f eng sh uang l eng l ie t a m ei m u AP sh i g uang d iao z uo t a f eng g u AP f u sh i n an k e y i m eng l eng n uan d ou c ang zh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.064 0.249 0.088 0.241 0.096 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.337 0.249 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.337 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.241 0.096 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 | 4|0032.356| 哪 杯 酒 烫 过 肺 腑 AP 曾 换 他 睥 睨 一 顾 AP 剑 破 乾 坤 轮 转 山 河 倾 覆|n a b ei j iu t ang g uo f ei f u AP c eng h uan t a p i n i y i g u AP j ian p o q ian k un l un zh uan sh an h e q ing f u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 67 67 62 62 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 1.348 1.348|0.088 0.297 0.040 0.273 0.064 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.337 0.249 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.337 0.273 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.241 0.096 0.273 0.064 0.249 0.088 0.610 0.064 0.241 0.096 0.273 0.064 1.348|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 | 5|0043.620| 他 三 清 尘 外 剔 去 心 中 毒|t a s an q ing ch en w ai t i q v x in zh ong d u|57 57 60 60 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.169 0.169 0.169 0.169 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.590 0.590|0.032 0.081 0.088 0.073 0.096 0.610 0.064 0.249 0.088 0.305 0.032 0.241 0.096 0.249 0.088 0.273 0.064 0.305 0.032 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 | 6|0048.981| 尝 世 间 百 味 甘 醇 与 涩 苦|ch ang sh i j ian b ai w ei g an ch un y v s e k u|57 57 60 60 64 64 62 62 60 60 59 59 60 60 62 62 64 64 69 69|0.169 0.169 0.169 0.169 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 1.180 1.180|0.064 0.081 0.088 0.105 0.064 0.634 0.040 0.249 0.088 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 1.180|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 | 7|0053.929| 曾 有 谁 偏 执 不 悟 AP 谈 笑 斗 酒 至 酣 处 AP 而 今 不 过 拍 去 肩 上 红 尘 土|c eng y ou sh ui p ian zh i b u w u AP t an x iao d ou j iu zh i h an ch u AP er j in b u g uo p ai q v j ian sh ang h ong ch en t u|60 60 62 62 64 64 67 67 64 64 67 67 62 62 rest 62 62 67 67 72 72 71 71 69 69 67 67 69 69 rest 67 64 64 62 62 62 62 64 64 67 67 60 60 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.088 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.297 0.040 0.249 0.088 0.337 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.273 0.064 0.337 0.337 0.273 0.064 0.297 0.040 0.273 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.249 0.088 0.273 0.064 0.273 0.064 0.305 0.032 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 10 | 8|0064.655| 风 霜 冷 冽 他 眉 目 时 光 雕 琢 他 风 骨 浮 世 南 柯 一 梦 冷 暖 都 藏 住|f eng sh uang l eng l ie t a m ei m u sh i g uang d iao z uo t a f eng g u f u sh i n an k e y i m eng l eng n uan d ou c ang zh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 64 64 62 62 64 64 62 62 67 67 64 64 60 60 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.506 0.506|0.064 0.249 0.088 0.241 0.096 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.586 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.610 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.241 0.096 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.506|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 11 | 9|0075.418| 哪 杯 酒 烫 过 肺 腑 曾 换 他 睥 睨 一 顾 AP 剑 破 乾 坤 轮 转 山 河 倾 覆|n a b ei j iu t ang g uo f ei f u c eng h uan t a p i n i y i g u AP j ian p o q ian k un l un zh uan sh an h e q ing f u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 67 67 62 62 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.674 0.674|0.088 0.297 0.040 0.273 0.064 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.586 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.421 0.189 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.241 0.096 0.273 0.064 0.249 0.088 0.610 0.064 0.241 0.096 0.273 0.064 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 12 | 10|0086.260| 到 最 后 沧 海 一 粟 AP 何 必 江 湖 多 殊 途 AP 当 年 论 剑 峰 顶 谁 几 笔 成 书|d ao z ui h ou c ang h ai y i s u AP h e b i j iang h u d uo sh u t u AP d ang n ian l un j ian f eng d ing sh ui j i b i ch eng sh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.032 0.249 0.088 0.273 0.064 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.421 0.189 0.064 0.297 0.040 0.273 0.064 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.421 0.221 0.032 0.249 0.088 0.241 0.096 0.273 0.064 0.273 0.064 0.305 0.032 0.249 0.088 0.273 0.064 0.297 0.040 0.273 0.064 0.249 0.088 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 13 | 11|0096.991| 纵 他 朝 众 生 再 晤 AP 奈 何 明 月 终 辜 负 AP 坐 听 晨 钟 难 算 太 虚 有 无|z ong t a ch ao zh ong sh eng z ai w u AP n ai h e m ing y ve zh ong g u f u AP z uo t ing ch en zh ong n an s uan t ai x v y ou w u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 62 62 60 60 59 59 60 60 59 59 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.169 0.169 1.264 1.264|0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.249 0.088 0.421 0.165 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 0.273 0.064 0.421 0.165 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.305 0.032 0.586 0.088 0.249 0.088 0.081 0.088 1.264|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 | 12|0107.917| 天 道 勘 破 敢 问 一 句 悟 不|t ian d ao k an p o g an w en y i j v w u b u|57 57 64 64 62 62 64 64 62 62 60 60 59 59 60 60 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.506 0.506 0.337 0.337 0.590 0.590|0.032 0.305 0.032 0.273 0.064 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.357 0.064 0.418 0.088 0.297 0.040 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 | 13|0112.496| 悟 悟|w u w u|68 68 69 69|0.506 0.506 3.792 3.792|0.088 0.418 0.088 3.792|0 0 0 0 -------------------------------------------------------------------------------- /svs_song_pitch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from scipy.io import wavfile 5 | from time import * 6 | 7 | import torch 8 | import argparse 9 | 10 | from vits.models import SynthesizerTrn 11 | from util import SingInput 12 | from util import FeatureInput 13 | from omegaconf import OmegaConf 14 | 15 | 16 | from pitch.models import PitchDiffusion 17 | from pitch.utils import fix_len_compatibility 18 | 19 | 20 | def save_wav(wav, path, rate): 21 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 22 | wavfile.write(path, rate, wav.astype(np.int16)) 23 | 24 | 25 | def load_svs_model(checkpoint_path, model): 26 | assert os.path.isfile(checkpoint_path) 27 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 28 | saved_state_dict = checkpoint_dict["model_g"] 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | print("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | model.load_state_dict(new_state_dict) 38 | return model 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('-c', '--config', type=str, required=True, 44 | help="yaml file for configuration") 45 | parser.add_argument('-m', '--model', type=str, required=True, 46 | help="path of checkpoint pt file") 47 | parser.add_argument('-p', '--pitch', type=str, required=True, 48 | help="path of checkpoint pt file") 49 | args = parser.parse_args() 50 | 51 | # define model and load checkpoint 52 | hps = OmegaConf.load(args.config) 53 | 54 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 55 | featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length) 56 | 57 | net_g = SynthesizerTrn( 58 | hps.data.filter_length // 2 + 1, 59 | hps.data.segment_size // hps.data.hop_length, 60 | hps).cuda() 61 | net_g.eval() 62 | 63 | load_svs_model(args.model, net_g) 64 | 65 | net_p = PitchDiffusion().cuda() 66 | net_p.eval() 67 | load_svs_model(args.pitch, net_p) 68 | 69 | # check directory existence 70 | os.makedirs("./svs_out", exist_ok=True) 71 | fo = open("./svs_song.txt", "r+") 72 | song_rate = hps.data.sampling_rate 73 | song_time = fo.readline().strip().split("|")[1] 74 | song_length = int(song_rate * (float(song_time) + 30)) 75 | song_data = np.zeros(song_length, dtype="float32") 76 | while True: 77 | try: 78 | message = fo.readline().strip() 79 | except Exception as e: 80 | print("nothing of except:", e) 81 | break 82 | if message == None: 83 | break 84 | if message == "": 85 | break 86 | ( 87 | item_indx, 88 | item_time, 89 | labels_ids, 90 | labels_frames, 91 | scores_ids, 92 | scores_dur, 93 | labels_slr, 94 | labels_uvs, 95 | ) = singInput.parseSong(message) 96 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 97 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 98 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 99 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 100 | 101 | phone = torch.LongTensor(labels_ids) 102 | score = torch.LongTensor(scores_ids) 103 | slurs = torch.LongTensor(labels_slr) 104 | 105 | lengths = phone.size()[0] 106 | lengths_fix = fix_len_compatibility(lengths) 107 | 108 | phone_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda() 109 | score_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda() 110 | slurs_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda() 111 | phone_fix[0, :lengths] = phone 112 | score_fix[0, :lengths] = score 113 | slurs_fix[0, :lengths] = slurs 114 | 115 | begin_time = time() 116 | with torch.no_grad(): 117 | n_timesteps = 50 118 | temperature = 1 119 | # PIT 120 | phone_lengths = torch.LongTensor([lengths_fix]).cuda() 121 | pit_pri, pit_pre = net_p(phone_fix, phone_lengths, score_fix, slurs_fix, n_timesteps, temperature) 122 | pitch = pit_pre[:, 0, :] 123 | pitch = 2**pitch 124 | print('~~~~~~~') 125 | audio = ( 126 | net_g.infer(phone_fix, phone_lengths, score_fix, pitch, slurs_fix)[0, 0] 127 | .data.cpu() 128 | .float() 129 | .numpy() 130 | ) 131 | 132 | save_wav(audio, f"./svs_out/{item_indx}.wav", hps.data.sampling_rate) 133 | # wav 134 | item_start = int(song_rate * float(item_time)) 135 | item_end = item_start + len(audio) 136 | song_data[item_start:item_end] = audio 137 | # out of for 138 | song_data = np.array(song_data, dtype="float32") 139 | save_wav(song_data, f"./svs_out/_song.wav", hps.data.sampling_rate) 140 | fo.close() 141 | # can be deleted 142 | os.system("chmod 777 ./svs_out -R") 143 | -------------------------------------------------------------------------------- /svs_train.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import argparse 4 | import torch 5 | import torch.multiprocessing as mp 6 | from omegaconf import OmegaConf 7 | 8 | from vits_extend.train import train 9 | 10 | torch.backends.cudnn.benchmark = True 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-c', '--config', type=str, required=True, 16 | help="yaml file for configuration") 17 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, 18 | help="path of checkpoint pt file to resume training") 19 | parser.add_argument('-n', '--name', type=str, required=True, 20 | help="name of the model for logging, saving checkpoint") 21 | args = parser.parse_args() 22 | 23 | hp = OmegaConf.load(args.config) 24 | with open(args.config, 'r') as f: 25 | hp_str = ''.join(f.readlines()) 26 | 27 | assert hp.data.hop_length == 320, \ 28 | 'hp.data.hop_length must be equal to 320, got %d' % hp.data.hop_length 29 | 30 | args.num_gpus = 0 31 | torch.manual_seed(hp.train.seed) 32 | if torch.cuda.is_available(): 33 | torch.cuda.manual_seed(hp.train.seed) 34 | args.num_gpus = torch.cuda.device_count() 35 | print('Batch size per GPU :', hp.train.batch_size) 36 | 37 | if args.num_gpus > 1: 38 | mp.spawn(train, nprocs=args.num_gpus, 39 | args=(args, args.checkpoint_path, hp, hp_str,)) 40 | else: 41 | train(0, args, args.checkpoint_path, hp, hp_str) 42 | else: 43 | print('No GPU find!') 44 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import pyworld 4 | 5 | from svs import label_to_ids, load_midi_map, uv_map 6 | 7 | 8 | class SingInput(object): 9 | def __init__(self, samplerate=32000, hop_size=320): 10 | self.fs = samplerate 11 | self.hop = hop_size 12 | self.notemaper = load_midi_map() 13 | 14 | def phone_to_uv(self, phones): 15 | uv = [] 16 | for phone in phones: 17 | uv.append(uv_map[phone.lower()]) 18 | return uv 19 | 20 | def notes_to_id(self, notes): 21 | note_ids = [] 22 | for note in notes: 23 | note_ids.append(self.notemaper[note]) 24 | return note_ids 25 | 26 | def frame_duration(self, durations): 27 | ph_durs = [float(x) for x in durations] 28 | sentence_length = 0 29 | for ph_dur in ph_durs: 30 | sentence_length = sentence_length + ph_dur 31 | sentence_length = int(sentence_length * self.fs / self.hop + 0.5) 32 | 33 | sample_frame = [] 34 | startTime = 0 35 | for i_ph in range(len(ph_durs)): 36 | start_frame = int(startTime * self.fs / self.hop + 0.5) 37 | end_frame = int((startTime + ph_durs[i_ph]) * self.fs / self.hop + 0.5) 38 | count_frame = end_frame - start_frame 39 | sample_frame.append(count_frame) 40 | startTime = startTime + ph_durs[i_ph] 41 | all_frame = np.sum(sample_frame) 42 | assert all_frame == sentence_length 43 | # match mel length 44 | sample_frame[-1] = sample_frame[-1] - 1 45 | return sample_frame 46 | 47 | def score_duration(self, durations): 48 | ph_durs = [float(x) for x in durations] 49 | sample_frame = [] 50 | for i_ph in range(len(ph_durs)): 51 | count_frame = int(ph_durs[i_ph] * self.fs / self.hop + 0.5) 52 | if count_frame >= 256: 53 | print("count_frame", count_frame) 54 | count_frame = 255 55 | sample_frame.append(count_frame) 56 | return sample_frame 57 | 58 | def parseInput(self, singinfo: str): 59 | infos = singinfo.split("|") 60 | file = infos[0] 61 | # hanz = infos[1] 62 | phon = infos[2].split(" ") 63 | note = infos[3].split(" ") 64 | note_dur = infos[4].split(" ") 65 | phon_dur = infos[5].split(" ") 66 | phon_slr = infos[6].split(" ") 67 | 68 | labels_ids = label_to_ids(phon) 69 | labels_uvs = self.phone_to_uv(phon) 70 | labels_frames = self.frame_duration(phon_dur) 71 | scores_ids = self.notes_to_id(note) 72 | scores_dur = self.score_duration(note_dur) 73 | labels_slr = [int(x) for x in phon_slr] 74 | return ( 75 | file, 76 | labels_ids, 77 | labels_frames, 78 | scores_ids, 79 | scores_dur, 80 | labels_slr, 81 | labels_uvs, 82 | ) 83 | 84 | def parseSong(self, singinfo: str): 85 | infos = singinfo.split("|") 86 | item_indx = infos[0] 87 | item_time = infos[1] 88 | # hanz = infos[2] 89 | phon = infos[3].split(" ") 90 | note_ids = infos[4].split(" ") 91 | note_dur = infos[5].split(" ") 92 | phon_dur = infos[6].split(" ") 93 | phon_slr = infos[7].split(" ") 94 | 95 | labels_ids = label_to_ids(phon) 96 | labels_uvs = self.phone_to_uv(phon) 97 | labels_frames = self.frame_duration(phon_dur) 98 | scores_ids = [int(x) if x != "rest" else 0 for x in note_ids] 99 | scores_dur = self.score_duration(note_dur) 100 | labels_slr = [int(x) for x in phon_slr] 101 | return ( 102 | item_indx, 103 | item_time, 104 | labels_ids, 105 | labels_frames, 106 | scores_ids, 107 | scores_dur, 108 | labels_slr, 109 | labels_uvs, 110 | ) 111 | 112 | def expandInput(self, labels_ids, labels_frames): 113 | assert len(labels_ids) == len(labels_frames) 114 | frame_num = np.sum(labels_frames) 115 | frame_labels = np.zeros(frame_num, dtype=np.int) 116 | start = 0 117 | for index, num in enumerate(labels_frames): 118 | frame_labels[start : start + num] = labels_ids[index] 119 | start += num 120 | return frame_labels 121 | 122 | def scorePitch(self, scores_id): 123 | score_pitch = np.zeros(len(scores_id), dtype=np.float) 124 | for index, score_id in enumerate(scores_id): 125 | if score_id == 0: 126 | score_pitch[index] = 0 127 | else: 128 | pitch = librosa.midi_to_hz(score_id) 129 | score_pitch[index] = round(pitch, 1) 130 | return score_pitch 131 | 132 | def smoothPitch(self, pitch): 133 | # 使用卷积对数据平滑 134 | kernel = np.hanning(5) # 随机生成一个卷积核(对称的) 135 | kernel /= kernel.sum() 136 | smooth_pitch = np.convolve(pitch, kernel, "same") 137 | return smooth_pitch 138 | 139 | 140 | class FeatureInput(object): 141 | def __init__(self, samplerate=32000, hop_size=320): 142 | self.fs = samplerate 143 | self.hop = hop_size 144 | 145 | self.f0_bin = 256 146 | self.f0_max = 1100.0 147 | self.f0_min = 50.0 148 | self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) 149 | self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) 150 | 151 | def compute_f0(self, file): 152 | x, sr = librosa.load(file, sr=self.fs) 153 | assert sr == self.fs 154 | f0, t = pyworld.dio( 155 | x.astype(np.double), 156 | fs=sr, 157 | f0_ceil=900, 158 | frame_period=1000 * self.hop / sr, 159 | ) 160 | f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs) 161 | for index, pitch in enumerate(f0): 162 | f0[index] = round(pitch, 1) 163 | return f0 164 | 165 | def coarse_f0(self, f0): 166 | f0_mel = 1127 * np.log(1 + f0 / 700) 167 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * ( 168 | self.f0_bin - 2 169 | ) / (self.f0_mel_max - self.f0_mel_min) + 1 170 | 171 | # use 0 or 1 172 | f0_mel[f0_mel <= 1] = 1 173 | f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1 174 | f0_coarse = np.rint(f0_mel).astype(np.int) 175 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( 176 | f0_coarse.max(), 177 | f0_coarse.min(), 178 | ) 179 | return f0_coarse 180 | -------------------------------------------------------------------------------- /util/generate_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--file', type=str, required=True) 9 | args = parser.parse_args() 10 | alls = [] 11 | fo = open(args.file, "r+") 12 | while True: 13 | try: 14 | message = fo.readline().strip() 15 | except Exception as e: 16 | print("nothing of except:", e) 17 | break 18 | if message == None: 19 | break 20 | if message == "": 21 | break 22 | alls.append(message) 23 | fo.close() 24 | 25 | valids = alls[:5] 26 | trains = alls[5:] 27 | 28 | random.shuffle(trains) 29 | os.makedirs("filelists", exist_ok=True) 30 | 31 | fw = open("./filelists/singing_valid.txt", "w", encoding="utf-8") 32 | for strs in valids: 33 | print(strs, file=fw) 34 | fw.close() 35 | 36 | fw = open("./filelists/singing_train.txt", "w", encoding="utf-8") 37 | for strs in trains: 38 | print(strs, file=fw) 39 | fw.close() 40 | -------------------------------------------------------------------------------- /util/generate_label.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.getcwd()) 3 | import logging 4 | logging.basicConfig(level=logging.INFO) # ERROR & INFO 5 | import argparse 6 | import numpy as np 7 | 8 | from omegaconf import OmegaConf 9 | from util import SingInput, FeatureInput 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--config', type=str, required=True) 15 | parser.add_argument('--data', type=str, required=True) 16 | parser.add_argument('--file', type=str, required=True) 17 | args = parser.parse_args() 18 | 19 | hps = OmegaConf.load(args.config) 20 | 21 | assert os.path.exists(args.file) 22 | assert os.path.exists(os.path.join(args.data, "wavs")) 23 | os.makedirs(os.path.join(args.data, "labels"), exist_ok=True) 24 | 25 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 26 | featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length) 27 | 28 | raw_file = open(args.file, "r+") 29 | vits_file = open(os.path.join(args.data, "labels.txt"), 30 | "w", encoding="utf-8") 31 | label_path = os.path.join(args.data, "labels") 32 | i = 0 33 | all_txt = [] # 统计非重复的句子个数 34 | while True: 35 | try: 36 | message = raw_file.readline().strip() 37 | except Exception as e: 38 | print("nothing of except:", e) 39 | break 40 | if message == None: 41 | break 42 | if message == "": 43 | break 44 | # i = i + 1 45 | # if i > 5: 46 | # break 47 | infos = message.split("|") 48 | file = infos[0] 49 | hanz = infos[1] 50 | all_txt.append(hanz) 51 | phon = infos[2].split(" ") 52 | note = infos[3].split(" ") 53 | note_dur = infos[4].split(" ") 54 | phon_dur = infos[5].split(" ") 55 | phon_slur = infos[6].split(" ") 56 | 57 | logging.info("----------------------------") 58 | logging.info(file) 59 | logging.info(hanz) 60 | logging.info(phon) 61 | # logging.info(note_dur) 62 | # logging.info(phon_dur) 63 | # logging.info(phon_slur) 64 | path_wave = os.path.join(args.data, "wavs", f"{file}.wav") 65 | path_label = os.path.join(label_path, f"{file}_label.npy") 66 | path_score = os.path.join(label_path, f"{file}_score.npy") 67 | path_pitch = os.path.join(label_path, f"{file}_pitch.npy") 68 | path_slurs = os.path.join(label_path, f"{file}_slurs.npy") 69 | 70 | ( 71 | file, 72 | labels_ids, 73 | labels_frames, 74 | scores_ids, 75 | scores_dur, 76 | labels_slr, 77 | labels_uvs, 78 | ) = singInput.parseInput(message) 79 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 80 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 81 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 82 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 83 | 84 | featur_pit = featureInput.compute_f0(path_wave) 85 | featur_pit = featur_pit[: len(labels_ids)] 86 | featur_pit = featur_pit * labels_uvs 87 | 88 | assert len(labels_ids) == len(featur_pit) 89 | 90 | np.save(path_label, labels_ids, allow_pickle=False) 91 | np.save(path_score, scores_ids, allow_pickle=False) 92 | np.save(path_pitch, featur_pit, allow_pickle=False) 93 | np.save(path_slurs, labels_slr, allow_pickle=False) 94 | 95 | print( 96 | f"{path_wave}|{path_label}|{path_score}|{path_pitch}|{path_slurs}", 97 | file=vits_file, 98 | ) 99 | 100 | raw_file.close() 101 | vits_file.close() 102 | print(len(set(all_txt))) # 统计非重复的句子个数 103 | -------------------------------------------------------------------------------- /util/resample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | from scipy.io import wavfile 8 | 9 | 10 | def resample_wave(wav_in, wav_out, sample_rate): 11 | wav, _ = librosa.load(wav_in, sr=sample_rate) 12 | wav = wav / np.abs(wav).max() * 0.6 13 | wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6 14 | wavfile.write(wav_out, sample_rate, wav.astype(np.int16)) 15 | 16 | 17 | def process_file(file, wavPath, outPath, sr): 18 | if file.endswith(".wav"): 19 | file = file[:-4] 20 | resample_wave(f"{wavPath}/{file}.wav", f"{outPath}/{file}.wav", sr) 21 | 22 | 23 | def process_files_with_thread_pool(wavPath, outPath, sr, thread_num=None): 24 | files = [f for f in os.listdir(f"./{wavPath}") if f.endswith(".wav")] 25 | 26 | with ThreadPoolExecutor(max_workers=thread_num) as executor: 27 | futures = {executor.submit(process_file, file, wavPath, outPath, sr): file for file in files} 28 | 29 | for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr}'): 30 | future.result() 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) 36 | parser.add_argument("-o", "--out", help="out", dest="out", required=True) 37 | parser.add_argument("-s", "--sr", help="sample rate", dest="sr", type=int, required=True) 38 | parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) 39 | 40 | args = parser.parse_args() 41 | print(args.wav) 42 | print(args.out) 43 | print(args.sr) 44 | 45 | os.makedirs(args.out, exist_ok=True) 46 | wavPath = args.wav 47 | outPath = args.out 48 | 49 | if args.thread_count == 0: 50 | process_num = os.cpu_count() // 2 + 1 51 | else: 52 | process_num = args.thread_count 53 | process_files_with_thread_pool(wavPath, outPath, args.sr, process_num) 54 | -------------------------------------------------------------------------------- /vits/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/vits/__init__.py -------------------------------------------------------------------------------- /vits/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from einops import rearrange 5 | 6 | 7 | class LayerNorm(nn.Module): 8 | def __init__(self, channels, eps=1e-4): 9 | super(LayerNorm, self).__init__() 10 | self.channels = channels 11 | self.eps = eps 12 | 13 | self.gamma = torch.nn.Parameter(torch.ones(channels)) 14 | self.beta = torch.nn.Parameter(torch.zeros(channels)) 15 | 16 | def forward(self, x): 17 | n_dims = len(x.shape) 18 | mean = torch.mean(x, 1, keepdim=True) 19 | variance = torch.mean((x - mean)**2, 1, keepdim=True) 20 | 21 | x = (x - mean) * torch.rsqrt(variance + self.eps) 22 | 23 | shape = [1, -1] + [1] * (n_dims - 2) 24 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 25 | return x 26 | 27 | 28 | class ConvReluNorm(nn.Module): 29 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, 30 | n_layers, p_dropout, eps=1e-5): 31 | super(ConvReluNorm, self).__init__() 32 | self.in_channels = in_channels 33 | self.hidden_channels = hidden_channels 34 | self.out_channels = out_channels 35 | self.kernel_size = kernel_size 36 | self.n_layers = n_layers 37 | self.p_dropout = p_dropout 38 | self.eps = eps 39 | 40 | self.conv_layers = torch.nn.ModuleList() 41 | self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, 42 | kernel_size, padding=kernel_size//2)) 43 | self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) 44 | for _ in range(n_layers - 1): 45 | self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, 46 | kernel_size, padding=kernel_size//2)) 47 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) 48 | self.proj.weight.data.zero_() 49 | self.proj.bias.data.zero_() 50 | 51 | def forward(self, x, x_mask): 52 | for i in range(self.n_layers): 53 | x = self.conv_layers[i](x * x_mask) 54 | x = self.instance_norm(x, x_mask) 55 | x = self.relu_drop(x) 56 | x = self.proj(x) 57 | return x * x_mask 58 | 59 | def instance_norm(self, x, mask, return_mean_std=False): 60 | mean, std = self.calc_mean_std(x, mask) 61 | x = (x - mean) / std 62 | if return_mean_std: 63 | return x, mean, std 64 | else: 65 | return x 66 | 67 | def calc_mean_std(self, x, mask=None): 68 | x = x * mask 69 | B, C = x.shape[:2] 70 | mn = x.view(B, C, -1).mean(-1) 71 | sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt() 72 | mn = mn.view(B, C, *((len(x.shape) - 2) * [1])) 73 | sd = sd.view(B, C, *((len(x.shape) - 2) * [1])) 74 | return mn, sd 75 | 76 | 77 | class RotaryPositionalEmbeddings(nn.Module): 78 | """ 79 | ## RoPE module 80 | https://github.com/labmlai/annotated_deep_learning_paper_implementations 81 | 82 | Rotary encoding transforms pairs of features by rotating in the 2D plane. 83 | That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. 84 | Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it 85 | by an angle depending on the position of the token. 86 | """ 87 | def __init__(self, d: int, base: int = 10_000): 88 | r""" 89 | * `d` is the number of features $d$ 90 | * `base` is the constant used for calculating $\Theta$ 91 | """ 92 | super().__init__() 93 | self.base = base 94 | self.d = int(d) 95 | self.cos_cached = None 96 | self.sin_cached = None 97 | 98 | def _build_cache(self, x: torch.Tensor): 99 | r""" 100 | Cache $\cos$ and $\sin$ values 101 | """ 102 | # Return if cache is already built 103 | if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: 104 | return 105 | # Get sequence length 106 | seq_len = x.shape[0] 107 | theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) 108 | # Create position indexes `[0, 1, ..., seq_len - 1]` 109 | seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) 110 | # Calculate the product of position index and $\theta_i$ 111 | idx_theta = torch.einsum("n,d->nd", seq_idx, theta) 112 | # Concatenate so that for row $m$ we have 113 | idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) 114 | # Cache them 115 | self.cos_cached = idx_theta2.cos()[:, None, None, :] 116 | self.sin_cached = idx_theta2.sin()[:, None, None, :] 117 | 118 | def _neg_half(self, x: torch.Tensor): 119 | d_2 = self.d // 2 120 | return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) 121 | 122 | def forward(self, x: torch.Tensor): 123 | """ 124 | * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` 125 | """ 126 | x = rearrange(x, "b h t d -> t b h d") 127 | self._build_cache(x) 128 | # Split the features, we can choose to apply rotary embeddings only to a partial set of features. 129 | x_rope, x_pass = x[..., : self.d], x[..., self.d :] 130 | # Calculate 131 | neg_half_x = self._neg_half(x_rope) 132 | x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) 133 | return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") 134 | 135 | 136 | class MultiHeadAttention(nn.Module): 137 | def __init__(self, channels, out_channels, n_heads, 138 | heads_share=True, p_dropout=0.0, proximal_bias=False, 139 | proximal_init=False): 140 | super(MultiHeadAttention, self).__init__() 141 | assert channels % n_heads == 0 142 | 143 | self.channels = channels 144 | self.out_channels = out_channels 145 | self.n_heads = n_heads 146 | self.heads_share = heads_share 147 | self.proximal_bias = proximal_bias 148 | self.p_dropout = p_dropout 149 | self.attn = None 150 | 151 | self.k_channels = channels // n_heads 152 | self.conv_q = torch.nn.Conv1d(channels, channels, 1) 153 | self.conv_k = torch.nn.Conv1d(channels, channels, 1) 154 | self.conv_v = torch.nn.Conv1d(channels, channels, 1) 155 | 156 | # from https://nn.labml.ai/transformers/rope/index.html 157 | self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) 158 | self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) 159 | 160 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) 161 | self.drop = torch.nn.Dropout(p_dropout) 162 | 163 | torch.nn.init.xavier_uniform_(self.conv_q.weight) 164 | torch.nn.init.xavier_uniform_(self.conv_k.weight) 165 | if proximal_init: 166 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 167 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 168 | torch.nn.init.xavier_uniform_(self.conv_v.weight) 169 | 170 | def forward(self, x, c, attn_mask=None): 171 | q = self.conv_q(x) 172 | k = self.conv_k(c) 173 | v = self.conv_v(c) 174 | 175 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 176 | 177 | x = self.conv_o(x) 178 | return x 179 | 180 | def attention(self, query, key, value, mask=None): 181 | b, d, t_s, t_t = (*key.size(), query.size(2)) 182 | query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) 183 | key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) 184 | value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) 185 | 186 | query = self.query_rotary_pe(query) 187 | key = self.key_rotary_pe(key) 188 | 189 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 190 | 191 | if self.proximal_bias: 192 | assert t_s == t_t, "Proximal bias is only available for self-attention." 193 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, 194 | dtype=scores.dtype) 195 | if mask is not None: 196 | scores = scores.masked_fill(mask == 0, -1e4) 197 | p_attn = torch.nn.functional.softmax(scores, dim=-1) 198 | p_attn = self.drop(p_attn) 199 | output = torch.matmul(p_attn, value) 200 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) 201 | return output, p_attn 202 | 203 | def _attention_bias_proximal(self, length): 204 | r = torch.arange(length, dtype=torch.float32) 205 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 206 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 207 | 208 | 209 | class FFN(nn.Module): 210 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, 211 | p_dropout=0.0): 212 | super(FFN, self).__init__() 213 | self.in_channels = in_channels 214 | self.out_channels = out_channels 215 | self.filter_channels = filter_channels 216 | self.kernel_size = kernel_size 217 | self.p_dropout = p_dropout 218 | 219 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, 220 | padding=kernel_size//2) 221 | self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, 222 | padding=kernel_size//2) 223 | self.drop = torch.nn.Dropout(p_dropout) 224 | 225 | def forward(self, x, x_mask): 226 | x = self.conv_1(x * x_mask) 227 | x = torch.relu(x) 228 | x = self.drop(x) 229 | x = self.conv_2(x * x_mask) 230 | return x * x_mask 231 | 232 | 233 | class Encoder(nn.Module): 234 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, 235 | kernel_size=1, p_dropout=0.0, **kwargs): 236 | super(Encoder, self).__init__() 237 | self.hidden_channels = hidden_channels 238 | self.filter_channels = filter_channels 239 | self.n_heads = n_heads 240 | self.n_layers = n_layers 241 | self.kernel_size = kernel_size 242 | self.p_dropout = p_dropout 243 | 244 | self.drop = torch.nn.Dropout(p_dropout) 245 | self.attn_layers = torch.nn.ModuleList() 246 | self.norm_layers_1 = torch.nn.ModuleList() 247 | self.ffn_layers = torch.nn.ModuleList() 248 | self.norm_layers_2 = torch.nn.ModuleList() 249 | for _ in range(self.n_layers): 250 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, 251 | n_heads, p_dropout=p_dropout)) 252 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 253 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, 254 | filter_channels, kernel_size, p_dropout=p_dropout)) 255 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 256 | 257 | def forward(self, x, x_mask): 258 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 259 | for i in range(self.n_layers): 260 | x = x * x_mask 261 | y = self.attn_layers[i](x, x, attn_mask) 262 | y = self.drop(y) 263 | x = self.norm_layers_1[i](x + y) 264 | y = self.ffn_layers[i](x, x_mask) 265 | y = self.drop(y) 266 | x = self.norm_layers_2[i](x + y) 267 | x = x * x_mask 268 | return x 269 | -------------------------------------------------------------------------------- /vits/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def slice_pitch_segments(x, ids_str, segment_size=4): 9 | ret = torch.zeros_like(x[:, :segment_size]) 10 | for i in range(x.size(0)): 11 | idx_str = ids_str[i] 12 | idx_end = idx_str + segment_size 13 | ret[i] = x[i, idx_str:idx_end] 14 | return ret 15 | 16 | 17 | def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): 18 | b, d, t = x.size() 19 | if x_lengths is None: 20 | x_lengths = t 21 | ids_str_max = x_lengths - segment_size + 1 22 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 23 | ret = slice_segments(x, ids_str, segment_size) 24 | ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) 25 | return ret, ret_pitch, ids_str 26 | 27 | 28 | def rand_spec_segments(x, x_lengths=None, segment_size=4): 29 | b, d, t = x.size() 30 | if x_lengths is None: 31 | x_lengths = t 32 | ids_str_max = x_lengths - segment_size 33 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 34 | ret = slice_segments(x, ids_str, segment_size) 35 | return ret, ids_str 36 | 37 | 38 | def init_weights(m, mean=0.0, std=0.01): 39 | classname = m.__class__.__name__ 40 | if classname.find("Conv") != -1: 41 | m.weight.data.normal_(mean, std) 42 | 43 | 44 | def get_padding(kernel_size, dilation=1): 45 | return int((kernel_size * dilation - dilation) / 2) 46 | 47 | 48 | def convert_pad_shape(pad_shape): 49 | l = pad_shape[::-1] 50 | pad_shape = [item for sublist in l for item in sublist] 51 | return pad_shape 52 | 53 | 54 | def kl_divergence(m_p, logs_p, m_q, logs_q): 55 | """KL(P||Q)""" 56 | kl = (logs_q - logs_p) - 0.5 57 | kl += ( 58 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 59 | ) 60 | return kl 61 | 62 | 63 | def rand_gumbel(shape): 64 | """Sample from the Gumbel distribution, protect from overflows.""" 65 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 66 | return -torch.log(-torch.log(uniform_samples)) 67 | 68 | 69 | def rand_gumbel_like(x): 70 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 71 | return g 72 | 73 | 74 | def slice_segments(x, ids_str, segment_size=4): 75 | ret = torch.zeros_like(x[:, :, :segment_size]) 76 | for i in range(x.size(0)): 77 | idx_str = ids_str[i] 78 | idx_end = idx_str + segment_size 79 | ret[i] = x[i, :, idx_str:idx_end] 80 | return ret 81 | 82 | 83 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 84 | b, d, t = x.size() 85 | if x_lengths is None: 86 | x_lengths = t 87 | ids_str_max = x_lengths - segment_size + 1 88 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 89 | ret = slice_segments(x, ids_str, segment_size) 90 | return ret, ids_str 91 | 92 | 93 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 94 | position = torch.arange(length, dtype=torch.float) 95 | num_timescales = channels // 2 96 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 97 | num_timescales - 1 98 | ) 99 | inv_timescales = min_timescale * torch.exp( 100 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 101 | ) 102 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 103 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 104 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 105 | signal = signal.view(1, channels, length) 106 | return signal 107 | 108 | 109 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 110 | b, channels, length = x.size() 111 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 112 | return x + signal.to(dtype=x.dtype, device=x.device) 113 | 114 | 115 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 116 | b, channels, length = x.size() 117 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 118 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 119 | 120 | 121 | def subsequent_mask(length): 122 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 123 | return mask 124 | 125 | 126 | @torch.jit.script 127 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 128 | n_channels_int = n_channels[0] 129 | in_act = input_a + input_b 130 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 131 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 132 | acts = t_act * s_act 133 | return acts 134 | 135 | 136 | def convert_pad_shape(pad_shape): 137 | l = pad_shape[::-1] 138 | pad_shape = [item for sublist in l for item in sublist] 139 | return pad_shape 140 | 141 | 142 | def shift_1d(x): 143 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 144 | return x 145 | 146 | 147 | def sequence_mask(length, max_length=None): 148 | if max_length is None: 149 | max_length = length.max() 150 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 151 | return x.unsqueeze(0) < length.unsqueeze(1) 152 | 153 | 154 | def generate_path(duration, mask): 155 | """ 156 | duration: [b, 1, t_x] 157 | mask: [b, 1, t_y, t_x] 158 | """ 159 | device = duration.device 160 | 161 | b, _, t_y, t_x = mask.shape 162 | cum_duration = torch.cumsum(duration, -1) 163 | 164 | cum_duration_flat = cum_duration.view(b * t_x) 165 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 166 | path = path.view(b, t_x, t_y) 167 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 168 | path = path.unsqueeze(1).transpose(2, 3) * mask 169 | return path 170 | 171 | 172 | def clip_grad_value_(parameters, clip_value, norm_type=2): 173 | if isinstance(parameters, torch.Tensor): 174 | parameters = [parameters] 175 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 176 | norm_type = float(norm_type) 177 | if clip_value is not None: 178 | clip_value = float(clip_value) 179 | 180 | total_norm = 0 181 | for p in parameters: 182 | param_norm = p.grad.data.norm(norm_type) 183 | total_norm += param_norm.item() ** norm_type 184 | if clip_value is not None: 185 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 186 | total_norm = total_norm ** (1.0 / norm_type) 187 | return total_norm 188 | -------------------------------------------------------------------------------- /vits/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | from vits.spectrogram import spectrogram_torch 7 | from vits.utils import load_wav_to_torch 8 | 9 | 10 | def load_filepaths(filename, split="|"): 11 | with open(filename, encoding='utf-8') as f: 12 | filepaths = [line.strip().split(split) for line in f] 13 | return filepaths 14 | 15 | 16 | class TextAudioLoader(torch.utils.data.Dataset): 17 | """ 18 | 1) loads audio, text pairs 19 | 2) normalizes text and converts them to sequences of integers 20 | 3) computes spectrograms from audio files. 21 | """ 22 | 23 | def __init__(self, audiopaths_and_text, hparams): 24 | self.audiopaths_and_text = load_filepaths(audiopaths_and_text) 25 | self.max_wav_value = hparams.max_wav_value 26 | self.sampling_rate = hparams.sampling_rate 27 | self.filter_length = hparams.filter_length 28 | self.hop_length = hparams.hop_length 29 | self.win_length = hparams.win_length 30 | self.sampling_rate = hparams.sampling_rate 31 | self.min_text_len = getattr(hparams, "min_text_len", 1) 32 | self.max_text_len = getattr(hparams, "max_text_len", 5000) 33 | self._filter() 34 | print(f"~~~~~~~~~~~~~~~~~~~~~{len(self)}~~~~~~~~~~~~~~~~~~~~~~~~~~~~") 35 | 36 | def _filter(self): 37 | """ 38 | Filter text & store spec lengths 39 | """ 40 | # Store spectrogram lengths for Bucketing 41 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 42 | # spec_length = wav_length // hop_length 43 | audiopaths_and_text_new = [] 44 | lengths = [] 45 | for audiopath, text, score, pitch, slur in self.audiopaths_and_text: 46 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 47 | wav_len = os.path.getsize(audiopath) // (2 * self.hop_length) 48 | if wav_len < 50: # no use short wave 49 | continue 50 | audiopaths_and_text_new.append([audiopath, text, score, pitch, slur]) 51 | lengths.append(wav_len) 52 | self.audiopaths_and_text = audiopaths_and_text_new 53 | self.lengths = lengths 54 | 55 | def get_audio_text_pair(self, audiopath_and_text): 56 | # separate filename and text 57 | file = audiopath_and_text[0] 58 | phone = audiopath_and_text[1] 59 | score = audiopath_and_text[2] 60 | pitch = audiopath_and_text[3] 61 | slurs = audiopath_and_text[4] 62 | 63 | phone, score, pitch, slurs = self.get_labels(phone, score, pitch, slurs) 64 | spec, wav = self.get_audio(file) 65 | 66 | len_phone = phone.size()[0] 67 | len_spec = spec.size()[-1] 68 | 69 | if len_phone != len_spec: 70 | # print("**************CareFull*******************") 71 | # print(f"filepath={audiopath_and_text[0]}") 72 | # print(f"len_text={len_phone}") 73 | # print(f"len_spec={len_spec}") 74 | if len_phone > len_spec: 75 | print(file) 76 | print("len_phone", len_phone) 77 | print("len_spec", len_spec) 78 | assert len_phone < len_spec 79 | len_min = min(len_phone, len_spec) 80 | len_wav = len_min * self.hop_length 81 | # print(wav.size()) 82 | # print(f"len_min={len_min}") 83 | # print(f"len_wav={len_wav}") 84 | spec = spec[:, :len_min] 85 | wav = wav[:, :len_wav] 86 | return (phone, score, pitch, slurs, spec, wav) 87 | 88 | def get_labels(self, phone, score, pitch, slurs): 89 | phone = np.load(phone) 90 | score = np.load(score) 91 | pitch = np.load(pitch) 92 | slurs = np.load(slurs) 93 | phone = torch.LongTensor(phone) 94 | score = torch.LongTensor(score) 95 | pitch = torch.FloatTensor(pitch) 96 | slurs = torch.LongTensor(slurs) 97 | return phone, score, pitch, slurs 98 | 99 | def get_audio(self, filename): 100 | audio, sampling_rate = load_wav_to_torch(filename) 101 | if sampling_rate != self.sampling_rate: 102 | raise ValueError( 103 | "{} {} SR doesn't match target {} SR".format( 104 | sampling_rate, self.sampling_rate 105 | ) 106 | ) 107 | audio_norm = audio / self.max_wav_value 108 | audio_norm = audio_norm.unsqueeze(0) 109 | spec_filename = filename.replace(".wav", ".spec.pt") 110 | if os.path.exists(spec_filename): 111 | spec = torch.load(spec_filename) 112 | else: 113 | spec = spectrogram_torch( 114 | audio_norm, 115 | self.filter_length, 116 | self.sampling_rate, 117 | self.hop_length, 118 | self.win_length, 119 | center=False, 120 | ) 121 | spec = torch.squeeze(spec, 0) 122 | torch.save(spec, spec_filename) 123 | return spec, audio_norm 124 | 125 | def __getitem__(self, index): 126 | return self.get_audio_text_pair(self.audiopaths_and_text[index]) 127 | 128 | def __len__(self): 129 | return len(self.audiopaths_and_text) 130 | 131 | 132 | class TextAudioCollate: 133 | """Zero-pads model inputs and targets""" 134 | 135 | def __init__(self, return_ids=False): 136 | self.return_ids = return_ids 137 | 138 | def __call__(self, batch): 139 | """Collate's training batch from normalized text and aduio 140 | PARAMS 141 | ------ 142 | batch: [text_normalized, spec_normalized, wav_normalized] 143 | """ 144 | # Right zero-pad all one-hot text sequences to max input length 145 | _, ids_sorted_decreasing = torch.sort( 146 | torch.LongTensor([x[4].size(1) for x in batch]), dim=0, descending=True 147 | ) 148 | 149 | max_phone_len = max([len(x[0]) for x in batch]) 150 | phone_lengths = torch.LongTensor(len(batch)) 151 | phone_padded = torch.LongTensor(len(batch), max_phone_len) 152 | score_padded = torch.LongTensor(len(batch), max_phone_len) 153 | pitch_padded = torch.FloatTensor(len(batch), max_phone_len) 154 | slurs_padded = torch.LongTensor(len(batch), max_phone_len) 155 | phone_padded.zero_() 156 | score_padded.zero_() 157 | pitch_padded.zero_() 158 | slurs_padded.zero_() 159 | 160 | max_spec_len = max([x[4].size(1) for x in batch]) 161 | max_wave_len = max([x[5].size(1) for x in batch]) 162 | spec_lengths = torch.LongTensor(len(batch)) 163 | wave_lengths = torch.LongTensor(len(batch)) 164 | spec_padded = torch.FloatTensor(len(batch), batch[0][4].size(0), max_spec_len) 165 | wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) 166 | spec_padded.zero_() 167 | wave_padded.zero_() 168 | 169 | for i in range(len(ids_sorted_decreasing)): 170 | row = batch[ids_sorted_decreasing[i]] 171 | 172 | phone = row[0] 173 | phone_padded[i, : phone.size(0)] = phone 174 | phone_lengths[i] = phone.size(0) 175 | 176 | score = row[1] 177 | score_padded[i, : score.size(0)] = score 178 | 179 | pitch = row[2] 180 | pitch_padded[i, : pitch.size(0)] = pitch 181 | 182 | slurs = row[3] 183 | slurs_padded[i, : slurs.size(0)] = slurs 184 | 185 | spec = row[4] 186 | spec_padded[i, :, : spec.size(1)] = spec 187 | spec_lengths[i] = spec.size(1) 188 | 189 | wave = row[5] 190 | wave_padded[i, :, : wave.size(1)] = wave 191 | wave_lengths[i] = wave.size(1) 192 | 193 | return ( 194 | phone_padded, 195 | phone_lengths, 196 | score_padded, 197 | pitch_padded, 198 | slurs_padded, 199 | spec_padded, 200 | spec_lengths, 201 | wave_padded, 202 | wave_lengths, 203 | ) 204 | 205 | 206 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 207 | """ 208 | Maintain similar input lengths in a batch. 209 | Length groups are specified by boundaries. 210 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 211 | 212 | It removes samples which are not included in the boundaries. 213 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 214 | """ 215 | 216 | def __init__( 217 | self, 218 | dataset, 219 | batch_size, 220 | boundaries, 221 | num_replicas=None, 222 | rank=None, 223 | shuffle=True, 224 | ): 225 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 226 | self.lengths = dataset.lengths 227 | self.batch_size = batch_size 228 | self.boundaries = boundaries 229 | 230 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 231 | self.total_size = sum(self.num_samples_per_bucket) 232 | self.num_samples = self.total_size // self.num_replicas 233 | 234 | def _create_buckets(self): 235 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 236 | for i in range(len(self.lengths)): 237 | length = self.lengths[i] 238 | idx_bucket = self._bisect(length) 239 | if idx_bucket != -1: 240 | buckets[idx_bucket].append(i) 241 | 242 | for i in range(len(buckets) - 1, 0, -1): 243 | if len(buckets[i]) == 0: 244 | buckets.pop(i) 245 | self.boundaries.pop(i + 1) 246 | 247 | num_samples_per_bucket = [] 248 | for i in range(len(buckets)): 249 | len_bucket = len(buckets[i]) 250 | total_batch_size = self.num_replicas * self.batch_size 251 | rem = ( 252 | total_batch_size - (len_bucket % total_batch_size) 253 | ) % total_batch_size 254 | num_samples_per_bucket.append(len_bucket + rem) 255 | return buckets, num_samples_per_bucket 256 | 257 | def __iter__(self): 258 | # deterministically shuffle based on epoch 259 | g = torch.Generator() 260 | g.manual_seed(self.epoch) 261 | 262 | indices = [] 263 | if self.shuffle: 264 | for bucket in self.buckets: 265 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 266 | else: 267 | for bucket in self.buckets: 268 | indices.append(list(range(len(bucket)))) 269 | 270 | batches = [] 271 | for i in range(len(self.buckets)): 272 | bucket = self.buckets[i] 273 | len_bucket = len(bucket) 274 | if (len_bucket == 0): 275 | continue 276 | ids_bucket = indices[i] 277 | num_samples_bucket = self.num_samples_per_bucket[i] 278 | 279 | # add extra samples to make it evenly divisible 280 | rem = num_samples_bucket - len_bucket 281 | ids_bucket = ( 282 | ids_bucket 283 | + ids_bucket * (rem // len_bucket) 284 | + ids_bucket[: (rem % len_bucket)] 285 | ) 286 | 287 | # subsample 288 | ids_bucket = ids_bucket[self.rank:: self.num_replicas] 289 | 290 | # batching 291 | for j in range(len(ids_bucket) // self.batch_size): 292 | batch = [ 293 | bucket[idx] 294 | for idx in ids_bucket[ 295 | j * self.batch_size: (j + 1) * self.batch_size 296 | ] 297 | ] 298 | batches.append(batch) 299 | 300 | if self.shuffle: 301 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 302 | batches = [batches[i] for i in batch_ids] 303 | self.batches = batches 304 | 305 | assert len(self.batches) * self.batch_size == self.num_samples 306 | return iter(self.batches) 307 | 308 | def _bisect(self, x, lo=0, hi=None): 309 | if hi is None: 310 | hi = len(self.boundaries) - 1 311 | 312 | if hi > lo: 313 | mid = (hi + lo) // 2 314 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 315 | return mid 316 | elif x <= self.boundaries[mid]: 317 | return self._bisect(x, lo, mid) 318 | else: 319 | return self._bisect(x, mid + 1, hi) 320 | else: 321 | return -1 322 | 323 | def __len__(self): 324 | return self.num_samples // self.batch_size 325 | -------------------------------------------------------------------------------- /vits/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def feature_loss(fmap_r, fmap_g): 5 | loss = 0 6 | for dr, dg in zip(fmap_r, fmap_g): 7 | for rl, gl in zip(dr, dg): 8 | rl = rl.float().detach() 9 | gl = gl.float() 10 | loss += torch.mean(torch.abs(rl - gl)) 11 | 12 | return loss * 2 13 | 14 | 15 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 16 | loss = 0 17 | r_losses = [] 18 | g_losses = [] 19 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 20 | dr = dr.float() 21 | dg = dg.float() 22 | r_loss = torch.mean((1 - dr) ** 2) 23 | g_loss = torch.mean(dg**2) 24 | loss += r_loss + g_loss 25 | r_losses.append(r_loss.item()) 26 | g_losses.append(g_loss.item()) 27 | 28 | return loss, r_losses, g_losses 29 | 30 | 31 | def generator_loss(disc_outputs): 32 | loss = 0 33 | gen_losses = [] 34 | for dg in disc_outputs: 35 | dg = dg.float() 36 | l = torch.mean((1 - dg) ** 2) 37 | gen_losses.append(l) 38 | loss += l 39 | 40 | return loss, gen_losses 41 | 42 | 43 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 44 | """ 45 | z_p, logs_q: [b, h, t_t] 46 | m_p, logs_p: [b, h, t_t] 47 | """ 48 | z_p = z_p.float() 49 | logs_q = logs_q.float() 50 | m_p = m_p.float() 51 | logs_p = logs_p.float() 52 | z_mask = z_mask.float() 53 | 54 | kl = logs_p - logs_q - 0.5 55 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 56 | kl = torch.sum(kl * z_mask) 57 | l = kl / torch.sum(z_mask) 58 | return l 59 | -------------------------------------------------------------------------------- /vits/models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from vits import attentions 8 | from vits import commons 9 | from vits import modules 10 | from vits.utils import f0_to_coarse 11 | from vits_decoder.generator import Generator 12 | 13 | 14 | class TextEncoder(nn.Module): 15 | def __init__(self, 16 | out_channels, 17 | hidden_channels, 18 | filter_channels, 19 | n_heads, 20 | n_layers, 21 | kernel_size, 22 | p_dropout): 23 | super().__init__() 24 | self.out_channels = out_channels 25 | self.hidden_channels = hidden_channels 26 | self.emb_phone = nn.Embedding(63, hidden_channels) # phone lables 27 | self.emb_score = nn.Embedding(128, hidden_channels) # pitch notes 28 | self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 29 | self.emb_slurs = nn.Embedding(2, hidden_channels) # phone slur 30 | nn.init.normal_(self.emb_phone.weight, 0.0, hidden_channels**-0.5) 31 | nn.init.normal_(self.emb_score.weight, 0.0, hidden_channels**-0.5) 32 | nn.init.normal_(self.emb_pitch.weight, 0.0, hidden_channels**-0.5) 33 | nn.init.normal_(self.emb_slurs.weight, 0.0, hidden_channels**-0.5) 34 | self.enc = attentions.Encoder( 35 | hidden_channels, 36 | filter_channels, 37 | n_heads, 38 | n_layers, 39 | kernel_size, 40 | p_dropout) 41 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 42 | 43 | def forward(self, phone, lengths, score, slurs, pitch): 44 | x = self.emb_phone(phone) + self.emb_score(score) + self.emb_pitch(pitch) + self.emb_slurs(slurs) 45 | x = x * math.sqrt(self.hidden_channels) # [b, t, h] 46 | x = torch.transpose(x, 1, -1) # [b, h, t] 47 | x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( 48 | x.dtype 49 | ) 50 | x = self.enc(x * x_mask, x_mask) 51 | stats = self.proj(x) * x_mask 52 | m, logs = torch.split(stats, self.out_channels, dim=1) 53 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 54 | return z, m, logs, x_mask, x 55 | 56 | 57 | class ResidualCouplingBlock(nn.Module): 58 | def __init__( 59 | self, 60 | channels, 61 | hidden_channels, 62 | kernel_size, 63 | dilation_rate, 64 | n_layers, 65 | n_flows=3, 66 | gin_channels=0, 67 | ): 68 | super().__init__() 69 | self.flows = nn.ModuleList() 70 | for i in range(n_flows): 71 | self.flows.append( 72 | modules.ResidualCouplingLayer( 73 | channels, 74 | hidden_channels, 75 | kernel_size, 76 | dilation_rate, 77 | n_layers, 78 | gin_channels=gin_channels, 79 | mean_only=True, 80 | ) 81 | ) 82 | self.flows.append(modules.Flip()) 83 | 84 | def forward(self, x, x_mask, g=None, reverse=False): 85 | if not reverse: 86 | total_logdet = 0 87 | for flow in self.flows: 88 | x, log_det = flow(x, x_mask, g=g, reverse=reverse) 89 | total_logdet += log_det 90 | return x, total_logdet 91 | else: 92 | total_logdet = 0 93 | for flow in reversed(self.flows): 94 | x, log_det = flow(x, x_mask, g=g, reverse=reverse) 95 | total_logdet += log_det 96 | return x, total_logdet 97 | 98 | 99 | class PosteriorEncoder(nn.Module): 100 | def __init__( 101 | self, 102 | in_channels, 103 | out_channels, 104 | hidden_channels, 105 | kernel_size, 106 | dilation_rate, 107 | n_layers, 108 | gin_channels=0, 109 | ): 110 | super().__init__() 111 | self.out_channels = out_channels 112 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 113 | self.enc = modules.WN( 114 | hidden_channels, 115 | kernel_size, 116 | dilation_rate, 117 | n_layers, 118 | gin_channels=gin_channels, 119 | ) 120 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 121 | 122 | def forward(self, x, x_lengths, g=None): 123 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( 124 | x.dtype 125 | ) 126 | x = self.pre(x) * x_mask 127 | x = self.enc(x, x_mask, g=g) 128 | stats = self.proj(x) * x_mask 129 | m, logs = torch.split(stats, self.out_channels, dim=1) 130 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 131 | return z, m, logs, x_mask 132 | 133 | 134 | class SynthesizerTrn(nn.Module): 135 | def __init__( 136 | self, 137 | spec_channels, 138 | segment_size, 139 | hp 140 | ): 141 | super().__init__() 142 | self.segment_size = segment_size 143 | self.enc_p = TextEncoder( 144 | hp.vits.inter_channels, 145 | hp.vits.hidden_channels, 146 | hp.vits.filter_channels, 147 | 2, 148 | 6, 149 | 3, 150 | 0.1, 151 | ) 152 | self.enc_q = PosteriorEncoder( 153 | spec_channels, 154 | hp.vits.inter_channels, 155 | hp.vits.hidden_channels, 156 | 5, 157 | 1, 158 | 16, 159 | gin_channels=hp.vits.gin_channels, 160 | ) 161 | self.flow = ResidualCouplingBlock( 162 | hp.vits.inter_channels, 163 | hp.vits.hidden_channels, 164 | 5, 165 | 1, 166 | 4, 167 | gin_channels=hp.vits.gin_channels 168 | ) 169 | self.dec = Generator(hp=hp) 170 | 171 | def forward(self, phone, phone_l, score, pitch, slurs, spec, spec_l): 172 | 173 | z_p, m_p, logs_p, ppg_mask, x = self.enc_p( 174 | phone, phone_l, score, slurs, f0_to_coarse(pitch)) 175 | z_q, m_q, logs_q, spec_mask = self.enc_q(spec, spec_l) 176 | 177 | z_slice, pit_slice, ids_slice = commons.rand_slice_segments_with_pitch( 178 | z_q, pitch, spec_l, self.segment_size) 179 | audio = self.dec(z_slice, pit_slice) 180 | 181 | # SNAC to flow 182 | z_f, logdet_f = self.flow(z_q, spec_mask) 183 | z_r, logdet_r = self.flow(z_p, spec_mask, reverse=True) 184 | return audio, ids_slice, spec_mask, (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r) 185 | 186 | def infer(self, phone, phone_l, score, pitch, slurs): 187 | z_p, m_p, logs_p, ppg_mask, x = self.enc_p( 188 | phone, phone_l, score, slurs, f0_to_coarse(pitch)) 189 | z, _ = self.flow(z_p, ppg_mask, reverse=True) 190 | o = self.dec(z * ppg_mask, pitch) 191 | return o 192 | -------------------------------------------------------------------------------- /vits/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from vits import commons 4 | 5 | 6 | class WN(torch.nn.Module): 7 | def __init__( 8 | self, 9 | hidden_channels, 10 | kernel_size, 11 | dilation_rate, 12 | n_layers, 13 | gin_channels=0, 14 | p_dropout=0, 15 | ): 16 | super(WN, self).__init__() 17 | assert kernel_size % 2 == 1 18 | self.hidden_channels = hidden_channels 19 | self.kernel_size = (kernel_size,) 20 | self.dilation_rate = dilation_rate 21 | self.n_layers = n_layers 22 | self.gin_channels = gin_channels 23 | self.p_dropout = p_dropout 24 | 25 | self.in_layers = torch.nn.ModuleList() 26 | self.res_skip_layers = torch.nn.ModuleList() 27 | self.drop = nn.Dropout(p_dropout) 28 | 29 | if gin_channels != 0: 30 | cond_layer = torch.nn.Conv1d( 31 | gin_channels, 2 * hidden_channels * n_layers, 1 32 | ) 33 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 34 | 35 | for i in range(n_layers): 36 | dilation = dilation_rate**i 37 | padding = int((kernel_size * dilation - dilation) / 2) 38 | in_layer = torch.nn.Conv1d( 39 | hidden_channels, 40 | 2 * hidden_channels, 41 | kernel_size, 42 | dilation=dilation, 43 | padding=padding, 44 | ) 45 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 46 | self.in_layers.append(in_layer) 47 | 48 | # last one is not necessary 49 | if i < n_layers - 1: 50 | res_skip_channels = 2 * hidden_channels 51 | else: 52 | res_skip_channels = hidden_channels 53 | 54 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 55 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 56 | self.res_skip_layers.append(res_skip_layer) 57 | 58 | def forward(self, x, x_mask, g=None, **kwargs): 59 | output = torch.zeros_like(x) 60 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 61 | 62 | if g is not None: 63 | g = self.cond_layer(g) 64 | 65 | for i in range(self.n_layers): 66 | x_in = self.in_layers[i](x) 67 | if g is not None: 68 | cond_offset = i * 2 * self.hidden_channels 69 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 70 | else: 71 | g_l = torch.zeros_like(x_in) 72 | 73 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 74 | acts = self.drop(acts) 75 | 76 | res_skip_acts = self.res_skip_layers[i](acts) 77 | if i < self.n_layers - 1: 78 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 79 | x = (x + res_acts) * x_mask 80 | output = output + res_skip_acts[:, self.hidden_channels:, :] 81 | else: 82 | output = output + res_skip_acts 83 | return output * x_mask 84 | 85 | def remove_weight_norm(self): 86 | if self.gin_channels != 0: 87 | torch.nn.utils.remove_weight_norm(self.cond_layer) 88 | for l in self.in_layers: 89 | torch.nn.utils.remove_weight_norm(l) 90 | for l in self.res_skip_layers: 91 | torch.nn.utils.remove_weight_norm(l) 92 | 93 | 94 | class Flip(nn.Module): 95 | def forward(self, x, *args, reverse=False, **kwargs): 96 | x = torch.flip(x, [1]) 97 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 98 | return x, logdet 99 | 100 | 101 | class ResidualCouplingLayer(nn.Module): 102 | def __init__( 103 | self, 104 | channels, 105 | hidden_channels, 106 | kernel_size, 107 | dilation_rate, 108 | n_layers, 109 | p_dropout=0, 110 | gin_channels=0, 111 | mean_only=False, 112 | ): 113 | assert channels % 2 == 0, "channels should be divisible by 2" 114 | super().__init__() 115 | self.channels = channels 116 | self.hidden_channels = hidden_channels 117 | self.kernel_size = kernel_size 118 | self.dilation_rate = dilation_rate 119 | self.n_layers = n_layers 120 | self.half_channels = channels // 2 121 | self.mean_only = mean_only 122 | 123 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 124 | self.enc = WN( 125 | hidden_channels, 126 | kernel_size, 127 | dilation_rate, 128 | n_layers, 129 | p_dropout=p_dropout, 130 | gin_channels=gin_channels, 131 | ) 132 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 133 | self.post.weight.data.zero_() 134 | self.post.bias.data.zero_() 135 | 136 | def forward(self, x, x_mask, g=None, reverse=False): 137 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 138 | h = self.pre(x0) * x_mask 139 | h = self.enc(h, x_mask, g=g) 140 | stats = self.post(h) * x_mask 141 | if not self.mean_only: 142 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 143 | else: 144 | m = stats 145 | logs = torch.zeros_like(m) 146 | 147 | if not reverse: 148 | x1 = m + x1 * torch.exp(logs) * x_mask 149 | x = torch.cat([x0, x1], 1) 150 | logdet = torch.sum(logs, [1, 2]) 151 | return x, logdet 152 | else: 153 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 154 | x = torch.cat([x0, x1], 1) 155 | logdet = torch.sum(logs, [1, 2]) 156 | return x, logdet 157 | -------------------------------------------------------------------------------- /vits/spectrogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | 4 | from librosa.filters import mel as librosa_mel_fn 5 | 6 | MAX_WAV_VALUE = 32768.0 7 | 8 | 9 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 10 | """ 11 | PARAMS 12 | ------ 13 | C: compression factor 14 | """ 15 | return torch.log(torch.clamp(x, min=clip_val) * C) 16 | 17 | 18 | def dynamic_range_decompression_torch(x, C=1): 19 | """ 20 | PARAMS 21 | ------ 22 | C: compression factor used to compress 23 | """ 24 | return torch.exp(x) / C 25 | 26 | 27 | def spectral_normalize_torch(magnitudes): 28 | output = dynamic_range_compression_torch(magnitudes) 29 | return output 30 | 31 | 32 | def spectral_de_normalize_torch(magnitudes): 33 | output = dynamic_range_decompression_torch(magnitudes) 34 | return output 35 | 36 | 37 | mel_basis = {} 38 | hann_window = {} 39 | 40 | 41 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 42 | if torch.min(y) < -1.0: 43 | print("min value is ", torch.min(y)) 44 | if torch.max(y) > 1.0: 45 | print("max value is ", torch.max(y)) 46 | 47 | global hann_window 48 | dtype_device = str(y.dtype) + "_" + str(y.device) 49 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 50 | if wnsize_dtype_device not in hann_window: 51 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 52 | dtype=y.dtype, device=y.device 53 | ) 54 | 55 | y = torch.nn.functional.pad( 56 | y.unsqueeze(1), 57 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 58 | mode="reflect", 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.stft( 63 | y, 64 | n_fft, 65 | hop_length=hop_size, 66 | win_length=win_size, 67 | window=hann_window[wnsize_dtype_device], 68 | center=center, 69 | pad_mode="reflect", 70 | normalized=False, 71 | onesided=True, 72 | return_complex=False, 73 | ) 74 | 75 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 76 | return spec 77 | 78 | 79 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 80 | global mel_basis 81 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 82 | fmax_dtype_device = str(fmax) + "_" + dtype_device 83 | if fmax_dtype_device not in mel_basis: 84 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 85 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 86 | dtype=spec.dtype, device=spec.device 87 | ) 88 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 89 | spec = spectral_normalize_torch(spec) 90 | return spec 91 | 92 | 93 | def mel_spectrogram_torch( 94 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 95 | ): 96 | if torch.min(y) < -1.0: 97 | print("min value is ", torch.min(y)) 98 | if torch.max(y) > 1.0: 99 | print("max value is ", torch.max(y)) 100 | 101 | global mel_basis, hann_window 102 | dtype_device = str(y.dtype) + "_" + str(y.device) 103 | fmax_dtype_device = str(fmax) + "_" + dtype_device 104 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 105 | if fmax_dtype_device not in mel_basis: 106 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 107 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 108 | dtype=y.dtype, device=y.device 109 | ) 110 | if wnsize_dtype_device not in hann_window: 111 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 112 | dtype=y.dtype, device=y.device 113 | ) 114 | 115 | y = torch.nn.functional.pad( 116 | y.unsqueeze(1), 117 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 118 | mode="reflect", 119 | ) 120 | y = y.squeeze(1) 121 | 122 | spec = torch.stft( 123 | y, 124 | n_fft, 125 | hop_length=hop_size, 126 | win_length=win_size, 127 | window=hann_window[wnsize_dtype_device], 128 | center=center, 129 | pad_mode="reflect", 130 | normalized=False, 131 | onesided=True, 132 | return_complex=False, 133 | ) 134 | 135 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 136 | 137 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 138 | spec = spectral_normalize_torch(spec) 139 | 140 | return spec 141 | -------------------------------------------------------------------------------- /vits/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import read 4 | 5 | MATPLOTLIB_FLAG = False 6 | 7 | 8 | def load_wav_to_torch(full_path): 9 | sampling_rate, data = read(full_path) 10 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 11 | 12 | 13 | f0_bin = 256 14 | f0_max = 1100.0 15 | f0_min = 50.0 16 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 17 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 18 | 19 | 20 | def f0_to_coarse(f0): 21 | is_torch = isinstance(f0, torch.Tensor) 22 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * \ 23 | np.log(1 + f0 / 700) 24 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * \ 25 | (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 26 | 27 | f0_mel[f0_mel <= 1] = 1 28 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 29 | f0_coarse = ( 30 | f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 31 | assert f0_coarse.max() <= 255 and f0_coarse.min( 32 | ) >= 1, (f0_coarse.max(), f0_coarse.min()) 33 | return f0_coarse -------------------------------------------------------------------------------- /vits_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .alias.act import SnakeAlias -------------------------------------------------------------------------------- /vits_decoder/alias/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /vits_decoder/alias/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch import sin, pow 9 | from torch.nn import Parameter 10 | from .resample import UpSample1d, DownSample1d 11 | 12 | 13 | class Activation1d(nn.Module): 14 | def __init__(self, 15 | activation, 16 | up_ratio: int = 2, 17 | down_ratio: int = 2, 18 | up_kernel_size: int = 12, 19 | down_kernel_size: int = 12): 20 | super().__init__() 21 | self.up_ratio = up_ratio 22 | self.down_ratio = down_ratio 23 | self.act = activation 24 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 25 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 26 | 27 | # x: [B,C,T] 28 | def forward(self, x): 29 | x = self.upsample(x) 30 | x = self.act(x) 31 | x = self.downsample(x) 32 | 33 | return x 34 | 35 | 36 | class SnakeBeta(nn.Module): 37 | ''' 38 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 39 | Shape: 40 | - Input: (B, C, T) 41 | - Output: (B, C, T), same shape as the input 42 | Parameters: 43 | - alpha - trainable parameter that controls frequency 44 | - beta - trainable parameter that controls magnitude 45 | References: 46 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 47 | https://arxiv.org/abs/2006.08195 48 | Examples: 49 | >>> a1 = snakebeta(256) 50 | >>> x = torch.randn(256) 51 | >>> x = a1(x) 52 | ''' 53 | 54 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 55 | ''' 56 | Initialization. 57 | INPUT: 58 | - in_features: shape of the input 59 | - alpha - trainable parameter that controls frequency 60 | - beta - trainable parameter that controls magnitude 61 | alpha is initialized to 1 by default, higher values = higher-frequency. 62 | beta is initialized to 1 by default, higher values = higher-magnitude. 63 | alpha will be trained along with the rest of your model. 64 | ''' 65 | super(SnakeBeta, self).__init__() 66 | self.in_features = in_features 67 | # initialize alpha 68 | self.alpha_logscale = alpha_logscale 69 | if self.alpha_logscale: # log scale alphas initialized to zeros 70 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 71 | self.beta = Parameter(torch.zeros(in_features) * alpha) 72 | else: # linear scale alphas initialized to ones 73 | self.alpha = Parameter(torch.ones(in_features) * alpha) 74 | self.beta = Parameter(torch.ones(in_features) * alpha) 75 | self.alpha.requires_grad = alpha_trainable 76 | self.beta.requires_grad = alpha_trainable 77 | self.no_div_by_zero = 0.000000001 78 | 79 | def forward(self, x): 80 | ''' 81 | Forward pass of the function. 82 | Applies the function to the input elementwise. 83 | SnakeBeta = x + 1/b * sin^2 (xa) 84 | ''' 85 | alpha = self.alpha.unsqueeze( 86 | 0).unsqueeze(-1) # line up with x to [B, C, T] 87 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 88 | if self.alpha_logscale: 89 | alpha = torch.exp(alpha) 90 | beta = torch.exp(beta) 91 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 92 | return x 93 | 94 | 95 | class Mish(nn.Module): 96 | """ 97 | Mish activation function is proposed in "Mish: A Self 98 | Regularized Non-Monotonic Neural Activation Function" 99 | paper, https://arxiv.org/abs/1908.08681. 100 | """ 101 | 102 | def __init__(self): 103 | super().__init__() 104 | 105 | def forward(self, x): 106 | return x * torch.tanh(F.softplus(x)) 107 | 108 | 109 | class SnakeAlias(nn.Module): 110 | def __init__(self, 111 | channels, 112 | up_ratio: int = 2, 113 | down_ratio: int = 2, 114 | up_kernel_size: int = 12, 115 | down_kernel_size: int = 12): 116 | super().__init__() 117 | self.up_ratio = up_ratio 118 | self.down_ratio = down_ratio 119 | self.act = SnakeBeta(channels, alpha_logscale=True) 120 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 121 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 122 | 123 | # x: [B,C,T] 124 | def forward(self, x): 125 | x = self.upsample(x) 126 | x = self.act(x) 127 | x = self.downsample(x) 128 | 129 | return x -------------------------------------------------------------------------------- /vits_decoder/alias/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /vits_decoder/alias/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /vits_decoder/bigv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.nn import Conv1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | from .alias.act import SnakeAlias 7 | 8 | 9 | def init_weights(m, mean=0.0, std=0.01): 10 | classname = m.__class__.__name__ 11 | if classname.find("Conv") != -1: 12 | m.weight.data.normal_(mean, std) 13 | 14 | 15 | def get_padding(kernel_size, dilation=1): 16 | return int((kernel_size*dilation - dilation)/2) 17 | 18 | 19 | class AMPBlock(torch.nn.Module): 20 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 21 | super(AMPBlock, self).__init__() 22 | self.convs1 = nn.ModuleList([ 23 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 24 | padding=get_padding(kernel_size, dilation[0]))), 25 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 26 | padding=get_padding(kernel_size, dilation[1]))), 27 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 28 | padding=get_padding(kernel_size, dilation[2]))) 29 | ]) 30 | self.convs1.apply(init_weights) 31 | 32 | self.convs2 = nn.ModuleList([ 33 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 34 | padding=get_padding(kernel_size, 1))), 35 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 36 | padding=get_padding(kernel_size, 1))), 37 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 38 | padding=get_padding(kernel_size, 1))) 39 | ]) 40 | self.convs2.apply(init_weights) 41 | 42 | # total number of conv layers 43 | self.num_layers = len(self.convs1) + len(self.convs2) 44 | 45 | # periodic nonlinearity with snakebeta function and anti-aliasing 46 | self.activations = nn.ModuleList([ 47 | SnakeAlias(channels) for _ in range(self.num_layers) 48 | ]) 49 | 50 | def forward(self, x): 51 | acts1, acts2 = self.activations[::2], self.activations[1::2] 52 | for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): 53 | xt = a1(x) 54 | xt = c1(xt) 55 | xt = a2(xt) 56 | xt = c2(xt) 57 | x = xt + x 58 | return x 59 | 60 | def remove_weight_norm(self): 61 | for l in self.convs1: 62 | remove_weight_norm(l) 63 | for l in self.convs2: 64 | remove_weight_norm(l) -------------------------------------------------------------------------------- /vits_decoder/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from omegaconf import OmegaConf 5 | from .msd import ScaleDiscriminator 6 | from .mpd import MultiPeriodDiscriminator 7 | from .mrd import MultiResolutionDiscriminator 8 | 9 | 10 | class Discriminator(nn.Module): 11 | def __init__(self, hp): 12 | super(Discriminator, self).__init__() 13 | self.MRD = MultiResolutionDiscriminator(hp) 14 | self.MPD = MultiPeriodDiscriminator(hp) 15 | self.MSD = ScaleDiscriminator() 16 | 17 | def forward(self, x): 18 | r = self.MRD(x) 19 | p = self.MPD(x) 20 | s = self.MSD(x) 21 | return r + p + s 22 | 23 | 24 | if __name__ == '__main__': 25 | hp = OmegaConf.load('../config/base.yaml') 26 | model = Discriminator(hp) 27 | 28 | x = torch.randn(3, 1, 16384) 29 | print(x.shape) 30 | 31 | output = model(x) 32 | for features, score in output: 33 | for feat in features: 34 | print(feat.shape) 35 | print(score.shape) 36 | 37 | pytorch_total_params = sum(p.numel() 38 | for p in model.parameters() if p.requires_grad) 39 | print(pytorch_total_params) 40 | -------------------------------------------------------------------------------- /vits_decoder/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from torch.nn import Conv1d 7 | from torch.nn import ConvTranspose1d 8 | from torch.nn.utils import weight_norm 9 | from torch.nn.utils import remove_weight_norm 10 | 11 | from .nsf import SourceModuleHnNSF 12 | from .bigv import init_weights, AMPBlock, SnakeAlias 13 | 14 | 15 | class Generator(torch.nn.Module): 16 | # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. 17 | def __init__(self, hp): 18 | super(Generator, self).__init__() 19 | self.hp = hp 20 | self.num_kernels = len(hp.gen.resblock_kernel_sizes) 21 | self.num_upsamples = len(hp.gen.upsample_rates) 22 | # pre conv 23 | self.conv_pre = Conv1d(hp.gen.upsample_input, 24 | hp.gen.upsample_initial_channel, 7, 1, padding=3) 25 | # nsf 26 | self.f0_upsamp = torch.nn.Upsample( 27 | scale_factor=np.prod(hp.gen.upsample_rates)) 28 | self.m_source = SourceModuleHnNSF(sampling_rate=hp.data.sampling_rate) 29 | self.noise_convs = nn.ModuleList() 30 | # transposed conv-based upsamplers. does not apply anti-aliasing 31 | self.ups = nn.ModuleList() 32 | for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)): 33 | # print(f'ups: {i} {k}, {u}, {(k - u) // 2}') 34 | # base 35 | self.ups.append( 36 | weight_norm( 37 | ConvTranspose1d( 38 | hp.gen.upsample_initial_channel // (2 ** i), 39 | hp.gen.upsample_initial_channel // (2 ** (i + 1)), 40 | k, 41 | u, 42 | padding=(k - u) // 2) 43 | ) 44 | ) 45 | # nsf 46 | if i + 1 < len(hp.gen.upsample_rates): 47 | stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:]) 48 | stride_f0 = int(stride_f0) 49 | self.noise_convs.append( 50 | Conv1d( 51 | 1, 52 | hp.gen.upsample_initial_channel // (2 ** (i + 1)), 53 | kernel_size=stride_f0 * 2, 54 | stride=stride_f0, 55 | padding=stride_f0 // 2, 56 | ) 57 | ) 58 | else: 59 | self.noise_convs.append( 60 | Conv1d(1, hp.gen.upsample_initial_channel // 61 | (2 ** (i + 1)), kernel_size=1) 62 | ) 63 | 64 | # residual blocks using anti-aliased multi-periodicity composition modules (AMP) 65 | self.resblocks = nn.ModuleList() 66 | for i in range(len(self.ups)): 67 | ch = hp.gen.upsample_initial_channel // (2 ** (i + 1)) 68 | for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes): 69 | self.resblocks.append(AMPBlock(ch, k, d)) 70 | 71 | # post conv 72 | self.activation_post = SnakeAlias(ch) 73 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 74 | # weight initialization 75 | self.ups.apply(init_weights) 76 | 77 | def forward(self, x, f0): 78 | # Perturbation 79 | x = x + torch.randn_like(x) 80 | x = self.conv_pre(x) 81 | x = x * torch.tanh(F.softplus(x)) 82 | # nsf 83 | f0 = f0[:, None] 84 | f0 = self.f0_upsamp(f0).transpose(1, 2) 85 | har_source = self.m_source(f0) 86 | har_source = har_source.transpose(1, 2) 87 | 88 | for i in range(self.num_upsamples): 89 | # upsampling 90 | x = self.ups[i](x) 91 | # nsf 92 | x_source = self.noise_convs[i](har_source) 93 | x = x + x_source 94 | # AMP blocks 95 | xs = None 96 | for j in range(self.num_kernels): 97 | if xs is None: 98 | xs = self.resblocks[i * self.num_kernels + j](x) 99 | else: 100 | xs += self.resblocks[i * self.num_kernels + j](x) 101 | x = xs / self.num_kernels 102 | 103 | # post conv 104 | x = self.activation_post(x) 105 | x = self.conv_post(x) 106 | x = torch.tanh(x) 107 | return x 108 | 109 | def remove_weight_norm(self): 110 | for l in self.ups: 111 | remove_weight_norm(l) 112 | for l in self.resblocks: 113 | l.remove_weight_norm() 114 | 115 | def eval(self, inference=False): 116 | super(Generator, self).eval() 117 | # don't remove weight norm while validation in training loop 118 | if inference: 119 | self.remove_weight_norm() 120 | 121 | def pitch2source(self, f0): 122 | f0 = f0[:, None] 123 | f0 = self.f0_upsamp(f0).transpose(1, 2) # [1,len,1] 124 | har_source = self.m_source(f0) 125 | har_source = har_source.transpose(1, 2) # [1,1,len] 126 | return har_source 127 | 128 | def source2wav(self, audio): 129 | MAX_WAV_VALUE = 32768.0 130 | audio = audio.squeeze() 131 | audio = MAX_WAV_VALUE * audio 132 | audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1) 133 | audio = audio.short() 134 | return audio.cpu().detach().numpy() 135 | 136 | def inference(self, x, har_source): 137 | # Perturbation 138 | x = x + torch.randn_like(x) * 0.1 139 | x = self.conv_pre(x) 140 | x = x * torch.tanh(F.softplus(x)) 141 | 142 | for i in range(self.num_upsamples): 143 | # upsampling 144 | x = self.ups[i](x) 145 | # nsf 146 | x_source = self.noise_convs[i](har_source) 147 | x = x + x_source 148 | # AMP blocks 149 | xs = None 150 | for j in range(self.num_kernels): 151 | if xs is None: 152 | xs = self.resblocks[i * self.num_kernels + j](x) 153 | else: 154 | xs += self.resblocks[i * self.num_kernels + j](x) 155 | x = xs / self.num_kernels 156 | 157 | # post conv 158 | x = self.activation_post(x) 159 | x = self.conv_post(x) 160 | x = torch.tanh(x) 161 | return x 162 | -------------------------------------------------------------------------------- /vits_decoder/mpd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, spectral_norm 5 | 6 | class DiscriminatorP(nn.Module): 7 | def __init__(self, hp, period): 8 | super(DiscriminatorP, self).__init__() 9 | 10 | self.LRELU_SLOPE = hp.mpd.lReLU_slope 11 | self.period = period 12 | 13 | kernel_size = hp.mpd.kernel_size 14 | stride = hp.mpd.stride 15 | norm_f = weight_norm if hp.mpd.use_spectral_norm == False else spectral_norm 16 | 17 | self.convs = nn.ModuleList([ 18 | norm_f(nn.Conv2d(1, 64, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 19 | norm_f(nn.Conv2d(64, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 20 | norm_f(nn.Conv2d(128, 256, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 21 | norm_f(nn.Conv2d(256, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 22 | norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), 1, padding=(kernel_size // 2, 0))), 23 | ]) 24 | self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 25 | 26 | def forward(self, x): 27 | fmap = [] 28 | 29 | # 1d to 2d 30 | b, c, t = x.shape 31 | if t % self.period != 0: # pad first 32 | n_pad = self.period - (t % self.period) 33 | x = F.pad(x, (0, n_pad), "reflect") 34 | t = t + n_pad 35 | x = x.view(b, c, t // self.period, self.period) 36 | 37 | for l in self.convs: 38 | x = l(x) 39 | x = F.leaky_relu(x, self.LRELU_SLOPE) 40 | fmap.append(x) 41 | x = self.conv_post(x) 42 | fmap.append(x) 43 | x = torch.flatten(x, 1, -1) 44 | 45 | return fmap, x 46 | 47 | 48 | class MultiPeriodDiscriminator(nn.Module): 49 | def __init__(self, hp): 50 | super(MultiPeriodDiscriminator, self).__init__() 51 | 52 | self.discriminators = nn.ModuleList( 53 | [DiscriminatorP(hp, period) for period in hp.mpd.periods] 54 | ) 55 | 56 | def forward(self, x): 57 | ret = list() 58 | for disc in self.discriminators: 59 | ret.append(disc(x)) 60 | 61 | return ret # [(feat, score), (feat, score), (feat, score), (feat, score), (feat, score)] 62 | -------------------------------------------------------------------------------- /vits_decoder/mrd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, spectral_norm 5 | 6 | class DiscriminatorR(torch.nn.Module): 7 | def __init__(self, hp, resolution): 8 | super(DiscriminatorR, self).__init__() 9 | 10 | self.resolution = resolution 11 | self.LRELU_SLOPE = hp.mpd.lReLU_slope 12 | 13 | norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm 14 | 15 | self.convs = nn.ModuleList([ 16 | norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), 17 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 18 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 19 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 20 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 21 | ]) 22 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 23 | 24 | def forward(self, x): 25 | fmap = [] 26 | 27 | x = self.spectrogram(x) 28 | x = x.unsqueeze(1) 29 | for l in self.convs: 30 | x = l(x) 31 | x = F.leaky_relu(x, self.LRELU_SLOPE) 32 | fmap.append(x) 33 | x = self.conv_post(x) 34 | fmap.append(x) 35 | x = torch.flatten(x, 1, -1) 36 | 37 | return fmap, x 38 | 39 | def spectrogram(self, x): 40 | n_fft, hop_length, win_length = self.resolution 41 | x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') 42 | x = x.squeeze(1) 43 | x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) #[B, F, TT, 2] 44 | mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] 45 | 46 | return mag 47 | 48 | 49 | class MultiResolutionDiscriminator(torch.nn.Module): 50 | def __init__(self, hp): 51 | super(MultiResolutionDiscriminator, self).__init__() 52 | self.resolutions = eval(hp.mrd.resolutions) 53 | self.discriminators = nn.ModuleList( 54 | [DiscriminatorR(hp, resolution) for resolution in self.resolutions] 55 | ) 56 | 57 | def forward(self, x): 58 | ret = list() 59 | for disc in self.discriminators: 60 | ret.append(disc(x)) 61 | 62 | return ret # [(feat, score), (feat, score), (feat, score)] 63 | -------------------------------------------------------------------------------- /vits_decoder/msd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | 6 | 7 | class ScaleDiscriminator(torch.nn.Module): 8 | def __init__(self): 9 | super(ScaleDiscriminator, self).__init__() 10 | self.convs = nn.ModuleList([ 11 | weight_norm(nn.Conv1d(1, 16, 15, 1, padding=7)), 12 | weight_norm(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), 13 | weight_norm(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), 14 | weight_norm(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 15 | weight_norm(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 16 | weight_norm(nn.Conv1d(1024, 1024, 5, 1, padding=2)), 17 | ]) 18 | self.conv_post = weight_norm(nn.Conv1d(1024, 1, 3, 1, padding=1)) 19 | 20 | def forward(self, x): 21 | fmap = [] 22 | for l in self.convs: 23 | x = l(x) 24 | x = F.leaky_relu(x, 0.1) 25 | fmap.append(x) 26 | x = self.conv_post(x) 27 | fmap.append(x) 28 | x = torch.flatten(x, 1, -1) 29 | return [(fmap, x)] 30 | -------------------------------------------------------------------------------- /vits_extend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/vits_extend/__init__.py -------------------------------------------------------------------------------- /vits_extend/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from vits.data_utils import DistributedBucketSampler 3 | from vits.data_utils import TextAudioLoader 4 | from vits.data_utils import TextAudioCollate 5 | 6 | 7 | def create_dataloader_train(hps, n_gpus, rank): 8 | collate_fn = TextAudioCollate() 9 | train_dataset = TextAudioLoader(hps.data.training_files, hps.data) 10 | train_sampler = DistributedBucketSampler( 11 | train_dataset, 12 | hps.train.batch_size, 13 | [32, 300, 400, 500, 600, 700, 800, 900, 1000], 14 | num_replicas=n_gpus, 15 | rank=rank, 16 | shuffle=True) 17 | train_loader = DataLoader( 18 | train_dataset, 19 | num_workers=4, 20 | shuffle=False, 21 | pin_memory=True, 22 | collate_fn=collate_fn, 23 | batch_sampler=train_sampler) 24 | return train_loader 25 | 26 | 27 | def create_dataloader_eval(hps): 28 | collate_fn = TextAudioCollate() 29 | eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data) 30 | eval_loader = DataLoader( 31 | eval_dataset, 32 | num_workers=2, 33 | shuffle=False, 34 | batch_size=hps.train.batch_size, 35 | pin_memory=True, 36 | drop_last=False, 37 | collate_fn=collate_fn) 38 | return eval_loader 39 | -------------------------------------------------------------------------------- /vits_extend/plotting.py: -------------------------------------------------------------------------------- 1 | import logging 2 | mpl_logger = logging.getLogger('matplotlib') # must before import matplotlib 3 | mpl_logger.setLevel(logging.WARNING) 4 | import matplotlib 5 | matplotlib.use("Agg") 6 | 7 | import numpy as np 8 | import matplotlib.pylab as plt 9 | 10 | 11 | def save_figure_to_numpy(fig): 12 | # save it to a numpy array. 13 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 14 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 15 | data = np.transpose(data, (2, 0, 1)) 16 | return data 17 | 18 | 19 | def plot_waveform_to_numpy(waveform): 20 | fig, ax = plt.subplots(figsize=(12, 4)) 21 | ax.plot() 22 | ax.plot(range(len(waveform)), waveform, 23 | linewidth=0.1, alpha=0.7, color='blue') 24 | 25 | plt.xlabel("Samples") 26 | plt.ylabel("Amplitude") 27 | plt.ylim(-1, 1) 28 | plt.tight_layout() 29 | 30 | fig.canvas.draw() 31 | data = save_figure_to_numpy(fig) 32 | plt.close() 33 | 34 | return data 35 | 36 | 37 | def plot_spectrogram_to_numpy(spectrogram): 38 | fig, ax = plt.subplots(figsize=(12, 4)) 39 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 40 | interpolation='none') 41 | plt.colorbar(im, ax=ax) 42 | plt.xlabel("Frames") 43 | plt.ylabel("Channels") 44 | plt.tight_layout() 45 | 46 | fig.canvas.draw() 47 | data = save_figure_to_numpy(fig) 48 | plt.close() 49 | return data 50 | -------------------------------------------------------------------------------- /vits_extend/stft.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jungil Kong 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import os 25 | import random 26 | import torch 27 | import torch.utils.data 28 | import numpy as np 29 | from librosa.util import normalize 30 | from scipy.io.wavfile import read 31 | from librosa.filters import mel as librosa_mel_fn 32 | 33 | 34 | class TacotronSTFT(torch.nn.Module): 35 | def __init__(self, filter_length=512, hop_length=160, win_length=512, 36 | n_mel_channels=80, sampling_rate=16000, mel_fmin=0.0, 37 | mel_fmax=None, center=False, device='cpu'): 38 | super(TacotronSTFT, self).__init__() 39 | self.n_mel_channels = n_mel_channels 40 | self.sampling_rate = sampling_rate 41 | self.n_fft = filter_length 42 | self.hop_size = hop_length 43 | self.win_size = win_length 44 | self.fmin = mel_fmin 45 | self.fmax = mel_fmax 46 | self.center = center 47 | 48 | mel = librosa_mel_fn( 49 | sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax) 50 | 51 | mel_basis = torch.from_numpy(mel).float().to(device) 52 | hann_window = torch.hann_window(win_length).to(device) 53 | 54 | self.register_buffer('mel_basis', mel_basis) 55 | self.register_buffer('hann_window', hann_window) 56 | 57 | def linear_spectrogram(self, y): 58 | # assert (torch.min(y.data) >= -1) 59 | # assert (torch.max(y.data) <= 1) 60 | 61 | y = torch.nn.functional.pad(y.unsqueeze(1), 62 | (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)), 63 | mode='reflect') 64 | y = y.squeeze(1) 65 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, 66 | center=self.center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 67 | spec = torch.norm(spec, p=2, dim=-1) 68 | 69 | return spec 70 | 71 | def mel_spectrogram(self, y): 72 | """Computes mel-spectrograms from a batch of waves 73 | PARAMS 74 | ------ 75 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 76 | 77 | RETURNS 78 | ------- 79 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 80 | """ 81 | # assert(torch.min(y.data) >= -1) 82 | # assert(torch.max(y.data) <= 1) 83 | 84 | y = torch.nn.functional.pad(y.unsqueeze(1), 85 | (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)), 86 | mode='reflect') 87 | y = y.squeeze(1) 88 | 89 | spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, 90 | center=self.center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 91 | 92 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 93 | 94 | spec = torch.matmul(self.mel_basis, spec) 95 | spec = self.spectral_normalize_torch(spec) 96 | 97 | return spec 98 | 99 | def spectral_normalize_torch(self, magnitudes): 100 | output = self.dynamic_range_compression_torch(magnitudes) 101 | return output 102 | 103 | def dynamic_range_compression_torch(self, x, C=1, clip_val=1e-5): 104 | return torch.log(torch.clamp(x, min=clip_val) * C) 105 | -------------------------------------------------------------------------------- /vits_extend/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def stft(x, fft_size, hop_size, win_length, window): 13 | """Perform STFT and convert to magnitude spectrogram. 14 | Args: 15 | x (Tensor): Input signal tensor (B, T). 16 | fft_size (int): FFT size. 17 | hop_size (int): Hop size. 18 | win_length (int): Window length. 19 | window (str): Window function type. 20 | Returns: 21 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 22 | """ 23 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False) 24 | real = x_stft[..., 0] 25 | imag = x_stft[..., 1] 26 | 27 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 28 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 29 | 30 | 31 | class SpectralConvergengeLoss(torch.nn.Module): 32 | """Spectral convergence loss module.""" 33 | 34 | def __init__(self): 35 | """Initilize spectral convergence loss module.""" 36 | super(SpectralConvergengeLoss, self).__init__() 37 | 38 | def forward(self, x_mag, y_mag): 39 | """Calculate forward propagation. 40 | Args: 41 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 42 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 43 | Returns: 44 | Tensor: Spectral convergence loss value. 45 | """ 46 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 47 | 48 | 49 | class LogSTFTMagnitudeLoss(torch.nn.Module): 50 | """Log STFT magnitude loss module.""" 51 | 52 | def __init__(self): 53 | """Initilize los STFT magnitude loss module.""" 54 | super(LogSTFTMagnitudeLoss, self).__init__() 55 | 56 | def forward(self, x_mag, y_mag): 57 | """Calculate forward propagation. 58 | Args: 59 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 60 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 61 | Returns: 62 | Tensor: Log STFT magnitude loss value. 63 | """ 64 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 65 | 66 | 67 | class STFTLoss(torch.nn.Module): 68 | """STFT loss module.""" 69 | 70 | def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 71 | """Initialize STFT loss module.""" 72 | super(STFTLoss, self).__init__() 73 | self.fft_size = fft_size 74 | self.shift_size = shift_size 75 | self.win_length = win_length 76 | self.window = getattr(torch, window)(win_length).to(device) 77 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 78 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 79 | 80 | def forward(self, x, y): 81 | """Calculate forward propagation. 82 | Args: 83 | x (Tensor): Predicted signal (B, T). 84 | y (Tensor): Groundtruth signal (B, T). 85 | Returns: 86 | Tensor: Spectral convergence loss value. 87 | Tensor: Log STFT magnitude loss value. 88 | """ 89 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 90 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 91 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 92 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 93 | 94 | return sc_loss, mag_loss 95 | 96 | 97 | class MultiResolutionSTFTLoss(torch.nn.Module): 98 | """Multi resolution STFT loss module.""" 99 | 100 | def __init__(self, 101 | device, 102 | resolutions, 103 | window="hann_window"): 104 | """Initialize Multi resolution STFT loss module. 105 | Args: 106 | resolutions (list): List of (FFT size, hop size, window length). 107 | window (str): Window function type. 108 | """ 109 | super(MultiResolutionSTFTLoss, self).__init__() 110 | self.stft_losses = torch.nn.ModuleList() 111 | for fs, ss, wl in resolutions: 112 | self.stft_losses += [STFTLoss(device, fs, ss, wl, window)] 113 | 114 | def forward(self, x, y): 115 | """Calculate forward propagation. 116 | Args: 117 | x (Tensor): Predicted signal (B, T). 118 | y (Tensor): Groundtruth signal (B, T). 119 | Returns: 120 | Tensor: Multi resolution spectral convergence loss value. 121 | Tensor: Multi resolution log STFT magnitude loss value. 122 | """ 123 | sc_loss = 0.0 124 | mag_loss = 0.0 125 | for f in self.stft_losses: 126 | sc_l, mag_l = f(x, y) 127 | sc_loss += sc_l 128 | mag_loss += mag_l 129 | 130 | sc_loss /= len(self.stft_losses) 131 | mag_loss /= len(self.stft_losses) 132 | 133 | return sc_loss, mag_loss 134 | -------------------------------------------------------------------------------- /vits_extend/validation.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def validate(hp, args, generator, discriminator, valloader, stft, writer, step, device): 7 | generator.eval() 8 | discriminator.eval() 9 | torch.backends.cudnn.benchmark = False 10 | 11 | loader = tqdm.tqdm(valloader, desc='Validation loop') 12 | mel_loss = 0.0 13 | for idx, (phone, phone_l, score, pitch, slurs, spec, spec_l, audio, audio_l) in enumerate(loader): 14 | phone = phone.to(device) 15 | phone_l = phone_l.to(device) 16 | score = score.to(device) 17 | pitch = pitch.to(device) 18 | slurs = slurs.to(device) 19 | audio = audio.to(device) 20 | 21 | if hasattr(generator, 'module'): 22 | fake_audio = generator.module.infer(phone, phone_l, score, pitch, slurs)[ 23 | :, :, :audio.size(2)] 24 | else: 25 | fake_audio = generator.infer(phone, phone_l, score, pitch, slurs)[ 26 | :, :, :audio.size(2)] 27 | 28 | mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1)) 29 | mel_real = stft.mel_spectrogram(audio.squeeze(1)) 30 | 31 | mel_loss += F.l1_loss(mel_fake, mel_real).item() 32 | 33 | if idx < hp.log.num_audio: 34 | spec_fake = stft.linear_spectrogram(fake_audio.squeeze(1)) 35 | spec_real = stft.linear_spectrogram(audio.squeeze(1)) 36 | 37 | audio = audio[0][0].cpu().detach().numpy() 38 | fake_audio = fake_audio[0][0].cpu().detach().numpy() 39 | spec_fake = spec_fake[0].cpu().detach().numpy() 40 | spec_real = spec_real[0].cpu().detach().numpy() 41 | writer.log_fig_audio( 42 | audio, fake_audio, spec_fake, spec_real, idx, step) 43 | 44 | mel_loss = mel_loss / len(valloader.dataset) 45 | 46 | writer.log_validation(mel_loss, generator, discriminator, step) 47 | 48 | torch.backends.cudnn.benchmark = True 49 | -------------------------------------------------------------------------------- /vits_extend/writer.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | import numpy as np 3 | import librosa 4 | 5 | from .plotting import plot_waveform_to_numpy, plot_spectrogram_to_numpy 6 | 7 | class MyWriter(SummaryWriter): 8 | def __init__(self, hp, logdir): 9 | super(MyWriter, self).__init__(logdir) 10 | self.sample_rate = hp.data.sampling_rate 11 | 12 | def log_training(self, g_loss, d_loss, mel_loss, stft_loss, k_loss, r_loss, score_loss, step): 13 | self.add_scalar('train/g_loss', g_loss, step) 14 | self.add_scalar('train/d_loss', d_loss, step) 15 | 16 | self.add_scalar('train/score_loss', score_loss, step) 17 | self.add_scalar('train/stft_loss', stft_loss, step) 18 | self.add_scalar('train/mel_loss', mel_loss, step) 19 | self.add_scalar('train/kl_f_loss', k_loss, step) 20 | self.add_scalar('train/kl_r_loss', r_loss, step) 21 | 22 | def log_validation(self, mel_loss, generator, discriminator, step): 23 | self.add_scalar('validation/mel_loss', mel_loss, step) 24 | 25 | def log_fig_audio(self, real, fake, spec_fake, spec_real, idx, step): 26 | if idx == 0: 27 | spec_fake = librosa.amplitude_to_db(spec_fake, ref=np.max,top_db=80.) 28 | spec_real = librosa.amplitude_to_db(spec_real, ref=np.max,top_db=80.) 29 | self.add_image(f'spec_fake/{step}', plot_spectrogram_to_numpy(spec_fake), step) 30 | self.add_image(f'wave_fake/{step}', plot_waveform_to_numpy(fake), step) 31 | self.add_image(f'spec_real/{step}', plot_spectrogram_to_numpy(spec_real), step) 32 | self.add_image(f'wave_real/{step}', plot_waveform_to_numpy(real), step) 33 | 34 | self.add_audio(f'fake/{step}', fake, step, self.sample_rate) 35 | self.add_audio(f'real/{step}', real, step, self.sample_rate) 36 | 37 | def log_histogram(self, model, step): 38 | for tag, value in model.named_parameters(): 39 | self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step) 40 | --------------------------------------------------------------------------------