The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── LICENSE
├── README.md
├── configs
    ├── requirements.txt
    └── singing_base.yaml
├── pit_export.py
├── pit_train.py
├── pitch
    ├── __init__.py
    ├── base.py
    ├── data_utils.py
    ├── diffusion.py
    ├── models.py
    └── utils.py
├── pitch_extend
    ├── dataloader.py
    ├── plotting.py
    ├── train.py
    ├── validation.py
    └── writer.py
├── resource
    ├── vising_loss.png
    ├── vising_mel.png
    └── vising_sample.wav
├── svs
    ├── __init__.py
    ├── midi-HZ.scp
    ├── midi-note.scp
    ├── phone_map.py
    └── phone_uv.py
├── svs_export.py
├── svs_infer.py
├── svs_infer.txt
├── svs_infer_pitch.py
├── svs_song.py
├── svs_song.txt
├── svs_song_pitch.py
├── svs_train.py
├── util
    ├── __init__.py
    ├── generate_index.py
    ├── generate_label.py
    └── resample.py
├── vits
    ├── __init__.py
    ├── attentions.py
    ├── commons.py
    ├── data_utils.py
    ├── losses.py
    ├── models.py
    ├── modules.py
    ├── spectrogram.py
    └── utils.py
├── vits_decoder
    ├── __init__.py
    ├── alias
    │   ├── __init__.py
    │   ├── act.py
    │   ├── filter.py
    │   └── resample.py
    ├── bigv.py
    ├── discriminator.py
    ├── generator.py
    ├── mpd.py
    ├── mrd.py
    ├── msd.py
    └── nsf.py
└── vits_extend
    ├── __init__.py
    ├── dataloader.py
    ├── plotting.py
    ├── stft.py
    ├── stft_loss.py
    ├── train.py
    ├── validation.py
    └── writer.py


/LICENSE:
--------------------------------------------------------------------------------
  1 |                                  Apache License
  2 |                            Version 2.0, January 2004
  3 |                         http://www.apache.org/licenses/
  4 | 
  5 |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
  6 | 
  7 |    1. Definitions.
  8 | 
  9 |       "License" shall mean the terms and conditions for use, reproduction,
 10 |       and distribution as defined by Sections 1 through 9 of this document.
 11 | 
 12 |       "Licensor" shall mean the copyright owner or entity authorized by
 13 |       the copyright owner that is granting the License.
 14 | 
 15 |       "Legal Entity" shall mean the union of the acting entity and all
 16 |       other entities that control, are controlled by, or are under common
 17 |       control with that entity. For the purposes of this definition,
 18 |       "control" means (i) the power, direct or indirect, to cause the
 19 |       direction or management of such entity, whether by contract or
 20 |       otherwise, or (ii) ownership of fifty percent (50%) or more of the
 21 |       outstanding shares, or (iii) beneficial ownership of such entity.
 22 | 
 23 |       "You" (or "Your") shall mean an individual or Legal Entity
 24 |       exercising permissions granted by this License.
 25 | 
 26 |       "Source" form shall mean the preferred form for making modifications,
 27 |       including but not limited to software source code, documentation
 28 |       source, and configuration files.
 29 | 
 30 |       "Object" form shall mean any form resulting from mechanical
 31 |       transformation or translation of a Source form, including but
 32 |       not limited to compiled object code, generated documentation,
 33 |       and conversions to other media types.
 34 | 
 35 |       "Work" shall mean the work of authorship, whether in Source or
 36 |       Object form, made available under the License, as indicated by a
 37 |       copyright notice that is included in or attached to the work
 38 |       (an example is provided in the Appendix below).
 39 | 
 40 |       "Derivative Works" shall mean any work, whether in Source or Object
 41 |       form, that is based on (or derived from) the Work and for which the
 42 |       editorial revisions, annotations, elaborations, or other modifications
 43 |       represent, as a whole, an original work of authorship. For the purposes
 44 |       of this License, Derivative Works shall not include works that remain
 45 |       separable from, or merely link (or bind by name) to the interfaces of,
 46 |       the Work and Derivative Works thereof.
 47 | 
 48 |       "Contribution" shall mean any work of authorship, including
 49 |       the original version of the Work and any modifications or additions
 50 |       to that Work or Derivative Works thereof, that is intentionally
 51 |       submitted to Licensor for inclusion in the Work by the copyright owner
 52 |       or by an individual or Legal Entity authorized to submit on behalf of
 53 |       the copyright owner. For the purposes of this definition, "submitted"
 54 |       means any form of electronic, verbal, or written communication sent
 55 |       to the Licensor or its representatives, including but not limited to
 56 |       communication on electronic mailing lists, source code control systems,
 57 |       and issue tracking systems that are managed by, or on behalf of, the
 58 |       Licensor for the purpose of discussing and improving the Work, but
 59 |       excluding communication that is conspicuously marked or otherwise
 60 |       designated in writing by the copyright owner as "Not a Contribution."
 61 | 
 62 |       "Contributor" shall mean Licensor and any individual or Legal Entity
 63 |       on behalf of whom a Contribution has been received by Licensor and
 64 |       subsequently incorporated within the Work.
 65 | 
 66 |    2. Grant of Copyright License. Subject to the terms and conditions of
 67 |       this License, each Contributor hereby grants to You a perpetual,
 68 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 69 |       copyright license to reproduce, prepare Derivative Works of,
 70 |       publicly display, publicly perform, sublicense, and distribute the
 71 |       Work and such Derivative Works in Source or Object form.
 72 | 
 73 |    3. Grant of Patent License. Subject to the terms and conditions of
 74 |       this License, each Contributor hereby grants to You a perpetual,
 75 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 76 |       (except as stated in this section) patent license to make, have made,
 77 |       use, offer to sell, sell, import, and otherwise transfer the Work,
 78 |       where such license applies only to those patent claims licensable
 79 |       by such Contributor that are necessarily infringed by their
 80 |       Contribution(s) alone or by combination of their Contribution(s)
 81 |       with the Work to which such Contribution(s) was submitted. If You
 82 |       institute patent litigation against any entity (including a
 83 |       cross-claim or counterclaim in a lawsuit) alleging that the Work
 84 |       or a Contribution incorporated within the Work constitutes direct
 85 |       or contributory patent infringement, then any patent licenses
 86 |       granted to You under this License for that Work shall terminate
 87 |       as of the date such litigation is filed.
 88 | 
 89 |    4. Redistribution. You may reproduce and distribute copies of the
 90 |       Work or Derivative Works thereof in any medium, with or without
 91 |       modifications, and in Source or Object form, provided that You
 92 |       meet the following conditions:
 93 | 
 94 |       (a) You must give any other recipients of the Work or
 95 |           Derivative Works a copy of this License; and
 96 | 
 97 |       (b) You must cause any modified files to carry prominent notices
 98 |           stating that You changed the files; and
 99 | 
100 |       (c) You must retain, in the Source form of any Derivative Works
101 |           that You distribute, all copyright, patent, trademark, and
102 |           attribution notices from the Source form of the Work,
103 |           excluding those notices that do not pertain to any part of
104 |           the Derivative Works; and
105 | 
106 |       (d) If the Work includes a "NOTICE" text file as part of its
107 |           distribution, then any Derivative Works that You distribute must
108 |           include a readable copy of the attribution notices contained
109 |           within such NOTICE file, excluding those notices that do not
110 |           pertain to any part of the Derivative Works, in at least one
111 |           of the following places: within a NOTICE text file distributed
112 |           as part of the Derivative Works; within the Source form or
113 |           documentation, if provided along with the Derivative Works; or,
114 |           within a display generated by the Derivative Works, if and
115 |           wherever such third-party notices normally appear. The contents
116 |           of the NOTICE file are for informational purposes only and
117 |           do not modify the License. You may add Your own attribution
118 |           notices within Derivative Works that You distribute, alongside
119 |           or as an addendum to the NOTICE text from the Work, provided
120 |           that such additional attribution notices cannot be construed
121 |           as modifying the License.
122 | 
123 |       You may add Your own copyright statement to Your modifications and
124 |       may provide additional or different license terms and conditions
125 |       for use, reproduction, or distribution of Your modifications, or
126 |       for any such Derivative Works as a whole, provided Your use,
127 |       reproduction, and distribution of the Work otherwise complies with
128 |       the conditions stated in this License.
129 | 
130 |    5. Submission of Contributions. Unless You explicitly state otherwise,
131 |       any Contribution intentionally submitted for inclusion in the Work
132 |       by You to the Licensor shall be under the terms and conditions of
133 |       this License, without any additional terms or conditions.
134 |       Notwithstanding the above, nothing herein shall supersede or modify
135 |       the terms of any separate license agreement you may have executed
136 |       with Licensor regarding such Contributions.
137 | 
138 |    6. Trademarks. This License does not grant permission to use the trade
139 |       names, trademarks, service marks, or product names of the Licensor,
140 |       except as required for reasonable and customary use in describing the
141 |       origin of the Work and reproducing the content of the NOTICE file.
142 | 
143 |    7. Disclaimer of Warranty. Unless required by applicable law or
144 |       agreed to in writing, Licensor provides the Work (and each
145 |       Contributor provides its Contributions) on an "AS IS" BASIS,
146 |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 |       implied, including, without limitation, any warranties or conditions
148 |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 |       PARTICULAR PURPOSE. You are solely responsible for determining the
150 |       appropriateness of using or redistributing the Work and assume any
151 |       risks associated with Your exercise of permissions under this License.
152 | 
153 |    8. Limitation of Liability. In no event and under no legal theory,
154 |       whether in tort (including negligence), contract, or otherwise,
155 |       unless required by applicable law (such as deliberate and grossly
156 |       negligent acts) or agreed to in writing, shall any Contributor be
157 |       liable to You for damages, including any direct, indirect, special,
158 |       incidental, or consequential damages of any character arising as a
159 |       result of this License or out of the use or inability to use the
160 |       Work (including but not limited to damages for loss of goodwill,
161 |       work stoppage, computer failure or malfunction, or any and all
162 |       other commercial damages or losses), even if such Contributor
163 |       has been advised of the possibility of such damages.
164 | 
165 |    9. Accepting Warranty or Additional Liability. While redistributing
166 |       the Work or Derivative Works thereof, You may choose to offer,
167 |       and charge a fee for, acceptance of support, warranty, indemnity,
168 |       or other liability obligations and/or rights consistent with this
169 |       License. However, in accepting such obligations, You may act only
170 |       on Your own behalf and on Your sole responsibility, not on behalf
171 |       of any other Contributor, and only if You agree to indemnify,
172 |       defend, and hold each Contributor harmless for any liability
173 |       incurred by, or claims asserted against, such Contributor by reason
174 |       of your accepting any such warranty or additional liability.
175 | 
176 |    END OF TERMS AND CONDITIONS
177 | 
178 |    APPENDIX: How to apply the Apache License to your work.
179 | 
180 |       To apply the Apache License to your work, attach the following
181 |       boilerplate notice, with the fields enclosed by brackets "[]"
182 |       replaced with your own identifying information. (Don't include
183 |       the brackets!)  The text should be enclosed in the appropriate
184 |       comment syntax for the file format. We also recommend that a
185 |       file or class name and description of purpose be included on the
186 |       same "printed page" as the copyright notice for easier
187 |       identification within third-party archives.
188 | 
189 |    Copyright [yyyy] [name of copyright owner]
190 | 
191 |    Licensed under the Apache License, Version 2.0 (the "License");
192 |    you may not use this file except in compliance with the License.
193 |    You may obtain a copy of the License at
194 | 
195 |        http://www.apache.org/licenses/LICENSE-2.0
196 | 
197 |    Unless required by applicable law or agreed to in writing, software
198 |    distributed under the License is distributed on an "AS IS" BASIS,
199 |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 |    See the License for the specific language governing permissions and
201 |    limitations under the License.
202 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | <div align="center">
  2 | <h1> Variational Inference with adversarial learning for end-to-end Singing Voice Synthesis </h1>
  3 | 
  4 | Different from VISinger, It is just VITS without MAS and DurationPredictor. 
  5 | 
  6 | 作为一个用于学习的项目,就这样了:Pitch的预测是需要改进的地方
  7 | 
  8 | ![VISinger](https://github.com/MaxMax2016/VI-SVS/assets/16432329/c76ca716-b230-4852-b8f0-2c3041af7072)
  9 | 
 10 | ![VI-SVS](https://github.com/MaxMax2016/VI-SVS/assets/16432329/128c0f33-4428-4b57-9cd3-b6237f53c1a4)
 11 | 
 12 | </div>
 13 | 
 14 | **Pitch and Duration will be developed as add-on!**
 15 | 
 16 | # 训练步骤
 17 | 
 18 | - 1 下载数据 segments.zip,并解压
 19 | 
 20 | ```
 21 | segments
 22 | |-- test.txt
 23 | |-- train.txt
 24 | |-- transcriptions.txt
 25 | `-- wavs
 26 |     |-- 2001000001.wav
 27 |     |-- 2001000002.wav
 28 |     |-- 2001000003.wav
 29 | ```
 30 | 
 31 | - 2 转换采样率: 本项目采用32KHz
 32 | ```
 33 | python util/resample.py -w segments/wavs/ -o data_svs/wavs -s 32000
 34 | ```
 35 | 
 36 | - 3 生成数据标注
 37 | ```
 38 | python util/generate_label.py --config configs/singing_base.yaml --data data_svs/ --file segments/transcriptions.txt
 39 | ```
 40 | 
 41 | data_svs/labels.txt,内容格式:wave path|label path|score path|pitch path|slurs path
 42 | 
 43 | - 3 划分训练索引
 44 | ```
 45 | python util/generate_label.py --file data_svs/labels.txt
 46 | ```
 47 | 
 48 | 生成 filelists/singing_train.txt 和 filelists/singing_valid.txt
 49 | 
 50 | - 4 启动训练
 51 | ```
 52 | python svs_train.py -c configs/singing_base.yaml -n vits_svs
 53 | ```
 54 | 
 55 | - 5 训练Pitch
 56 | ```
 57 | python pit_train.py -c configs/singing_base.yaml -n pitch
 58 | ```
 59 | 
 60 | # 推理验证
 61 | 
 62 | - 0 模型导出
 63 | ```
 64 | python svs_export.py --config configs/singing_base.yaml --model chkpt/vits_svs/vits_svs_****.pt
 65 | ```
 66 | 
 67 | - 1 推理验证: F0根据乐谱生成
 68 | ```
 69 | python svs_infer.py --config configs/singing_base.yaml --model svs_opencpop.pt
 70 | ```
 71 | 
 72 | - 2 完整歌曲合成([使用release模型](https://github.com/PlayVoice/VI-SVS/releases/tag/0.0.3))
 73 | ```
 74 | python svs_song.py --config configs/singing_base.yaml --model svs_opencpop.pt
 75 | ```
 76 | 
 77 | # 推理验证,使用Pitch预测,效果不佳
 78 | 
 79 | - 0 模型导出
 80 | ```
 81 | python svs_export.py --config configs/singing_base.yaml --model chkpt/vits_svs/vits_svs_****.pt
 82 | ```
 83 | 
 84 | ```
 85 | python pit_export.py --config configs/singing_base.yaml --model chkpt/pitch/pitch_****.pt
 86 | ```
 87 | 
 88 | - 1 推理验证
 89 | ```
 90 | python svs_infer_pitch.py --config configs/singing_base.yaml --model svs_opencpop.pt --pitch pit_opencpop.pt
 91 | ```
 92 | 
 93 | - 2 完整歌曲合成([使用release模型](https://github.com/PlayVoice/VI-SVS/releases/tag/0.0.3))
 94 | ```
 95 | python svs_song_pitch.py --config configs/singing_base.yaml --model svs_opencpop.pt --pitch pit_opencpop.pt
 96 | ```
 97 | 
 98 | # 数据
 99 | 
100 | https://wenet.org.cn/opencpop/
101 | 
102 | # 歌声合成参考
103 | 
104 | https://github.com/SJTMusicTeam/Muskits
105 | 
106 | https://github.com/MoonInTheRiver/DiffSinger
107 | 
108 | [VISinger: Variational Inference with Adversarial Learning for End-to-End Singing Voice Synthesis](https://arxiv.org/abs/2110.08813)
109 | 
110 | # 模型设计参考
111 | 
112 | https://github.com/NVIDIA/BigVGAN
113 | 
114 | https://github.com/jaywalnut310/vits
115 | 
116 | https://github.com/mindslab-ai/univnet
117 | 
118 | https://github.com/PlayVoice/so-vits-svc-5.0
119 | 
120 | https://github.com/shivammehta25/Matcha-TTS
121 | 
122 | [RoFormer: Enhanced Transformer with rotary position embedding](https://arxiv.org/abs/2104.09864)
123 | 
124 | # Diffusion Pitch
125 | 
126 | https://github.com/thuhcsi/DiffVar
127 | 
128 | https://github.com/hayeong0/Diff-HierVC
129 | 
130 | https://github.com/tonnetonne814/SiFi-VITS2-44100-Ja
131 | 
132 | [Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech](https://arxiv.org/abs/2105.06337)
133 | 
134 | # Diffusion Pitch of Diff-HierVC
135 | ![DiffPitch](https://github.com/PlayVoice/VI-SVS/assets/16432329/055d75a4-7009-46c1-8603-65254cec47dd)
136 | 


--------------------------------------------------------------------------------
/configs/requirements.txt:
--------------------------------------------------------------------------------
 1 | Cython==0.29.21
 2 | librosa==0.8.0
 3 | matplotlib==3.3.1
 4 | numpy==1.18.5
 5 | phonemizer==2.2.1
 6 | scipy==1.5.2
 7 | tensorboard==2.3.0
 8 | torch==1.6.0
 9 | torchvision==0.7.0
10 | Unidecode==1.1.1
11 | 


--------------------------------------------------------------------------------
/configs/singing_base.yaml:
--------------------------------------------------------------------------------
 1 | train:
 2 |   model: "vits-svs"
 3 |   seed: 1234
 4 |   epochs: 10000
 5 |   learning_rate: 1e-4
 6 |   betas: [0.8, 0.99]
 7 |   lr_decay: 0.999875
 8 |   eps: 1e-9
 9 |   batch_size: 6
10 |   c_stft: 9
11 |   c_mel: 1.
12 |   c_kl: 0.2
13 |   port: 8001
14 |   pretrain: ""
15 | #############################
16 | data: 
17 |   training_files: "filelists/singing_train.txt"
18 |   validation_files: "filelists/singing_valid.txt"
19 |   segment_size: 8000  # WARNING: base on hop_length
20 |   max_wav_value: 32768.0
21 |   sampling_rate: 32000
22 |   filter_length: 1024
23 |   hop_length: 320
24 |   win_length: 1024
25 |   mel_channels: 100
26 |   mel_fmin: 50.0
27 |   mel_fmax: 16000.0
28 | #############################
29 | vits:
30 |   gin_channels: 0
31 |   inter_channels: 192
32 |   hidden_channels: 192
33 |   filter_channels: 640
34 | #############################
35 | gen:
36 |   upsample_input: 192
37 |   upsample_rates: [5,4,4,2,2]
38 |   upsample_kernel_sizes: [15,8,8,4,4]
39 |   upsample_initial_channel: 480
40 |   resblock_kernel_sizes: [3,7,11]
41 |   resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
42 | #############################
43 | mpd:
44 |   periods: [2,3,5,7,11]
45 |   kernel_size: 5
46 |   stride: 3
47 |   use_spectral_norm: False
48 |   lReLU_slope: 0.2
49 | #############################
50 | mrd:
51 |   resolutions: "[(1024, 120, 600), (2048, 240, 1200), (4096, 480, 2400), (512, 50, 240)]" # (filter_length, hop_length, win_length)
52 |   use_spectral_norm: False
53 |   lReLU_slope: 0.2
54 | #############################
55 | log:
56 |   info_interval: 100
57 |   eval_interval: 1
58 |   save_interval: 5
59 |   num_audio: 6
60 |   pth_dir: 'chkpt'
61 |   log_dir: 'logs'
62 |   keep_ckpts: 0
63 | #############################
64 | dist_config:
65 |   dist_backend: "nccl"
66 |   dist_url: "tcp://localhost:54321"
67 |   world_size: 1
68 | 
69 | 


--------------------------------------------------------------------------------
/pit_export.py:
--------------------------------------------------------------------------------
 1 | import sys,os
 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 3 | import torch
 4 | import argparse
 5 | 
 6 | from pitch.models import PitchDiffusion
 7 | 
 8 | 
 9 | def load_model(checkpoint_path, model):
10 |     assert os.path.isfile(checkpoint_path)
11 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
12 |     saved_state_dict = checkpoint_dict["model_g"]
13 |     if hasattr(model, "module"):
14 |         state_dict = model.module.state_dict()
15 |     else:
16 |         state_dict = model.state_dict()
17 |     new_state_dict = {}
18 |     for k, v in state_dict.items():
19 |         try:
20 |             new_state_dict[k] = saved_state_dict[k]
21 |         except:
22 |             new_state_dict[k] = v
23 |     if hasattr(model, "module"):
24 |         model.module.load_state_dict(new_state_dict)
25 |     else:
26 |         model.load_state_dict(new_state_dict)
27 |     return model
28 | 
29 | 
30 | def save_model(model, checkpoint_path):
31 |     if hasattr(model, 'module'):
32 |         state_dict = model.module.state_dict()
33 |     else:
34 |         state_dict = model.state_dict()
35 |     torch.save({'model_g': state_dict}, checkpoint_path)
36 | 
37 | 
38 | def main(args):
39 |     model = PitchDiffusion()
40 |     load_model(args.model, model)
41 |     save_model(model, "pit_opencpop.pt")
42 | 
43 | 
44 | if __name__ == '__main__':
45 |     parser = argparse.ArgumentParser()
46 |     parser.add_argument('-c', '--config', type=str, required=True,
47 |                         help="yaml file for config. will use hp_str from checkpoint if not given.")
48 |     parser.add_argument('-m', '--model', type=str, required=True,
49 |                         help="path of checkpoint pt file for evaluation")
50 |     args = parser.parse_args()
51 | 
52 |     main(args)
53 | 


--------------------------------------------------------------------------------
/pit_train.py:
--------------------------------------------------------------------------------
 1 | import sys,os
 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 3 | import argparse
 4 | import torch
 5 | import torch.multiprocessing as mp
 6 | from omegaconf import OmegaConf
 7 | 
 8 | from pitch_extend.train import train
 9 | 
10 | torch.backends.cudnn.benchmark = True
11 | 
12 | 
13 | if __name__ == '__main__':
14 |     parser = argparse.ArgumentParser()
15 |     parser.add_argument('-c', '--config', type=str, required=True,
16 |                         help="yaml file for configuration")
17 |     parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
18 |                         help="path of checkpoint pt file to resume training")
19 |     parser.add_argument('-n', '--name', type=str, required=True,
20 |                         help="name of the model for logging, saving checkpoint")
21 |     args = parser.parse_args()
22 | 
23 |     hp = OmegaConf.load(args.config)
24 |     with open(args.config, 'r') as f:
25 |         hp_str = ''.join(f.readlines())
26 | 
27 |     assert hp.data.hop_length == 320, \
28 |         'hp.data.hop_length must be equal to 320, got %d' % hp.data.hop_length
29 | 
30 |     args.num_gpus = 0
31 |     torch.manual_seed(hp.train.seed)
32 |     if torch.cuda.is_available():
33 |         torch.cuda.manual_seed(hp.train.seed)
34 |         args.num_gpus = torch.cuda.device_count()
35 |         print('Batch size per GPU :', hp.train.batch_size)
36 | 
37 |         if args.num_gpus > 1:
38 |             mp.spawn(train, nprocs=args.num_gpus,
39 |                      args=(args, args.checkpoint_path, hp, hp_str,))
40 |         else:
41 |             train(0, args, args.checkpoint_path, hp, hp_str)
42 |     else:
43 |         print('No GPU find!')
44 | 


--------------------------------------------------------------------------------
/pitch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/pitch/__init__.py


--------------------------------------------------------------------------------
/pitch/base.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import torch
 3 | 
 4 | 
 5 | class BaseModule(torch.nn.Module):
 6 |     def __init__(self):
 7 |         super(BaseModule, self).__init__()
 8 | 
 9 |     @property
10 |     def nparams(self):
11 |         """
12 |         Returns number of trainable parameters of the module.
13 |         """
14 |         num_params = 0
15 |         for name, param in self.named_parameters():
16 |             if param.requires_grad:
17 |                 num_params += np.prod(param.detach().cpu().numpy().shape)
18 |         return num_params
19 | 
20 | 
21 |     def relocate_input(self, x: list):
22 |         """
23 |         Relocates provided tensors to the same device set for the module.
24 |         """
25 |         device = next(self.parameters()).device
26 |         for i in range(len(x)):
27 |             if isinstance(x[i], torch.Tensor) and x[i].device != device:
28 |                 x[i] = x[i].to(device)
29 |         return x
30 | 


--------------------------------------------------------------------------------
/pitch/data_utils.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import numpy as np
  3 | import torch
  4 | import torch.utils.data
  5 | 
  6 | from vits.utils import load_wav_to_torch
  7 | from vits.spectrogram import spectrogram_torch
  8 | from pitch.utils import fix_len_compatibility
  9 | 
 10 | 
 11 | def load_filepaths(filename, split="|"):
 12 |     with open(filename, encoding='utf-8') as f:
 13 |         filepaths = [line.strip().split(split) for line in f]
 14 |     return filepaths
 15 | 
 16 | 
 17 | class TextAudioLoader(torch.utils.data.Dataset):
 18 |     """
 19 |     1) loads audio, text pairs
 20 |     2) normalizes text and converts them to sequences of integers
 21 |     3) computes spectrograms from audio files.
 22 |     """
 23 | 
 24 |     def __init__(self, audiopaths_and_text, hparams):
 25 |         self.audiopaths_and_text = load_filepaths(audiopaths_and_text)
 26 |         self.max_wav_value  = hparams.max_wav_value
 27 |         self.sampling_rate  = hparams.sampling_rate
 28 |         self.filter_length  = hparams.filter_length 
 29 |         self.hop_length     = hparams.hop_length 
 30 |         self.win_length     = hparams.win_length
 31 |         self.sampling_rate  = hparams.sampling_rate 
 32 |         self.min_text_len   = getattr(hparams, "min_text_len", 1)
 33 |         self.max_text_len   = getattr(hparams, "max_text_len", 5000)
 34 |         self._filter()
 35 |         print(f"~~~~~~~~~~~~~~~~~~~~~{len(self)}~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
 36 | 
 37 |     def _filter(self):
 38 |         """
 39 |         Filter text & store spec lengths
 40 |         """
 41 |         # Store spectrogram lengths for Bucketing
 42 |         # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
 43 |         # spec_length = wav_length // hop_length
 44 |         audiopaths_and_text_new = []
 45 |         lengths = []
 46 |         for audiopath, text, score, pitch, slur in self.audiopaths_and_text:
 47 |             if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
 48 |                 wav_len = os.path.getsize(audiopath) // (2 * self.hop_length)
 49 |                 if wav_len < 50: # no use short wave
 50 |                     continue
 51 |                 audiopaths_and_text_new.append([audiopath, text, score, pitch, slur])
 52 |                 lengths.append(wav_len)
 53 |         self.audiopaths_and_text = audiopaths_and_text_new
 54 |         self.lengths = lengths
 55 | 
 56 |     def get_audio_text_pair(self, audiopath_and_text):
 57 |         # separate filename and text
 58 |         file = audiopath_and_text[0]
 59 |         phone = audiopath_and_text[1]
 60 |         score = audiopath_and_text[2]
 61 |         pitch = audiopath_and_text[3]
 62 |         slurs = audiopath_and_text[4]
 63 | 
 64 |         phone, score, pitch, slurs = self.get_labels(phone, score, pitch, slurs)
 65 |         spec, wav = self.get_audio(file)
 66 | 
 67 |         len_phone = phone.size()[0]
 68 |         len_spec = spec.size()[-1]
 69 | 
 70 |         if len_phone != len_spec:
 71 |             # print("**************CareFull*******************")
 72 |             # print(f"filepath={audiopath_and_text[0]}")
 73 |             # print(f"len_text={len_phone}")
 74 |             # print(f"len_spec={len_spec}")
 75 |             if len_phone > len_spec:
 76 |                 print(file)
 77 |                 print("len_phone", len_phone)
 78 |                 print("len_spec", len_spec)
 79 |             assert len_phone < len_spec
 80 |             len_min = min(len_phone, len_spec)
 81 |             len_wav = len_min * self.hop_length
 82 |             # print(wav.size())
 83 |             # print(f"len_min={len_min}")
 84 |             # print(f"len_wav={len_wav}")
 85 |             spec = spec[:, :len_min]
 86 |             wav = wav[:, :len_wav]
 87 |         return (phone, score, pitch, slurs, spec, wav)
 88 | 
 89 |     def get_labels(self, phone, score, pitch, slurs):
 90 |         phone = np.load(phone)
 91 |         score = np.load(score)
 92 |         pitch = np.load(pitch)
 93 |         slurs = np.load(slurs)
 94 |         phone = torch.LongTensor(phone)
 95 |         score = torch.LongTensor(score)
 96 |         pitch = torch.FloatTensor(pitch)
 97 |         slurs = torch.LongTensor(slurs)
 98 |         return phone, score, pitch, slurs
 99 | 
100 |     def get_audio(self, filename):
101 |         audio, sampling_rate = load_wav_to_torch(filename)
102 |         if sampling_rate != self.sampling_rate:
103 |             raise ValueError(
104 |                 "{} {} SR doesn't match target {} SR".format(
105 |                     sampling_rate, self.sampling_rate
106 |                 )
107 |             )
108 |         audio_norm = audio / self.max_wav_value
109 |         audio_norm = audio_norm.unsqueeze(0)
110 |         spec_filename = filename.replace(".wav", ".spec.pt")
111 |         if os.path.exists(spec_filename):
112 |             spec = torch.load(spec_filename)
113 |         else:
114 |             spec = spectrogram_torch(
115 |                 audio_norm,
116 |                 self.filter_length,
117 |                 self.sampling_rate,
118 |                 self.hop_length,
119 |                 self.win_length,
120 |                 center=False,
121 |             )
122 |             spec = torch.squeeze(spec, 0)
123 |             torch.save(spec, spec_filename)
124 |         return spec, audio_norm
125 | 
126 |     def __getitem__(self, index):
127 |         return self.get_audio_text_pair(self.audiopaths_and_text[index])
128 | 
129 |     def __len__(self):
130 |         return len(self.audiopaths_and_text)
131 | 
132 | 
133 | class TextAudioCollate:
134 |     """Zero-pads model inputs and targets"""
135 | 
136 |     def __init__(self, return_ids=False):
137 |         self.return_ids = return_ids
138 | 
139 |     def __call__(self, batch):
140 |         """Collate's training batch from normalized text and aduio
141 |         PARAMS
142 |         ------
143 |         batch: [text_normalized, spec_normalized, wav_normalized]
144 |         """
145 |         # Right zero-pad all one-hot text sequences to max input length
146 |         _, ids_sorted_decreasing = torch.sort(
147 |             torch.LongTensor([x[4].size(1) for x in batch]), dim=0, descending=True
148 |         )
149 | 
150 |         max_phone_len = max([len(x[0]) for x in batch])
151 |         # For Unet
152 |         max_phone_len = fix_len_compatibility(max_phone_len)
153 | 
154 |         phone_lengths = torch.LongTensor(len(batch))
155 |         phone_padded = torch.LongTensor(len(batch), max_phone_len)
156 |         score_padded = torch.LongTensor(len(batch), max_phone_len)
157 |         pitch_padded = torch.FloatTensor(len(batch), max_phone_len)
158 |         slurs_padded = torch.LongTensor(len(batch), max_phone_len)
159 |         phone_padded.zero_()
160 |         score_padded.zero_()
161 |         pitch_padded.zero_()
162 |         slurs_padded.zero_()
163 | 
164 |         for i in range(len(ids_sorted_decreasing)):
165 |             row = batch[ids_sorted_decreasing[i]]
166 | 
167 |             phone = row[0]
168 |             phone_padded[i, : phone.size(0)] = phone
169 |             phone_lengths[i] = phone.size(0)
170 | 
171 |             score = row[1]
172 |             score_padded[i, : score.size(0)] = score
173 | 
174 |             pitch = row[2]
175 |             pitch_padded[i, : pitch.size(0)] = pitch
176 | 
177 |             slurs = row[3]
178 |             slurs_padded[i, : slurs.size(0)] = slurs
179 | 
180 |         return (
181 |             phone_padded,
182 |             phone_lengths,
183 |             score_padded,
184 |             pitch_padded,
185 |             slurs_padded,
186 |         )
187 | 
188 | 
189 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
190 |     """
191 |     Maintain similar input lengths in a batch.
192 |     Length groups are specified by boundaries.
193 |     Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
194 | 
195 |     It removes samples which are not included in the boundaries.
196 |     Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
197 |     """
198 | 
199 |     def __init__(
200 |         self,
201 |         dataset,
202 |         batch_size,
203 |         boundaries,
204 |         num_replicas=None,
205 |         rank=None,
206 |         shuffle=True,
207 |     ):
208 |         super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
209 |         self.lengths = dataset.lengths
210 |         self.batch_size = batch_size
211 |         self.boundaries = boundaries
212 | 
213 |         self.buckets, self.num_samples_per_bucket = self._create_buckets()
214 |         self.total_size = sum(self.num_samples_per_bucket)
215 |         self.num_samples = self.total_size // self.num_replicas
216 | 
217 |     def _create_buckets(self):
218 |         buckets = [[] for _ in range(len(self.boundaries) - 1)]
219 |         for i in range(len(self.lengths)):
220 |             length = self.lengths[i]
221 |             idx_bucket = self._bisect(length)
222 |             if idx_bucket != -1:
223 |                 buckets[idx_bucket].append(i)
224 | 
225 |         for i in range(len(buckets) - 1, 0, -1):
226 |             if len(buckets[i]) == 0:
227 |                 buckets.pop(i)
228 |                 self.boundaries.pop(i + 1)
229 | 
230 |         num_samples_per_bucket = []
231 |         for i in range(len(buckets)):
232 |             len_bucket = len(buckets[i])
233 |             total_batch_size = self.num_replicas * self.batch_size
234 |             rem = (
235 |                 total_batch_size - (len_bucket % total_batch_size)
236 |             ) % total_batch_size
237 |             num_samples_per_bucket.append(len_bucket + rem)
238 |         return buckets, num_samples_per_bucket
239 | 
240 |     def __iter__(self):
241 |         # deterministically shuffle based on epoch
242 |         g = torch.Generator()
243 |         g.manual_seed(self.epoch)
244 | 
245 |         indices = []
246 |         if self.shuffle:
247 |             for bucket in self.buckets:
248 |                 indices.append(torch.randperm(len(bucket), generator=g).tolist())
249 |         else:
250 |             for bucket in self.buckets:
251 |                 indices.append(list(range(len(bucket))))
252 | 
253 |         batches = []
254 |         for i in range(len(self.buckets)):
255 |             bucket = self.buckets[i]
256 |             len_bucket = len(bucket)
257 |             if (len_bucket == 0):
258 |                 continue
259 |             ids_bucket = indices[i]
260 |             num_samples_bucket = self.num_samples_per_bucket[i]
261 | 
262 |             # add extra samples to make it evenly divisible
263 |             rem = num_samples_bucket - len_bucket
264 |             ids_bucket = (
265 |                 ids_bucket
266 |                 + ids_bucket * (rem // len_bucket)
267 |                 + ids_bucket[: (rem % len_bucket)]
268 |             )
269 | 
270 |             # subsample
271 |             ids_bucket = ids_bucket[self.rank:: self.num_replicas]
272 | 
273 |             # batching
274 |             for j in range(len(ids_bucket) // self.batch_size):
275 |                 batch = [
276 |                     bucket[idx]
277 |                     for idx in ids_bucket[
278 |                         j * self.batch_size: (j + 1) * self.batch_size
279 |                     ]
280 |                 ]
281 |                 batches.append(batch)
282 | 
283 |         if self.shuffle:
284 |             batch_ids = torch.randperm(len(batches), generator=g).tolist()
285 |             batches = [batches[i] for i in batch_ids]
286 |         self.batches = batches
287 | 
288 |         assert len(self.batches) * self.batch_size == self.num_samples
289 |         return iter(self.batches)
290 | 
291 |     def _bisect(self, x, lo=0, hi=None):
292 |         if hi is None:
293 |             hi = len(self.boundaries) - 1
294 | 
295 |         if hi > lo:
296 |             mid = (hi + lo) // 2
297 |             if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
298 |                 return mid
299 |             elif x <= self.boundaries[mid]:
300 |                 return self._bisect(x, lo, mid)
301 |             else:
302 |                 return self._bisect(x, mid + 1, hi)
303 |         else:
304 |             return -1
305 | 
306 |     def __len__(self):
307 |         return self.num_samples // self.batch_size
308 | 


--------------------------------------------------------------------------------
/pitch/diffusion.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import torch
  3 | from einops import rearrange
  4 | from pitch.base import BaseModule
  5 | 
  6 | 
  7 | class Mish(BaseModule):
  8 |     def forward(self, x):
  9 |         return x * torch.tanh(torch.nn.functional.softplus(x))
 10 | 
 11 | 
 12 | class Rezero(BaseModule):
 13 |     def __init__(self, fn):
 14 |         super(Rezero, self).__init__()
 15 |         self.fn = fn
 16 |         self.g = torch.nn.Parameter(torch.zeros(1))
 17 | 
 18 |     def forward(self, x):
 19 |         return self.fn(x) * self.g
 20 | 
 21 | 
 22 | class Block(BaseModule):
 23 |     def __init__(self, dim, dim_out, groups=8):
 24 |         super(Block, self).__init__()
 25 |         self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, 
 26 |                                          padding=1), torch.nn.GroupNorm(
 27 |                                          groups, dim_out), Mish())
 28 | 
 29 |     def forward(self, x, mask):
 30 |         output = self.block(x * mask)
 31 |         return output * mask
 32 | 
 33 | 
 34 | class ResnetBlock(BaseModule):
 35 |     def __init__(self, dim, dim_out, time_emb_dim, groups=8):
 36 |         super(ResnetBlock, self).__init__()
 37 |         self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 
 38 |                                                                dim_out))
 39 | 
 40 |         self.block1 = Block(dim, dim_out, groups=groups)
 41 |         self.block2 = Block(dim_out, dim_out, groups=groups)
 42 |         if dim != dim_out:
 43 |             self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
 44 |         else:
 45 |             self.res_conv = torch.nn.Identity()
 46 | 
 47 |     def forward(self, x, mask, time_emb):
 48 |         h = self.block1(x, mask)
 49 |         h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
 50 |         h = self.block2(h, mask)
 51 |         output = h + self.res_conv(x * mask)
 52 |         return output
 53 | 
 54 | 
 55 | class LinearAttention(BaseModule):
 56 |     def __init__(self, dim, heads=4, dim_head=32):
 57 |         super(LinearAttention, self).__init__()
 58 |         self.heads = heads
 59 |         hidden_dim = dim_head * heads
 60 |         self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
 61 |         self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)            
 62 | 
 63 |     def forward(self, x):
 64 |         b, c, h, w = x.shape
 65 |         qkv = self.to_qkv(x)
 66 |         q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', 
 67 |                             heads = self.heads, qkv=3)            
 68 |         k = k.softmax(dim=-1)
 69 |         context = torch.einsum('bhdn,bhen->bhde', k, v)
 70 |         out = torch.einsum('bhde,bhdn->bhen', context, q)
 71 |         out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', 
 72 |                         heads=self.heads, h=h, w=w)
 73 |         return self.to_out(out)
 74 | 
 75 | 
 76 | class Residual(BaseModule):
 77 |     def __init__(self, fn):
 78 |         super(Residual, self).__init__()
 79 |         self.fn = fn
 80 | 
 81 |     def forward(self, x, *args, **kwargs):
 82 |         output = self.fn(x, *args, **kwargs) + x
 83 |         return output
 84 | 
 85 | 
 86 | class SinusoidalPosEmb(BaseModule):
 87 |     def __init__(self, dim):
 88 |         super(SinusoidalPosEmb, self).__init__()
 89 |         self.dim = dim
 90 | 
 91 |     def forward(self, x, scale=1000):
 92 |         device = x.device
 93 |         half_dim = self.dim // 2
 94 |         emb = math.log(10000) / (half_dim - 1)
 95 |         emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
 96 |         emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
 97 |         emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
 98 |         return emb
 99 | 
100 | 
101 | class GradLogPEstimator2d(BaseModule):
102 |     def __init__(self, n_feat, n_cond, dim, dim_mults=(1, 2, 4), groups=8, pe_scale=1000):
103 |         super(GradLogPEstimator2d, self).__init__()
104 |         self.dim = dim
105 |         self.dim_mults = dim_mults
106 |         self.groups = groups
107 |         self.pe_scale = pe_scale
108 | 
109 |         self.cond = torch.nn.Sequential(torch.nn.Conv1d(n_cond, dim * 4, 1), Mish(),
110 |                                         torch.nn.Conv1d(dim * 4, n_feat, 1))
111 |         self.time_pos_emb = SinusoidalPosEmb(dim)
112 |         self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
113 |                                        torch.nn.Linear(dim * 4, dim))
114 | 
115 |         dims = [2 + 1, *map(lambda m: dim * m, dim_mults)]
116 |         in_out = list(zip(dims[:-1], dims[1:]))
117 |         self.downs = torch.nn.ModuleList([])
118 |         self.ups = torch.nn.ModuleList([])
119 |         num_resolutions = len(in_out)
120 | 
121 |         for ind, (dim_in, dim_out) in enumerate(in_out):  # 2 downs
122 |             is_last = ind >= (num_resolutions - 1)
123 |             self.downs.append(torch.nn.ModuleList([
124 |                        ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
125 |                        ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
126 |                        Residual(Rezero(LinearAttention(dim_out))),
127 |                        torch.nn.Identity()]))
128 | 
129 |         mid_dim = dims[-1]
130 |         self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
131 |         self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
132 |         self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
133 | 
134 |         for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):  # 2 ups
135 |             self.ups.append(torch.nn.ModuleList([
136 |                      ResnetBlock(dim_out, dim_in, time_emb_dim=dim),
137 |                      ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
138 |                      Residual(Rezero(LinearAttention(dim_in))),
139 |                      torch.nn.Identity()]))
140 |         self.final_block = Block(dim, dim)
141 |         self.final_conv = torch.nn.Conv2d(dim, 1, 1)
142 | 
143 |     def forward(self, x, mask, mu, c, t):
144 | 
145 |         t = self.time_pos_emb(t, scale=self.pe_scale)
146 |         t = self.mlp(t)
147 |         c = self.cond(c)
148 |  
149 |         x = torch.stack([mu, x, c], 1)
150 |         mask = mask.unsqueeze(1)
151 | 
152 |         for resnet1, resnet2, attn, downsample in self.downs:
153 |             x = resnet1(x, mask, t)
154 |             x = resnet2(x, mask, t)
155 |             x = attn(x)
156 |             x = downsample(x * mask)
157 | 
158 |         x = self.mid_block1(x, mask, t)
159 |         x = self.mid_attn(x)
160 |         x = self.mid_block2(x, mask, t)
161 | 
162 |         for resnet1, resnet2, attn, upsample in self.ups:
163 |             x = resnet1(x, mask, t)
164 |             x = resnet2(x, mask, t)
165 |             x = attn(x)
166 |             x = upsample(x * mask)
167 | 
168 |         x = self.final_block(x, mask)
169 |         output = self.final_conv(x * mask)
170 | 
171 |         return (output * mask).squeeze(1)
172 | 
173 | 
174 | class Diffusion(BaseModule):
175 |     def __init__(self, n_feat, n_cond, dim, beta_min=0.05, beta_max=20, pe_scale=1000):
176 |         super(Diffusion, self).__init__()
177 |         self.estimator = GradLogPEstimator2d(n_feat, n_cond, dim, pe_scale=pe_scale)
178 |         self.n_feat = n_feat
179 |         self.beta_min = beta_min
180 |         self.beta_max = beta_max
181 | 
182 |     def get_beta(self, t):
183 |         beta = self.beta_min + (self.beta_max - self.beta_min) * t
184 |         return beta
185 | 
186 |     def get_gamma(self, s, t, p=1.0, use_torch=False):
187 |         beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s)
188 |         beta_integral *= (t - s)
189 |         if use_torch:
190 |             gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1)
191 |         else:
192 |             gamma = math.exp(-0.5 * p * beta_integral)
193 |         return gamma
194 | 
195 |     def get_mu(self, s, t):
196 |         a = self.get_gamma(s, t)
197 |         b = 1.0 - self.get_gamma(0, s, p=2.0)
198 |         c = 1.0 - self.get_gamma(0, t, p=2.0)
199 |         return a * b / c
200 | 
201 |     def get_nu(self, s, t):
202 |         a = self.get_gamma(0, s)
203 |         b = 1.0 - self.get_gamma(s, t, p=2.0)
204 |         c = 1.0 - self.get_gamma(0, t, p=2.0)
205 |         return a * b / c
206 | 
207 |     def get_sigma(self, s, t):
208 |         a = 1.0 - self.get_gamma(0, s, p=2.0)
209 |         b = 1.0 - self.get_gamma(s, t, p=2.0)
210 |         c = 1.0 - self.get_gamma(0, t, p=2.0)
211 |         return math.sqrt(a * b / c)
212 | 
213 |     @torch.no_grad()
214 |     def reverse_diffusion(self, z, mask, mu, mu_c, n_timesteps):
215 |         h = 1.0 / n_timesteps
216 |         xt = z * mask
217 | 
218 |         for i in range(n_timesteps):
219 |             t = 1.0 - i * h
220 |             time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
221 |             beta_t = self.get_beta(t) 
222 |             
223 |             kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0))
224 |             kappa /= (self.get_gamma(0, t) * beta_t * h)
225 |             kappa -= 1.0
226 |             omega = self.get_nu(t - h, t) / self.get_gamma(0, t)
227 |             omega += self.get_mu(t - h, t)
228 |             omega -= (0.5 * beta_t * h + 1.0)
229 |             sigma = self.get_sigma(t - h, t)  
230 | 
231 |             dxt = (mu - xt) * (0.5 * beta_t * h + omega) 
232 |             dxt -= (self.estimator(xt, mask, mu, mu_c, time)) * (1.0 + kappa) * (beta_t * h)            
233 |             dxt += torch.randn_like(z, device=z.device) * sigma 
234 |             xt = (xt - dxt) * mask
235 | 
236 |         return xt
237 | 
238 |     @torch.no_grad()
239 |     def forward(self, z, mask, mu, mu_c, n_timesteps):
240 |         return self.reverse_diffusion(z, mask, mu, mu_c, n_timesteps)
241 | 
242 |     # train: mel means f0_groun_truth
243 |     def get_noise(self, t, beta_init, beta_term, cumulative=False):
244 |         if cumulative:
245 |             noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
246 |         else:
247 |             noise = beta_init + (beta_term - beta_init)*t
248 |         return noise
249 | 
250 |     def forward_diffusion(self, mel, mask, mu, t):
251 |         time = t.unsqueeze(-1).unsqueeze(-1)
252 |         cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True)
253 |         mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
254 |         variance = 1.0 - torch.exp(-cum_noise)
255 |         z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, 
256 |                         requires_grad=False)
257 |         xt = mean + z * torch.sqrt(variance)
258 |         return xt * mask, z * mask
259 | 
260 |     def loss_t(self, mel, mask, mu, mu_c, t):
261 |         xt, z = self.forward_diffusion(mel, mask, mu, t)
262 |         time = t.unsqueeze(-1).unsqueeze(-1)
263 |         cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True)
264 |         noise_estimation = self.estimator(xt, mask, mu, mu_c, t)
265 |         noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
266 |         loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feat)
267 |         return loss, xt
268 | 
269 |     def compute_loss(self, mel, mask, mu, mu_c, offset=1e-5):
270 |         t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False)
271 |         t = torch.clamp(t, offset, 1.0 - offset)
272 |         return self.loss_t(mel, mask, mu, mu_c, t)
273 | 


--------------------------------------------------------------------------------
/pitch/models.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import torch
  3 | import torch.nn.functional as F
  4 | 
  5 | from torch import nn
  6 | from pitch.diffusion import Diffusion
  7 | from pitch.utils import rand_ids_segments, slice_segments
  8 | 
  9 | from vits import attentions
 10 | from vits import commons
 11 | 
 12 | 
 13 | class TextEncoder(nn.Module):
 14 |     def __init__(self,
 15 |                  hidden_channels,
 16 |                  filter_channels,
 17 |                  n_heads,
 18 |                  n_layers,
 19 |                  kernel_size,
 20 |                  p_dropout):
 21 |         super().__init__()
 22 |         self.hidden_channels = hidden_channels
 23 |         self.emb_phone = nn.Embedding(63, hidden_channels)      # phone lables
 24 |         self.emb_score = nn.Embedding(128, hidden_channels)     # pitch notes
 25 |         self.emb_slurs = nn.Embedding(2, hidden_channels)       # phone slur
 26 |         nn.init.normal_(self.emb_phone.weight, 0.0, hidden_channels**-0.5)
 27 |         nn.init.normal_(self.emb_score.weight, 0.0, hidden_channels**-0.5)
 28 |         nn.init.normal_(self.emb_slurs.weight, 0.0, hidden_channels**-0.5)
 29 |         self.enc = attentions.Encoder(
 30 |             hidden_channels,
 31 |             filter_channels,
 32 |             n_heads,
 33 |             n_layers,
 34 |             kernel_size,
 35 |             p_dropout)
 36 |         self.proj = nn.Conv1d(hidden_channels, 2, 1)  # pitch + uv
 37 | 
 38 |     def forward(self, phone, lengths, score, slurs):
 39 |         x = self.emb_phone(phone) + self.emb_score(score) + self.emb_slurs(slurs)
 40 |         x = x * math.sqrt(self.hidden_channels)  # [b, t, h]
 41 |         x = torch.transpose(x, 1, -1)  # [b, h, t]
 42 |         x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
 43 |             x.dtype
 44 |         )
 45 |         x = self.enc(x * x_mask, x_mask)
 46 |         c = x
 47 |         x = self.proj(x)
 48 |         return x, x_mask, c
 49 | 
 50 | 
 51 | class PitchDiffusion(nn.Module):
 52 |     def __init__(self):
 53 |         super().__init__()
 54 |         self.pit_encoder = TextEncoder(hidden_channels=192, filter_channels=768, 
 55 |                                        n_heads=2, n_layers=5, kernel_size=5, p_dropout=0.1)
 56 |         self.decoder = Diffusion(2, 192, 64, beta_min=0.05, beta_max=20.0, pe_scale=1000)
 57 | 
 58 | 
 59 |     @torch.no_grad()
 60 |     def forward(self, phone, lengths, score, slurs, n_timesteps, temperature=1.0):
 61 |         # Encoder
 62 |         mu_x, mask_x, c = self.pit_encoder(phone, lengths, score, slurs)
 63 |         encoder_outputs = mu_x
 64 | 
 65 |         # Sample latent representation from terminal distribution N(mu_y, I)
 66 |         z = mu_x + torch.randn_like(mu_x, device=mu_x.device) / temperature
 67 |         # Generate sample by performing reverse dynamics
 68 |         decoder_outputs = self.decoder(z, mask_x, mu_x, c, n_timesteps)
 69 |         return encoder_outputs, decoder_outputs
 70 | 
 71 |     def compute_loss(self, phone, lengths, score, slurs, pitch, out_size):
 72 |         # Get encoder_outputs `mu_x`
 73 |         mu_x, mask_x, c = self.pit_encoder(phone, lengths, score, slurs)
 74 | 
 75 |         # Compute loss between encoder outputs and pitch
 76 |         floor = torch.ones_like(pitch)
 77 |         pitch = torch.maximum(pitch, floor)
 78 |         pitch = torch.log2(pitch)
 79 |         # Loss
 80 |         loss_f0 = F.l1_loss(mu_x[:, 0, :], pitch)
 81 |         uv_gt = (pitch > 0).to(pitch.dtype)
 82 |         loss_uv = F.binary_cross_entropy_with_logits(mu_x[:, 1, :], uv_gt)
 83 |         prior_loss = loss_f0 + loss_uv
 84 |         # pitch_gt
 85 |         pitch_gt = torch.zeros_like(mu_x, device=mu_x.device)
 86 |         pitch_gt[:, 0, :] = pitch
 87 |         pitch_gt[:, 1, :] = uv_gt
 88 |         # Compute loss of score-based decoder
 89 |         # Cut a small segment of pitch in order to increase batch size
 90 |         if not isinstance(out_size, type(None)) and out_size < pitch_gt.shape[1]:
 91 |             ids = rand_ids_segments(lengths, out_size)
 92 |             pitch_gt = slice_segments(pitch_gt, ids, out_size)
 93 | 
 94 |             mask_x = slice_segments(mask_x, ids, out_size)
 95 |             mu_x = slice_segments(mu_x, ids, out_size)
 96 |             c = slice_segments(c, ids, out_size)
 97 | 
 98 |         diff_loss, xt = self.decoder.compute_loss(pitch_gt, mask_x, mu_x, c)
 99 |         return prior_loss, diff_loss
100 |  
101 | 


--------------------------------------------------------------------------------
/pitch/utils.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import inspect
 3 | 
 4 | 
 5 | def sequence_mask(length, max_length=None):
 6 |     if max_length is None:
 7 |         max_length = length.max()
 8 |     x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
 9 |     return x.unsqueeze(0) < length.unsqueeze(1)
10 | 
11 | 
12 | def fix_len_compatibility(length, num_downsamplings_in_unet=2):
13 |     while True:
14 |         if length % (2**num_downsamplings_in_unet) == 0:
15 |             return length
16 |         length += 1
17 | 
18 | 
19 | def convert_pad_shape(pad_shape):
20 |     l = pad_shape[::-1]
21 |     pad_shape = [item for sublist in l for item in sublist]
22 |     return pad_shape
23 | 
24 | 
25 | def generate_path(duration, mask):
26 |     device = duration.device
27 | 
28 |     b, t_x, t_y = mask.shape
29 |     cum_duration = torch.cumsum(duration, 1)
30 |     path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
31 | 
32 |     cum_duration_flat = cum_duration.view(b * t_x)
33 |     path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
34 |     path = path.view(b, t_x, t_y)
35 |     path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], 
36 |                                           [1, 0], [0, 0]]))[:, :-1]
37 |     path = path * mask
38 |     return path
39 | 
40 | 
41 | def duration_loss(logw, logw_, lengths):
42 |     loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
43 |     return loss
44 | 
45 | 
46 | def rand_ids_segments(lengths, segment_size=200):
47 |     b = lengths.shape[0]
48 |     ids_str_max = lengths - segment_size
49 |     ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(dtype=torch.long)
50 |     ids_str = torch.where(ids_str < 0, 0, ids_str)  # fix error
51 |     return ids_str
52 | 
53 | 
54 | def slice_segments(x, ids_str, segment_size=200):
55 |     ret = torch.zeros_like(x[:, :, :segment_size])
56 |     for i in range(x.size(0)):
57 |         idx_str = ids_str[i]
58 |         idx_end = idx_str + segment_size
59 |         ret[i] = x[i, :, idx_str:idx_end]
60 |     return ret
61 | 
62 | 
63 | def retrieve_name(var):
64 |     for fi in reversed(inspect.stack()):
65 |         names = [var_name for var_name,
66 |                  var_val in fi.frame.f_locals.items() if var_val is var]
67 |         if len(names) > 0:
68 |             return names[0]
69 | 
70 | 
71 | Debug_Enable = True
72 | 
73 | 
74 | def debug_shapes(var):
75 |     if Debug_Enable:
76 |         print(retrieve_name(var), var.shape)
77 | 


--------------------------------------------------------------------------------
/pitch_extend/dataloader.py:
--------------------------------------------------------------------------------
 1 | from torch.utils.data import DataLoader
 2 | from pitch.data_utils import DistributedBucketSampler
 3 | from pitch.data_utils import TextAudioLoader
 4 | from pitch.data_utils import TextAudioCollate
 5 | 
 6 | 
 7 | def create_dataloader_train(hps, n_gpus, rank):
 8 |     collate_fn = TextAudioCollate()
 9 |     train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
10 |     train_sampler = DistributedBucketSampler(
11 |         train_dataset,
12 |         hps.train.batch_size,
13 |         [32, 300, 400, 500, 600, 700, 800, 900, 1000],
14 |         num_replicas=n_gpus,
15 |         rank=rank,
16 |         shuffle=True)
17 |     train_loader = DataLoader(
18 |         train_dataset,
19 |         num_workers=4,
20 |         shuffle=False,
21 |         pin_memory=True,
22 |         collate_fn=collate_fn,
23 |         batch_sampler=train_sampler)
24 |     return train_loader
25 | 
26 | 
27 | def create_dataloader_eval(hps):
28 |     collate_fn = TextAudioCollate()
29 |     eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
30 |     eval_loader = DataLoader(
31 |         eval_dataset,
32 |         num_workers=2,
33 |         shuffle=False,
34 |         batch_size=hps.train.batch_size,
35 |         pin_memory=True,
36 |         drop_last=False,
37 |         collate_fn=collate_fn)
38 |     return eval_loader
39 | 


--------------------------------------------------------------------------------
/pitch_extend/plotting.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | mpl_logger = logging.getLogger('matplotlib')  # must before import matplotlib
 3 | mpl_logger.setLevel(logging.WARNING)
 4 | import matplotlib
 5 | matplotlib.use("Agg")
 6 | 
 7 | import numpy as np
 8 | import matplotlib.pylab as plt
 9 | 
10 | 
11 | def save_figure_to_numpy(fig):
12 |     # save it to a numpy array.
13 |     data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
14 |     data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
15 |     data = np.transpose(data, (2, 0, 1))
16 |     return data
17 | 
18 | 
19 | def plot_f0_to_numpy(f0_pre, f0_gt=None):
20 |     fig = plt.figure(figsize=(12, 6))
21 |     plt.plot(f0_pre.T, "g")
22 |     if f0_gt is not None:
23 |         plt.plot(f0_gt.T, "r")
24 |     fig.canvas.draw()
25 |     data = save_figure_to_numpy(fig)
26 |     plt.close()
27 |     return data
28 | 


--------------------------------------------------------------------------------
/pitch_extend/train.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import time
  3 | import logging
  4 | import math
  5 | import tqdm
  6 | import torch
  7 | import torch.nn as nn
  8 | import torch.nn.functional as F
  9 | from torch.distributed import init_process_group
 10 | from torch.nn.parallel import DistributedDataParallel
 11 | 
 12 | from vits.commons import clip_grad_value_
 13 | 
 14 | from pitch.utils import fix_len_compatibility
 15 | from pitch.models import PitchDiffusion
 16 | from pitch_extend.validation import validate
 17 | from pitch_extend.writer import MyWriter
 18 | from pitch_extend.dataloader import create_dataloader_train
 19 | from pitch_extend.dataloader import create_dataloader_eval
 20 | 
 21 | 
 22 | def load_model(model, saved_state_dict):
 23 |     if hasattr(model, 'module'):
 24 |         state_dict = model.module.state_dict()
 25 |     else:
 26 |         state_dict = model.state_dict()
 27 |     new_state_dict = {}
 28 |     for k, v in state_dict.items():
 29 |         try:
 30 |             new_state_dict[k] = saved_state_dict[k]
 31 |         except:
 32 |             print("%s is not in the checkpoint" % k)
 33 |             new_state_dict[k] = v
 34 |     if hasattr(model, 'module'):
 35 |         model.module.load_state_dict(new_state_dict)
 36 |     else:
 37 |         model.load_state_dict(new_state_dict)
 38 |     return model
 39 | 
 40 | 
 41 | # 400 frames
 42 | out_size = fix_len_compatibility(400)
 43 | 
 44 | 
 45 | def train(rank, args, chkpt_path, hp, hp_str):
 46 | 
 47 |     if args.num_gpus > 1:
 48 |         init_process_group(backend=hp.dist_config.dist_backend, init_method=hp.dist_config.dist_url,
 49 |                            world_size=hp.dist_config.world_size * args.num_gpus, rank=rank)
 50 | 
 51 |     torch.cuda.manual_seed(hp.train.seed)
 52 |     device = torch.device('cuda:{:d}'.format(rank))
 53 | 
 54 |     model_g = PitchDiffusion().to(device)
 55 | 
 56 |     optim_g = torch.optim.AdamW(model_g.parameters(),
 57 |                                 lr=hp.train.learning_rate, betas=hp.train.betas, eps=hp.train.eps)
 58 | 
 59 |     init_epoch = 1
 60 |     step = 0
 61 | 
 62 |     # define logger, writer, valloader, stft at rank_zero
 63 |     if rank == 0:
 64 |         pth_dir = os.path.join(hp.log.pth_dir, args.name)
 65 |         log_dir = os.path.join(hp.log.log_dir, args.name)
 66 |         os.makedirs(pth_dir, exist_ok=True)
 67 |         os.makedirs(log_dir, exist_ok=True)
 68 | 
 69 |         logging.basicConfig(
 70 |             level=logging.INFO,
 71 |             format='%(asctime)s - %(levelname)s - %(message)s',
 72 |             handlers=[
 73 |                 logging.FileHandler(os.path.join(log_dir, '%s-%d.log' % (args.name, time.time()))),
 74 |                 logging.StreamHandler()
 75 |             ]
 76 |         )
 77 |         logger = logging.getLogger()
 78 |         writer = MyWriter(hp, log_dir)
 79 |         valloader = create_dataloader_eval(hp)
 80 | 
 81 |     if chkpt_path is not None:
 82 |         if rank == 0:
 83 |             logger.info("Resuming from checkpoint: %s" % chkpt_path)
 84 |         checkpoint = torch.load(chkpt_path, map_location='cpu')
 85 |         load_model(model_g, checkpoint['model_g'])
 86 |         optim_g.load_state_dict(checkpoint['optim_g'])
 87 |         init_epoch = checkpoint['epoch']
 88 |         step = checkpoint['step']
 89 | 
 90 |         if rank == 0:
 91 |             if hp_str != checkpoint['hp_str']:
 92 |                 logger.warning("New hparams is different from checkpoint. Will use new.")
 93 |     else:
 94 |         if rank == 0:
 95 |             logger.info("Starting new training run.")
 96 | 
 97 |     if args.num_gpus > 1:
 98 |         model_g = DistributedDataParallel(model_g, device_ids=[rank])
 99 | 
100 |     # this accelerates training when the size of minibatch is always consistent.
101 |     # if not consistent, it'll horribly slow down.
102 |     torch.backends.cudnn.benchmark = True
103 | 
104 |     scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hp.train.lr_decay, last_epoch=init_epoch-2)
105 |     trainloader = create_dataloader_train(hp, args.num_gpus, rank)
106 | 
107 |     for epoch in range(init_epoch, hp.train.epochs):
108 | 
109 |         trainloader.batch_sampler.set_epoch(epoch)
110 | 
111 |         if rank == 0 and epoch % hp.log.eval_interval == 0:
112 |             with torch.no_grad():
113 |                 validate(hp, model_g, valloader, writer, step, device)
114 | 
115 |         if rank == 0:
116 |             loader = tqdm.tqdm(trainloader, desc='Loading train data')
117 |         else:
118 |             loader = trainloader
119 | 
120 |         model_g.train()
121 | 
122 |         for phone, phone_l, score, pitch, slurs in loader:
123 | 
124 |             phone = phone.to(device)
125 |             phone_l = phone_l.to(device)
126 |             score = score.to(device)
127 |             pitch = pitch.to(device)
128 |             slurs = slurs.to(device)
129 | 
130 |             # generator
131 |             optim_g.zero_grad()
132 |             #
133 |             prior_loss, diff_loss = model_g.compute_loss(phone, phone_l, score, slurs, pitch, out_size=out_size)
134 |             loss_g = sum([prior_loss, diff_loss])
135 |             loss_g.backward()
136 |             clip_grad_value_(model_g.parameters(),  None)
137 |             optim_g.step()
138 | 
139 |             step += 1
140 |             # logging
141 |             loss_g = loss_g.item()
142 |             if rank == 0 and step % hp.log.info_interval == 0:
143 |                 writer.log_training(loss_g, prior_loss, diff_loss, step)
144 |                 logger.info("epoch %d | g %.04f prior_loss %.04f diff_loss %.04f | step %d" % (
145 |                     epoch, loss_g, prior_loss, diff_loss, step))
146 | 
147 |         if rank == 0 and epoch % hp.log.save_interval == 0:
148 |             save_path = os.path.join(pth_dir, '%s_%04d.pt'
149 |                                      % (args.name, epoch))
150 |             torch.save({
151 |                 'model_g': (model_g.module if args.num_gpus > 1 else model_g).state_dict(),
152 |                 'optim_g': optim_g.state_dict(),
153 |                 'step': step,
154 |                 'epoch': epoch,
155 |                 'hp_str': hp_str,
156 |             }, save_path)
157 |             logger.info("Saved checkpoint to: %s" % save_path)
158 | 
159 |         scheduler_g.step()
160 | 


--------------------------------------------------------------------------------
/pitch_extend/validation.py:
--------------------------------------------------------------------------------
 1 | import tqdm
 2 | import torch
 3 | import torch.nn.functional as F
 4 | 
 5 | 
 6 | def validate(hp, generator, valloader, writer, step, device):
 7 |     generator.eval()
 8 |     torch.backends.cudnn.benchmark = False
 9 | 
10 |     loader = tqdm.tqdm(valloader, desc='Validation loop')
11 |     vali_loss = 0.0
12 |     for idx, (phone, phone_l, score, pitch, slurs) in enumerate(loader):
13 |         phone = phone.to(device)
14 |         phone_l = phone_l.to(device)
15 |         score = score.to(device)
16 |         pitch = pitch.to(device)
17 |         slurs = slurs.to(device)
18 | 
19 |         pitch_pri, pitch_pre = generator(phone, phone_l, score, slurs, n_timesteps=50)
20 | 
21 |         # De-Log
22 |         pitch_pri = torch.pow(2, pitch_pri)
23 |         pitch_pre = torch.pow(2, pitch_pre)
24 | 
25 |         loss_f0 = F.l1_loss(pitch_pre[:, 0, :], pitch)
26 |         vali_loss += loss_f0.item()
27 | 
28 |         if idx < hp.log.num_audio:
29 |             writer.log_fig_pitch(pitch_pri, pitch_pre, pitch, idx, step)
30 | 
31 |     vali_loss = vali_loss / len(valloader.dataset)
32 |     writer.log_validation(vali_loss, step)
33 | 
34 |     torch.backends.cudnn.benchmark = True
35 | 


--------------------------------------------------------------------------------
/pitch_extend/writer.py:
--------------------------------------------------------------------------------
 1 | from torch.utils.tensorboard import SummaryWriter
 2 | from .plotting import plot_f0_to_numpy
 3 | 
 4 | 
 5 | class MyWriter(SummaryWriter):
 6 |     def __init__(self, hp, logdir):
 7 |         super(MyWriter, self).__init__(logdir)
 8 | 
 9 |     def log_training(self, loss_g, prior_loss, diff_loss, step):
10 |         self.add_scalar('train/loss_g', loss_g, step)
11 |         self.add_scalar('train/loss_prior', prior_loss, step)
12 |         self.add_scalar('train/loss_diff', diff_loss, step)
13 | 
14 |     def log_validation(self, vali_loss, step):
15 |         self.add_scalar('validation/vali_loss', vali_loss, step)
16 | 
17 |     def log_fig_pitch(self, pitch_prio, pitch_fake, pitch_real, idx, step):
18 |         if idx == 0:
19 |             pitch_prio = pitch_prio[0, 0, :].data.cpu().numpy()
20 |             pitch_fake = pitch_fake[0, 0, :].data.cpu().numpy()
21 |             pitch_prio[pitch_prio > 1000] = 1000
22 |             pitch_fake[pitch_fake > 1000] = 1000
23 |             pitch_real = pitch_real[0].data.cpu().numpy()
24 |             self.add_image(f'pitch_prio/{step}', plot_f0_to_numpy(pitch_prio, pitch_real), step)
25 |             self.add_image(f'pitch_fake/{step}', plot_f0_to_numpy(pitch_fake, pitch_real), step)
26 |             # self.add_image(f'pitch_real/{step}', plot_f0_to_numpy(pitch_real), step)
27 | 


--------------------------------------------------------------------------------
/resource/vising_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/resource/vising_loss.png


--------------------------------------------------------------------------------
/resource/vising_mel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/resource/vising_mel.png


--------------------------------------------------------------------------------
/resource/vising_sample.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/resource/vising_sample.wav


--------------------------------------------------------------------------------
/svs/__init__.py:
--------------------------------------------------------------------------------
 1 | from svs.phone_map import label_to_ids
 2 | from svs.phone_uv import uv_map
 3 | 
 4 | 
 5 | def load_midi_map():
 6 |     notemap = {}
 7 |     notemap["rest"] = 0
 8 |     fo = open("./svs/midi-note.scp", "r+")
 9 |     while True:
10 |         try:
11 |             message = fo.readline().strip()
12 |         except Exception as e:
13 |             print("nothing of except:", e)
14 |             break
15 |         if message == None:
16 |             break
17 |         if message == "":
18 |             break
19 |         infos = message.split()
20 |         notemap[infos[1]] = int(infos[0])
21 |     fo.close()
22 |     return notemap
23 | 


--------------------------------------------------------------------------------
/svs/midi-HZ.scp:
--------------------------------------------------------------------------------
  1 | 127     G9          12543.9
  2 | 126     F#9/Gb9     11839.8
  3 | 125     F9          11175.3
  4 | 124     E9          10548.1
  5 | 123     D#9/Eb9     9956.1 
  6 | 122     D9          9397.3 
  7 | 121     C#9/Db9     8869.8 
  8 | 120     C9          8372   
  9 | 119     B8          7902.1 
 10 | 118     A#8/Bb8     7458.6 
 11 | 117     A8          7040   
 12 | 116     G#8/Ab8     6644.9 
 13 | 115     G8          6271.9 
 14 | 114     F#8/Gb8     5919.9 
 15 | 113     F8          5587.7 
 16 | 112     E8          5274   
 17 | 111     D#8/Eb8     4978   
 18 | 110     D8          4698.6 
 19 | 109     C#8/Db8     4434.9 
 20 | 108     C8          4186   
 21 | 107     B7          3951.1 
 22 | 106     A#7/Bb7     3729.3 
 23 | 105     A7          3520   
 24 | 104     G#7/Ab7     3322.4 
 25 | 103     G7          3136   
 26 | 102     F#7/Gb7     2960   
 27 | 101     F7          2793.8 
 28 | 100     E7          2637   
 29 | 99      D#7/Eb7     2489   
 30 | 98      D7          2349.3 
 31 | 97      C#7/Db7     2217.5 
 32 | 96      C7          2093   
 33 | 95      B6          1975.5 
 34 | 94      A#6/Bb6     1864.7 
 35 | 93      A6          1760   
 36 | 92      G#6/Ab6     1661.2 
 37 | 91      G6          1568   
 38 | 90      F#6/Gb6     1480   
 39 | 89      F6          1396.9 
 40 | 88      E6          1318.5 
 41 | 87      D#6/Eb6     1244.5 
 42 | 86      D6          1174.7 
 43 | 85      C#6/Db6     1108.7 
 44 | 84      C6          1046.5 
 45 | 83      B5          987.8  
 46 | 82      A#5/Bb5     932.3  
 47 | 81      A5          880    
 48 | 80      G#5/Ab5     830.6  
 49 | 79      G5          784    
 50 | 78      F#5/Gb5     740    
 51 | 77      F5          698.5  
 52 | 76      E5          659.3  
 53 | 75      D#5/Eb5     622.3  
 54 | 74      D5          587.3  
 55 | 73      C#5/Db5     554.4  
 56 | 72      C5          523.3  
 57 | 71      B4          493.9  
 58 | 70      A#4/Bb4     466.2  
 59 | 69      A4          440    
 60 | 68      G#4/Ab4     415.3  
 61 | 67      G4          392    
 62 | 66      F#4/Gb4     370    
 63 | 65      F4          349.2  
 64 | 64      E4          329.6  
 65 | 63      D#4/Eb4     311.1  
 66 | 62      D4          293.7  
 67 | 61      C#4/Db4     277.2  
 68 | 60      C4          261.6  
 69 | 59      B3          246.9  
 70 | 58      A#3/Bb3     233.1  
 71 | 57      A3          220    
 72 | 56      G#3/Ab3     207.7  
 73 | 55      G3          196    
 74 | 54      F#3/Gb3     185    
 75 | 53      F3          174.6  
 76 | 52      E3          164.8  
 77 | 51      D#3/Eb3     155.6  
 78 | 50      D3          146.8  
 79 | 49      C#3/Db3     138.6  
 80 | 48      C3          130.8  
 81 | 47      B2          123.5  
 82 | 46      A#2/Bb2     116.5  
 83 | 45      A2          110    
 84 | 44      G#2/Ab2     103.   
 85 | 43      G2          98     
 86 | 42      F#2/Gb2     92.5   
 87 | 41      F2          87.3   
 88 | 40      E2          82.4   
 89 | 39      D#2/Eb2     77.8   
 90 | 38      D2          73.4   
 91 | 37      C#2/Db2     69.3   
 92 | 36      C2          65.4   
 93 | 35      B1          61.7   
 94 | 34      A#1/Bb1     58.3   
 95 | 33      A1          55     
 96 | 32      G#1/Ab1     51.9   
 97 | 31      G1          49     
 98 | 30      F#1/Gb1     46.2   
 99 | 29      F1          43.7   
100 | 28      E1          41.2   
101 | 27      D#1/Eb1     38.9   
102 | 26      D1          36.7   
103 | 25      C#1/Db1     34.6   
104 | 24      C1          32.7   
105 | 23      B0          30.9   
106 | 22      A#0/Bb0     29.1   
107 | 21      A0          27.5   
108 | 0       rest        0        


--------------------------------------------------------------------------------
/svs/midi-note.scp:
--------------------------------------------------------------------------------
  1 | 127	G9
  2 | 126	F#9/Gb9
  3 | 125	F9
  4 | 124	E9
  5 | 123	D#9/Eb9
  6 | 122	D9
  7 | 121	C#9/Db9
  8 | 120	C9
  9 | 119	B8
 10 | 118	A#8/Bb8
 11 | 117	A8
 12 | 116	G#8/Ab8
 13 | 115	G8
 14 | 114	F#8/Gb8
 15 | 113	F8
 16 | 112	E8
 17 | 111	D#8/Eb8
 18 | 110	D8
 19 | 109	C#8/Db8
 20 | 108	C8
 21 | 107	B7
 22 | 106	A#7/Bb7
 23 | 105	A7
 24 | 104	G#7/Ab7
 25 | 103	G7
 26 | 102	F#7/Gb7
 27 | 101	F7
 28 | 100	E7
 29 | 99	D#7/Eb7
 30 | 98	D7
 31 | 97	C#7/Db7
 32 | 96	C7
 33 | 95	B6
 34 | 94	A#6/Bb6
 35 | 93	A6
 36 | 92	G#6/Ab6
 37 | 91	G6
 38 | 90	F#6/Gb6
 39 | 89	F6
 40 | 88	E6
 41 | 87	D#6/Eb6
 42 | 86	D6
 43 | 85	C#6/Db6
 44 | 84	C6
 45 | 83	B5
 46 | 82	A#5/Bb5
 47 | 81	A5
 48 | 80	G#5/Ab5
 49 | 79	G5
 50 | 78	F#5/Gb5
 51 | 77	F5
 52 | 76	E5
 53 | 75	D#5/Eb5
 54 | 74	D5
 55 | 73	C#5/Db5
 56 | 72	C5
 57 | 71	B4
 58 | 70	A#4/Bb4
 59 | 69	A4
 60 | 68	G#4/Ab4
 61 | 67	G4
 62 | 66	F#4/Gb4
 63 | 65	F4
 64 | 64	E4
 65 | 63	D#4/Eb4
 66 | 62	D4
 67 | 61	C#4/Db4
 68 | 60	C4
 69 | 59	B3
 70 | 58	A#3/Bb3
 71 | 57	A3
 72 | 56	G#3/Ab3
 73 | 55	G3
 74 | 54	F#3/Gb3
 75 | 53	F3
 76 | 52	E3
 77 | 51	D#3/Eb3
 78 | 50	D3
 79 | 49	C#3/Db3
 80 | 48	C3
 81 | 47	B2
 82 | 46	A#2/Bb2
 83 | 45	A2
 84 | 44	G#2/Ab2
 85 | 43	G2
 86 | 42	F#2/Gb2
 87 | 41	F2
 88 | 40	E2
 89 | 39	D#2/Eb2
 90 | 38	D2
 91 | 37	C#2/Db2
 92 | 36	C2
 93 | 35	B1
 94 | 34	A#1/Bb1
 95 | 33	A1
 96 | 32	G#1/Ab1
 97 | 31	G1
 98 | 30	F#1/Gb1
 99 | 29	F1
100 | 28	E1
101 | 27	D#1/Eb1
102 | 26	D1
103 | 25	C#1/Db1
104 | 24	C1
105 | 23	B0
106 | 22	A#0/Bb0
107 | 21	A0


--------------------------------------------------------------------------------
/svs/phone_map.py:
--------------------------------------------------------------------------------
 1 | _pause = ["unk", "sos", "eos", "ap", "sp"]
 2 | 
 3 | _initials = [
 4 |     "b",
 5 |     "c",
 6 |     "ch",
 7 |     "d",
 8 |     "f",
 9 |     "g",
10 |     "h",
11 |     "j",
12 |     "k",
13 |     "l",
14 |     "m",
15 |     "n",
16 |     "p",
17 |     "q",
18 |     "r",
19 |     "s",
20 |     "sh",
21 |     "t",
22 |     "w",
23 |     "x",
24 |     "y",
25 |     "z",
26 |     "zh",
27 | ]
28 | 
29 | _finals = [
30 |     "a",
31 |     "ai",
32 |     "an",
33 |     "ang",
34 |     "ao",
35 |     "e",
36 |     "ei",
37 |     "en",
38 |     "eng",
39 |     "er",
40 |     "i",
41 |     "ia",
42 |     "ian",
43 |     "iang",
44 |     "iao",
45 |     "ie",
46 |     "in",
47 |     "ing",
48 |     "iong",
49 |     "iu",
50 |     "o",
51 |     "ong",
52 |     "ou",
53 |     "u",
54 |     "ua",
55 |     "uai",
56 |     "uan",
57 |     "uang",
58 |     "ui",
59 |     "un",
60 |     "uo",
61 |     "v",
62 |     "van",
63 |     "ve",
64 |     "vn",
65 | ]
66 | 
67 | 
68 | symbols = _pause + _initials + _finals
69 | 
70 | # Mappings from symbol to numeric ID and vice versa:
71 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
72 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
73 | 
74 | 
75 | def label_to_ids(phones):
76 |     # use lower letter
77 |     sequence = [_symbol_to_id[symbol.lower()] for symbol in phones]
78 |     return sequence
79 | 


--------------------------------------------------------------------------------
/svs/phone_uv.py:
--------------------------------------------------------------------------------
 1 | # 普通话发音基础声母韵母
 2 | # 普通话声母只有 4 个浊音:m、n、l、r,其余 17 个辅音声母都是清音
 3 | # 汉语拼音的 y 和 w 只出现在零声母音节的开头,它们的作用主要是使音节界限清楚。
 4 | # https://baijiahao.baidu.com/s?id=1655739561730224990&wfr=spider&for=pc
 5 | 
 6 | uv_map = {
 7 |     "unk":0,
 8 |     "sos":0,
 9 |     "eos":0,
10 |     "ap":0,
11 |     "sp":0,
12 |     "b":0,
13 |     "c":0,
14 |     "ch":0,
15 |     "d":0,
16 |     "f":0,
17 |     "g":0,
18 |     "h":0,
19 |     "j":0,
20 |     "k":0,
21 |     "l":1,
22 |     "m":1,
23 |     "n":1,
24 |     "p":0,
25 |     "q":0,
26 |     "r":1,
27 |     "s":0,
28 |     "sh":0,
29 |     "t":0,
30 |     "w":1,
31 |     "x":0,
32 |     "y":1,
33 |     "z":0,
34 |     "zh":0,
35 |     "a":1,
36 |     "ai":1,
37 |     "an":1,
38 |     "ang":1,
39 |     "ao":1,
40 |     "e":1,
41 |     "ei":1,
42 |     "en":1,
43 |     "eng":1,
44 |     "er":1,
45 |     "i":1,
46 |     "ia":1,
47 |     "ian":1,
48 |     "iang":1,
49 |     "iao":1,
50 |     "ie":1,
51 |     "in":1,
52 |     "ing":1,
53 |     "iong":1,
54 |     "iu":1,
55 |     "o":1,
56 |     "ong":1,
57 |     "ou":1,
58 |     "u":1,
59 |     "ua":1,
60 |     "uai":1,
61 |     "uan":1,
62 |     "uang":1,
63 |     "ui":1,
64 |     "un":1,
65 |     "uo":1,
66 |     "v":1,
67 |     "van":1,
68 |     "ve":1,
69 |     "vn":1
70 | }


--------------------------------------------------------------------------------
/svs_export.py:
--------------------------------------------------------------------------------
 1 | import sys,os
 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 3 | import torch
 4 | import argparse
 5 | from omegaconf import OmegaConf
 6 | 
 7 | from vits.models import SynthesizerTrn
 8 | 
 9 | 
10 | def load_model(checkpoint_path, model):
11 |     assert os.path.isfile(checkpoint_path)
12 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
13 |     saved_state_dict = checkpoint_dict["model_g"]
14 |     if hasattr(model, "module"):
15 |         state_dict = model.module.state_dict()
16 |     else:
17 |         state_dict = model.state_dict()
18 |     new_state_dict = {}
19 |     for k, v in state_dict.items():
20 |         try:
21 |             new_state_dict[k] = saved_state_dict[k]
22 |         except:
23 |             new_state_dict[k] = v
24 |     if hasattr(model, "module"):
25 |         model.module.load_state_dict(new_state_dict)
26 |     else:
27 |         model.load_state_dict(new_state_dict)
28 |     return model
29 | 
30 | 
31 | def save_pretrain(checkpoint_path, save_path):
32 |     assert os.path.isfile(checkpoint_path)
33 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
34 |     torch.save({
35 |         'model_g': checkpoint_dict['model_g'],
36 |         'model_d': checkpoint_dict['model_d'],
37 |     }, save_path)
38 | 
39 | 
40 | def save_model(model, checkpoint_path):
41 |     if hasattr(model, 'module'):
42 |         state_dict = model.module.state_dict()
43 |     else:
44 |         state_dict = model.state_dict()
45 |     torch.save({'model_g': state_dict}, checkpoint_path)
46 | 
47 | 
48 | def main(args):
49 |     hp = OmegaConf.load(args.config)
50 |     model = SynthesizerTrn(
51 |         hp.data.filter_length // 2 + 1,
52 |         hp.data.segment_size // hp.data.hop_length,
53 |         hp)
54 | 
55 |     load_model(args.model, model)
56 |     save_model(model, "svs_opencpop.pt")
57 | 
58 | 
59 | if __name__ == '__main__':
60 |     parser = argparse.ArgumentParser()
61 |     parser.add_argument('-c', '--config', type=str, required=True,
62 |                         help="yaml file for config. will use hp_str from checkpoint if not given.")
63 |     parser.add_argument('-m', '--model', type=str, required=True,
64 |                         help="path of checkpoint pt file for evaluation")
65 |     args = parser.parse_args()
66 | 
67 |     main(args)
68 | 


--------------------------------------------------------------------------------
/svs_infer.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import numpy as np
  3 | import matplotlib.pyplot as plt
  4 | 
  5 | from scipy.io import wavfile
  6 | from time import *
  7 | 
  8 | import torch
  9 | import argparse
 10 | 
 11 | from vits.models import SynthesizerTrn
 12 | from util import SingInput
 13 | from util import FeatureInput
 14 | from omegaconf import OmegaConf
 15 | 
 16 | 
 17 | def save_wav(wav, path, rate):
 18 |     wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
 19 |     wavfile.write(path, rate, wav.astype(np.int16))
 20 | 
 21 | 
 22 | def load_svs_model(checkpoint_path, model):
 23 |     assert os.path.isfile(checkpoint_path)
 24 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
 25 |     saved_state_dict = checkpoint_dict["model_g"]
 26 |     state_dict = model.state_dict()
 27 |     new_state_dict = {}
 28 |     for k, v in state_dict.items():
 29 |         try:
 30 |             new_state_dict[k] = saved_state_dict[k]
 31 |         except:
 32 |             print("%s is not in the checkpoint" % k)
 33 |             new_state_dict[k] = v
 34 |     model.load_state_dict(new_state_dict)
 35 |     return model
 36 | 
 37 | 
 38 | if __name__ == '__main__':
 39 |     parser = argparse.ArgumentParser()
 40 |     parser.add_argument('-c', '--config', type=str, required=True,
 41 |                         help="yaml file for configuration")
 42 |     parser.add_argument('-m', '--model', type=str, required=True,
 43 |                         help="path of checkpoint pt file")
 44 |     args = parser.parse_args()
 45 | 
 46 |     # define model and load checkpoint
 47 |     hps = OmegaConf.load(args.config)
 48 | 
 49 |     singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length)
 50 |     featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length)
 51 | 
 52 |     net_g = SynthesizerTrn(
 53 |         hps.data.filter_length // 2 + 1,
 54 |         hps.data.segment_size // hps.data.hop_length,
 55 |         hps).cuda()
 56 |     net_g.eval()
 57 | 
 58 |     load_svs_model(args.model, net_g)
 59 | 
 60 |     # check directory existence
 61 |     os.makedirs("./svs_out", exist_ok=True)
 62 |     fo = open("./svs_infer.txt", "r+")
 63 |     while True:
 64 |         try:
 65 |             message = fo.readline().strip()
 66 |         except Exception as e:
 67 |             print("nothing of except:", e)
 68 |             break
 69 |         if message == None:
 70 |             break
 71 |         if message == "":
 72 |             break
 73 |         print(message)
 74 |         (
 75 |             file,
 76 |             labels_ids,
 77 |             labels_frames,
 78 |             scores_ids,
 79 |             scores_dur,
 80 |             labels_slr,
 81 |             labels_uvs,
 82 |         ) = singInput.parseInput(message)
 83 |         labels_ids = singInput.expandInput(labels_ids, labels_frames)
 84 |         labels_uvs = singInput.expandInput(labels_uvs, labels_frames)
 85 |         labels_slr = singInput.expandInput(labels_slr, labels_frames)
 86 |         scores_ids = singInput.expandInput(scores_ids, labels_frames)
 87 |         scores_pit = singInput.scorePitch(scores_ids)
 88 |         # elments by elments
 89 |         scores_pit_ = scores_pit * labels_uvs
 90 |         scores_pit = singInput.smoothPitch(scores_pit_)
 91 | 
 92 |         fig = plt.figure(figsize=(12, 6))
 93 |         plt.plot(scores_pit_.T, "g")
 94 |         plt.plot(scores_pit.T, "r")
 95 |         plt.savefig(f"./svs_out/{file}_f0_.png", format="png")
 96 |         plt.close(fig)
 97 | 
 98 |         phone = torch.LongTensor(labels_ids)
 99 |         score = torch.LongTensor(scores_ids)
100 |         slurs = torch.LongTensor(labels_slr)
101 |         pitch = torch.FloatTensor(scores_pit)
102 | 
103 |         phone_lengths = phone.size()[0]
104 | 
105 |         with torch.no_grad():
106 |             phone = phone.cuda().unsqueeze(0)
107 |             score = score.cuda().unsqueeze(0)
108 |             pitch = pitch.cuda().unsqueeze(0)
109 |             slurs = slurs.cuda().unsqueeze(0)
110 |             phone_lengths = torch.LongTensor([phone_lengths]).cuda()
111 |             audio = (
112 |                 net_g.infer(phone, phone_lengths, score, pitch, slurs)[0, 0]
113 |                 .data.cpu()
114 |                 .float()
115 |                 .numpy()
116 |             )
117 | 
118 |         save_wav(audio, f"./svs_out/{file}.wav", hps.data.sampling_rate)
119 |     fo.close()
120 |     # can be deleted
121 |     os.system("chmod 777 ./svs_out -R")
122 | 


--------------------------------------------------------------------------------
/svs_infer.txt:
--------------------------------------------------------------------------------
 1 | 2001000001|感受停在我发端的指尖|g an sh ou t ing z ai w o f a d uan d e SP zh i j ian AP|G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 E4 E4 E4 E4 D#4/Eb4 D#4/Eb4 D#4/Eb4 D#4/Eb4 rest E4 E4 E4 E4 rest|0.253030 0.253030 0.428030 0.428030 0.320870 0.320870 0.358110 0.358110 0.218610 0.218610 0.519380 0.519380 0.351070 0.351070 0.152260 0.152260 0.089470 0.405810 0.405810 0.696660 0.696660 0.284630|0.0317 0.22133 0.15421 0.27382 0.06335 0.25752 0.07101 0.2871 0.03623 0.18238 0.18629 0.33309 0.01471 0.33636 0.01415 0.13811 0.08947 0.12862 0.27719 0.07962 0.61704 0.28463|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 2 | 2001000002|如何瞬间冻结时间|r u h e sh un j ian AP SP d ong j ie sh i j ian SP|B3 B3 B3 B3 B3 B3 G#4/Ab4 G#4/Ab4 rest rest B3 B3 B3 B3 B3 B3 F#4/Gb4 F#4/Gb4 rest|0.294760 0.294760 0.283550 0.283550 0.795250 0.795250 0.992200 0.992200 0.297130 0.104830 0.311040 0.311040 0.214620 0.214620 0.782750 0.782750 1.519540 1.519540 1.179120|0.06588 0.22888 0.11684 0.16671 0.18746 0.60779 0.11194 0.88026 0.29713 0.10483 0.03166 0.27938 0.05057 0.16405 0.21149 0.57126 0.13926 1.38028 1.17912|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 3 | 2001000003|记住望着我坚定的双眼|j i zh u w ang zh e w o SP j ian d ing d e sh uang y an AP|G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 E4 E4 rest E4 E4 D#4/Eb4 D#4/Eb4 D#4/Eb4 D#4/Eb4 E4 E4 E4 E4 rest|0.388470 0.388470 0.368320 0.368320 0.363510 0.363510 0.316690 0.316690 0.161350 0.161350 0.055570 0.495580 0.495580 0.342860 0.342860 0.141750 0.141750 0.398360 0.398360 0.785070 0.785070 0.317450|0.09945 0.28902 0.08103 0.28729 0.05083 0.31268 0.04303 0.27366 0.03603 0.12532 0.05557 0.15191 0.34367 0.02357 0.31929 0.02939 0.11236 0.21916 0.1792 0.22549 0.55958 0.31745|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 4 | 2001000004|也许已经没有明天|y e x v y i j ing AP m ei y ou m ing t ian SP AP|B3 B3 B3 B3 B3 B3 G#4/Ab4 G#4/Ab4 rest B3 B3 B3 B3 B3 B3 F#4/Gb4 F#4/Gb4 rest rest|0.236860 0.236860 0.426110 0.426110 0.660620 0.660620 1.021220 1.021220 0.409380 0.243270 0.243270 0.327560 0.327560 0.741700 0.741700 1.335140 1.335140 0.591900 0.515310|0.07979 0.15707 0.2089 0.21721 0.12179 0.53883 0.16915 0.85207 0.40938 0.06617 0.1771 0.04273 0.28483 0.11939 0.62231 0.17586 1.15928 0.5919 0.51531|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 5 | 2001000005|面对浩瀚的星海我们微小得像尘埃|m ian d ui h ao h an an d e x ing h ai ai ai AP w o m en w ei x iao d e x iang ch en ai ai ai SP|C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 E4 D#4/Eb4 D#4/Eb4 E4 E4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest C#4/Db4 C#4/Db4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 E4 E4 E4 E4 G#4/Ab4 A4 G#4/Ab4 rest|0.196990 0.196990 0.102120 0.102120 0.304680 0.304680 0.096780 0.096780 0.100220 0.150010 0.150010 0.361460 0.361460 0.221070 0.221070 0.183240 0.478670 0.384620 0.106510 0.106510 0.143020 0.143020 0.169480 0.169480 0.224180 0.224180 0.089360 0.089360 0.414460 0.414460 0.378050 0.378050 0.162790 0.207380 0.317260 0.297040|0.02765 0.16934 0.01874 0.08338 0.0821 0.22258 0.0693 0.02748 0.10022 0.07137 0.07864 0.12471 0.23675 0.12356 0.09751 0.18324 0.47867 0.38462 0.0405 0.06601 0.08303 0.05999 0.04687 0.12261 0.09778 0.1264 0.02321 0.06615 0.11958 0.29488 0.06723 0.31082 0.16279 0.20738 0.31726 0.29704|0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0
 6 | 2001000006|漂浮在一片无奈|p iao f u z ai ai ai AP SP y i i p ian ian ian w u n ai SP AP|E4 E4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest rest E4 E4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 E4 E4 F#4/Gb4 F#4/Gb4 rest rest|0.185230 0.185230 0.177410 0.177410 0.193930 0.193930 0.259670 0.299340 0.215550 0.031770 0.197520 0.197520 0.165450 0.184760 0.184760 0.212290 0.246960 0.440370 0.440370 1.524950 1.524950 0.855830 0.559100|0.06011 0.12512 0.07517 0.10224 0.08603 0.1079 0.25967 0.29934 0.21555 0.03177 0.05175 0.14577 0.16545 0.0748 0.10996 0.21229 0.24696 0.09617 0.3442 0.1437 1.38125 0.85583 0.5591|0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0
 7 | 2001000007|缘份让我们相遇乱世以外|y van f en r ang w o m en x iang y v AP l uan sh i y i w ai AP|D#4/Eb4 D#4/Eb4 E4 E4 E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 rest B4 B4 B4 B4 C#5/Db5 C#5/Db5 C#5/Db5 C#5/Db5 rest|0.323070 0.323070 0.325290 0.325290 0.483290 0.483290 0.212040 0.212040 0.294600 0.294600 0.465110 0.465110 0.364020 0.364020 0.137130 0.151270 0.151270 0.270860 0.270860 0.434770 0.434770 1.570380 1.570380 0.462970|0.12204 0.20103 0.11182 0.21347 0.09912 0.38417 0.05549 0.15655 0.10139 0.19321 0.17622 0.28889 0.0609 0.30312 0.13713 0.03605 0.11522 0.14541 0.12545 0.12186 0.31291 0.09403 1.47635 0.46297|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 8 | 2001000008|命运却要我们危难中相爱|m ing y van q ve y ao w o m en w ei n an zh ong x iang ai SP AP|E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 B4 B4 B4 B4 E4 E4 G#4/Ab4 G#4/Ab4 F#4/Gb4 F#4/Gb4 F#4/Gb4 rest rest|0.332160 0.332160 0.315140 0.315140 0.371590 0.371590 0.285140 0.285140 0.394510 0.394510 0.358480 0.358480 0.524060 0.524060 0.176940 0.176940 0.239510 0.239510 0.494880 0.494880 1.260320 0.317390 0.358080|0.03995 0.29221 0.08516 0.22998 0.12953 0.24206 0.09533 0.18981 0.09528 0.29923 0.06899 0.28949 0.03119 0.49287 0.048 0.12894 0.04204 0.19747 0.1539 0.34098 1.26032 0.31739 0.35808|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 9 | 2001000009|也许未来遥远在光年之外|y e x v w ei l ai y ao y van z ai SP g uang n ian zh i w ai SP AP|E4 E4 E4 E4 E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 rest B4 B4 B4 B4 C#5/Db5 C#5/Db5 C#5/Db5 C#5/Db5 rest rest|0.226010 0.226010 0.367780 0.367780 0.377380 0.377380 0.308330 0.308330 0.397890 0.397890 0.369570 0.369570 0.452320 0.452320 0.075060 0.237700 0.237700 0.272190 0.272190 0.325600 0.325600 1.446250 1.446250 0.243310 0.346690|0.11666 0.10935 0.23067 0.13711 0.14195 0.23543 0.12932 0.17901 0.16096 0.23693 0.19611 0.17346 0.08484 0.36748 0.07506 0.0593 0.1784 0.06402 0.20817 0.07175 0.25385 0.093 1.35325 0.24331 0.34669|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
10 | 2001000010|我愿守候未知里为你等待|w o y van sh ou h ou w ei zh i l i AP w ei n i SP d eng d ai AP|E4 E4 F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 G#4/Ab4 G#4/Ab4 B4 B4 B4 B4 rest E4 E4 G#4/Ab4 G#4/Ab4 rest F#4/Gb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 rest|0.302770 0.302770 0.288530 0.288530 0.402910 0.402910 0.447020 0.447020 0.296470 0.296470 0.202850 0.202850 0.466880 0.466880 0.207550 0.135530 0.135530 0.337900 0.337900 0.070010 0.249830 0.249830 0.392400 0.392400 0.210080|0.10342 0.19935 0.06127 0.22726 0.16322 0.23969 0.12336 0.32366 0.07033 0.22614 0.09677 0.10608 0.18788 0.279 0.20755 0.06243 0.0731 0.09532 0.24258 0.07001 0.03048 0.21935 0.10486 0.28754 0.21008|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
11 | 


--------------------------------------------------------------------------------
/svs_infer_pitch.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import numpy as np
  3 | import matplotlib.pyplot as plt
  4 | 
  5 | from scipy.io import wavfile
  6 | from time import *
  7 | 
  8 | import torch
  9 | import argparse
 10 | 
 11 | from vits.models import SynthesizerTrn
 12 | from util import SingInput
 13 | from util import FeatureInput
 14 | from omegaconf import OmegaConf
 15 | 
 16 | from pitch.models import PitchDiffusion
 17 | from pitch.utils import fix_len_compatibility
 18 | 
 19 | 
 20 | def save_wav(wav, path, rate):
 21 |     wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
 22 |     wavfile.write(path, rate, wav.astype(np.int16))
 23 | 
 24 | 
 25 | def load_svs_model(checkpoint_path, model):
 26 |     assert os.path.isfile(checkpoint_path)
 27 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
 28 |     saved_state_dict = checkpoint_dict["model_g"]
 29 |     state_dict = model.state_dict()
 30 |     new_state_dict = {}
 31 |     for k, v in state_dict.items():
 32 |         try:
 33 |             new_state_dict[k] = saved_state_dict[k]
 34 |         except:
 35 |             print("%s is not in the checkpoint" % k)
 36 |             new_state_dict[k] = v
 37 |     model.load_state_dict(new_state_dict)
 38 |     return model
 39 | 
 40 | 
 41 | if __name__ == '__main__':
 42 |     parser = argparse.ArgumentParser()
 43 |     parser.add_argument('-c', '--config', type=str, required=True,
 44 |                         help="yaml file for configuration")
 45 |     parser.add_argument('-m', '--model', type=str, required=True,
 46 |                         help="path of checkpoint pt file")
 47 |     parser.add_argument('-p', '--pitch', type=str, required=True,
 48 |                         help="path of checkpoint pt file")
 49 |     args = parser.parse_args()
 50 | 
 51 |     # define model and load checkpoint
 52 |     hps = OmegaConf.load(args.config)
 53 | 
 54 |     singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length)
 55 |     featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length)
 56 | 
 57 |     net_g = SynthesizerTrn(
 58 |         hps.data.filter_length // 2 + 1,
 59 |         hps.data.segment_size // hps.data.hop_length,
 60 |         hps).cuda()
 61 |     net_g.eval()
 62 | 
 63 |     load_svs_model(args.model, net_g)
 64 | 
 65 |     net_p = PitchDiffusion().cuda()
 66 |     net_p.eval()
 67 |     load_svs_model(args.pitch, net_p)
 68 | 
 69 |     # check directory existence
 70 |     os.makedirs("./svs_out", exist_ok=True)
 71 |     fo = open("./svs_infer.txt", "r+")
 72 |     while True:
 73 |         try:
 74 |             message = fo.readline().strip()
 75 |         except Exception as e:
 76 |             print("nothing of except:", e)
 77 |             break
 78 |         if message == None:
 79 |             break
 80 |         if message == "":
 81 |             break
 82 |         print(message)
 83 |         (
 84 |             file,
 85 |             labels_ids,
 86 |             labels_frames,
 87 |             scores_ids,
 88 |             scores_dur,
 89 |             labels_slr,
 90 |             labels_uvs,
 91 |         ) = singInput.parseInput(message)
 92 |         labels_ids = singInput.expandInput(labels_ids, labels_frames)
 93 |         labels_uvs = singInput.expandInput(labels_uvs, labels_frames)
 94 |         labels_slr = singInput.expandInput(labels_slr, labels_frames)
 95 |         scores_ids = singInput.expandInput(scores_ids, labels_frames)
 96 | 
 97 |         phone = torch.LongTensor(labels_ids)
 98 |         score = torch.LongTensor(scores_ids)
 99 |         slurs = torch.LongTensor(labels_slr)
100 | 
101 |         lengths = phone.size()[0]
102 |         lengths_fix = fix_len_compatibility(lengths)
103 | 
104 |         phone_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda()
105 |         score_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda()
106 |         slurs_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda()
107 |         phone_fix[0, :lengths] = phone
108 |         score_fix[0, :lengths] = score
109 |         slurs_fix[0, :lengths] = slurs
110 | 
111 |         with torch.no_grad():
112 |             n_timesteps = 50
113 |             temperature = 1
114 |             # PIT
115 |             phone_lengths = torch.LongTensor([lengths_fix]).cuda()
116 |             pit_pri, pit_pre = net_p(phone_fix, phone_lengths, score_fix, slurs_fix, n_timesteps, temperature)
117 |             pitch = pit_pre[:, 0, :]
118 |             pitch = 2**pitch
119 |             print('~~~~~~~')
120 |             # SVS
121 |             audio = (
122 |                 net_g.infer(phone_fix, phone_lengths, score_fix, pitch, slurs_fix)[0, 0]
123 |                 .data.cpu()
124 |                 .float()
125 |                 .numpy()
126 |             )
127 | 
128 |         save_wav(audio, f"./svs_out/{file}.wav", hps.data.sampling_rate)
129 |     fo.close()
130 |     # can be deleted
131 |     os.system("chmod 777 ./svs_out -R")
132 | 


--------------------------------------------------------------------------------
/svs_song.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import numpy as np
  3 | 
  4 | from scipy.io import wavfile
  5 | from time import *
  6 | 
  7 | import torch
  8 | import argparse
  9 | 
 10 | from vits.models import SynthesizerTrn
 11 | from util import SingInput
 12 | from util import FeatureInput
 13 | from omegaconf import OmegaConf
 14 | 
 15 | 
 16 | def save_wav(wav, path, rate):
 17 |     wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
 18 |     wavfile.write(path, rate, wav.astype(np.int16))
 19 | 
 20 | 
 21 | def load_svs_model(checkpoint_path, model):
 22 |     assert os.path.isfile(checkpoint_path)
 23 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
 24 |     saved_state_dict = checkpoint_dict["model_g"]
 25 |     state_dict = model.state_dict()
 26 |     new_state_dict = {}
 27 |     for k, v in state_dict.items():
 28 |         try:
 29 |             new_state_dict[k] = saved_state_dict[k]
 30 |         except:
 31 |             print("%s is not in the checkpoint" % k)
 32 |             new_state_dict[k] = v
 33 |     model.load_state_dict(new_state_dict)
 34 |     return model
 35 | 
 36 | 
 37 | if __name__ == '__main__':
 38 |     parser = argparse.ArgumentParser()
 39 |     parser.add_argument('-c', '--config', type=str, required=True,
 40 |                         help="yaml file for configuration")
 41 |     parser.add_argument('-m', '--model', type=str, required=True,
 42 |                         help="path of checkpoint pt file")
 43 |     args = parser.parse_args()
 44 | 
 45 |     # define model and load checkpoint
 46 |     hps = OmegaConf.load(args.config)
 47 | 
 48 |     singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length)
 49 |     featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length)
 50 | 
 51 |     net_g = SynthesizerTrn(
 52 |         hps.data.filter_length // 2 + 1,
 53 |         hps.data.segment_size // hps.data.hop_length,
 54 |         hps).cuda()
 55 |     net_g.eval()
 56 | 
 57 |     load_svs_model(args.model, net_g)
 58 | 
 59 |     # check directory existence
 60 |     os.makedirs("./svs_out", exist_ok=True)
 61 |     fo = open("./svs_song.txt", "r+")
 62 |     song_rate = hps.data.sampling_rate
 63 |     song_time = fo.readline().strip().split("|")[1]
 64 |     song_length = int(song_rate * (float(song_time) + 30))
 65 |     song_data = np.zeros(song_length, dtype="float32")
 66 |     while True:
 67 |         try:
 68 |             message = fo.readline().strip()
 69 |         except Exception as e:
 70 |             print("nothing of except:", e)
 71 |             break
 72 |         if message == None:
 73 |             break
 74 |         if message == "":
 75 |             break
 76 |         (
 77 |             item_indx,
 78 |             item_time,
 79 |             labels_ids,
 80 |             labels_frames,
 81 |             scores_ids,
 82 |             scores_dur,
 83 |             labels_slr,
 84 |             labels_uvs,
 85 |         ) = singInput.parseSong(message)
 86 |         labels_ids = singInput.expandInput(labels_ids, labels_frames)
 87 |         labels_uvs = singInput.expandInput(labels_uvs, labels_frames)
 88 |         labels_slr = singInput.expandInput(labels_slr, labels_frames)
 89 |         scores_ids = singInput.expandInput(scores_ids, labels_frames)
 90 |         scores_pit = singInput.scorePitch(scores_ids)
 91 |         # elments by elments
 92 |         scores_pit = scores_pit * labels_uvs
 93 |         # scores_pit = singInput.smoothPitch(scores_pit)
 94 |         # scores_pit = scores_pit * labels_uvs
 95 |         phone = torch.LongTensor(labels_ids)
 96 |         score = torch.LongTensor(scores_ids)
 97 |         slurs = torch.LongTensor(labels_slr)
 98 |         pitch = torch.FloatTensor(scores_pit)
 99 | 
100 |         phone_lengths = phone.size()[0]
101 | 
102 |         begin_time = time()
103 |         with torch.no_grad():
104 |             phone = phone.cuda().unsqueeze(0)
105 |             score = score.cuda().unsqueeze(0)
106 |             pitch = pitch.cuda().unsqueeze(0)
107 |             slurs = slurs.cuda().unsqueeze(0)
108 |             phone_lengths = torch.LongTensor([phone_lengths]).cuda()
109 |             audio = (
110 |                 net_g.infer(phone, phone_lengths, score, pitch, slurs)[0, 0]
111 |                 .data.cpu()
112 |                 .float()
113 |                 .numpy()
114 |             )
115 |        
116 |         save_wav(audio, f"./svs_out/{item_indx}.wav", hps.data.sampling_rate)
117 |         # wav
118 |         item_start = int(song_rate * float(item_time))
119 |         item_end = item_start + len(audio)
120 |         song_data[item_start:item_end] = audio
121 |     # out of for
122 |     song_data = np.array(song_data, dtype="float32")
123 |     save_wav(song_data, f"./svs_out/_song.wav", hps.data.sampling_rate)
124 |     fo.close()
125 |     # can be deleted
126 |     os.system("chmod 777 ./svs_out -R")
127 | 


--------------------------------------------------------------------------------
/svs_song.txt:
--------------------------------------------------------------------------------
 1 | song_time|116.88723672656248
 2 | 0|0000.694| 化 外 山 间 岁 月 皆 看 老|h ua w ai sh an j ian s ui y ve j ie k an l ao|57 57 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.506 0.506|0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 0.241 0.096 0.506|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 3 | 1|0006.140| 洛 雪 无 声 天 地 掩 尘 嚣|l uo x ve w u sh eng t ian d i y an ch en x iao|57 57 64 64 62 62 60 60 59 59 60 60 62 62 64 64 69 69|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.590 0.590|0.096 0.249 0.088 0.249 0.088 0.249 0.088 0.305 0.032 0.305 0.032 0.249 0.088 0.273 0.064 0.249 0.088 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 4 | 2|0010.923| 他 看 尽 晨 曦 日 暮 AP 饮 罢 腰 间 酒 一 壶 AP 依 稀 当 年 孤 旅 踏 苍 霞 尽 处|t a k an j in ch en x i r i m u AP y in b a y ao j ian j iu y i h u AP y i x i d ang n ian g u l v t a c ang x ia j in ch u|60 60 62 62 64 64 62 62 67 67 64 64 62 62 rest 64 64 67 67 72 72 71 71 69 69 67 67 69 69 rest 67 67 64 64 62 62 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 1.180 1.180|0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.249 0.088 0.337 0.249 0.088 0.297 0.040 0.249 0.088 0.273 0.064 0.273 0.064 0.249 0.088 0.273 0.064 0.421 0.165 0.088 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 1.180|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 5 | 3|0021.678| 风 霜 冷 冽 他 眉 目 AP 时 光 雕 琢 他 风 骨 AP 浮 世 南 柯 一 梦 冷 暖 都 藏 住|f eng sh uang l eng l ie t a m ei m u AP sh i g uang d iao z uo t a f eng g u AP f u sh i n an k e y i m eng l eng n uan d ou c ang zh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.064 0.249 0.088 0.241 0.096 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.337 0.249 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.337 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.241 0.096 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 6 | 4|0032.356| 哪 杯 酒 烫 过 肺 腑 AP 曾 换 他 睥 睨 一 顾 AP 剑 破 乾 坤 轮 转 山 河 倾 覆|n a b ei j iu t ang g uo f ei f u AP c eng h uan t a p i n i y i g u AP j ian p o q ian k un l un zh uan sh an h e q ing f u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 67 67 62 62 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 1.348 1.348|0.088 0.297 0.040 0.273 0.064 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.337 0.249 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.337 0.273 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.241 0.096 0.273 0.064 0.249 0.088 0.610 0.064 0.241 0.096 0.273 0.064 1.348|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 7 | 5|0043.620| 他 三 清 尘 外 剔 去 心 中 毒|t a s an q ing ch en w ai t i q v x in zh ong d u|57 57 60 60 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.169 0.169 0.169 0.169 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.590 0.590|0.032 0.081 0.088 0.073 0.096 0.610 0.064 0.249 0.088 0.305 0.032 0.241 0.096 0.249 0.088 0.273 0.064 0.305 0.032 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 8 | 6|0048.981| 尝 世 间 百 味 甘 醇 与 涩 苦|ch ang sh i j ian b ai w ei g an ch un y v s e k u|57 57 60 60 64 64 62 62 60 60 59 59 60 60 62 62 64 64 69 69|0.169 0.169 0.169 0.169 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 1.180 1.180|0.064 0.081 0.088 0.105 0.064 0.634 0.040 0.249 0.088 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 1.180|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 9 | 7|0053.929| 曾 有 谁 偏 执 不 悟 AP 谈 笑 斗 酒 至 酣 处 AP 而 今 不 过 拍 去 肩 上 红 尘 土|c eng y ou sh ui p ian zh i b u w u AP t an x iao d ou j iu zh i h an ch u AP er j in b u g uo p ai q v j ian sh ang h ong ch en t u|60 60 62 62 64 64 67 67 64 64 67 67 62 62 rest 62 62 67 67 72 72 71 71 69 69 67 67 69 69 rest 67 64 64 62 62 62 62 64 64 67 67 60 60 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.088 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.297 0.040 0.249 0.088 0.337 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.273 0.064 0.337 0.337 0.273 0.064 0.297 0.040 0.273 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.249 0.088 0.273 0.064 0.273 0.064 0.305 0.032 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
10 | 8|0064.655| 风 霜 冷 冽 他 眉 目 时 光 雕 琢 他 风 骨 浮 世 南 柯 一 梦 冷 暖 都 藏 住|f eng sh uang l eng l ie t a m ei m u sh i g uang d iao z uo t a f eng g u f u sh i n an k e y i m eng l eng n uan d ou c ang zh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 64 64 62 62 64 64 62 62 67 67 64 64 60 60 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.506 0.506|0.064 0.249 0.088 0.241 0.096 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.586 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.610 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.241 0.096 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.506|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
11 | 9|0075.418| 哪 杯 酒 烫 过 肺 腑 曾 换 他 睥 睨 一 顾 AP 剑 破 乾 坤 轮 转 山 河 倾 覆|n a b ei j iu t ang g uo f ei f u c eng h uan t a p i n i y i g u AP j ian p o q ian k un l un zh uan sh an h e q ing f u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 67 67 62 62 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.674 0.674|0.088 0.297 0.040 0.273 0.064 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.586 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.421 0.189 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.241 0.096 0.273 0.064 0.249 0.088 0.610 0.064 0.241 0.096 0.273 0.064 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
12 | 10|0086.260| 到 最 后 沧 海 一 粟 AP 何 必 江 湖 多 殊 途 AP 当 年 论 剑 峰 顶 谁 几 笔 成 书|d ao z ui h ou c ang h ai y i s u AP h e b i j iang h u d uo sh u t u AP d ang n ian l un j ian f eng d ing sh ui j i b i ch eng sh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.032 0.249 0.088 0.273 0.064 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.421 0.189 0.064 0.297 0.040 0.273 0.064 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.421 0.221 0.032 0.249 0.088 0.241 0.096 0.273 0.064 0.273 0.064 0.305 0.032 0.249 0.088 0.273 0.064 0.297 0.040 0.273 0.064 0.249 0.088 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
13 | 11|0096.991| 纵 他 朝 众 生 再 晤 AP 奈 何 明 月 终 辜 负 AP 坐 听 晨 钟 难 算 太 虚 有 无|z ong t a ch ao zh ong sh eng z ai w u AP n ai h e m ing y ve zh ong g u f u AP z uo t ing ch en zh ong n an s uan t ai x v y ou w u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 62 62 60 60 59 59 60 60 59 59 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.169 0.169 1.264 1.264|0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.249 0.088 0.421 0.165 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 0.273 0.064 0.421 0.165 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.305 0.032 0.586 0.088 0.249 0.088 0.081 0.088 1.264|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
14 | 12|0107.917| 天 道 勘 破 敢 问 一 句 悟 不|t ian d ao k an p o g an w en y i j v w u b u|57 57 64 64 62 62 64 64 62 62 60 60 59 59 60 60 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.506 0.506 0.337 0.337 0.590 0.590|0.032 0.305 0.032 0.273 0.064 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.357 0.064 0.418 0.088 0.297 0.040 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
15 | 13|0112.496| 悟 悟|w u w u|68 68 69 69|0.506 0.506 3.792 3.792|0.088 0.418 0.088 3.792|0 0 0 0


--------------------------------------------------------------------------------
/svs_song_pitch.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import numpy as np
  3 | 
  4 | from scipy.io import wavfile
  5 | from time import *
  6 | 
  7 | import torch
  8 | import argparse
  9 | 
 10 | from vits.models import SynthesizerTrn
 11 | from util import SingInput
 12 | from util import FeatureInput
 13 | from omegaconf import OmegaConf
 14 | 
 15 | 
 16 | from pitch.models import PitchDiffusion
 17 | from pitch.utils import fix_len_compatibility
 18 | 
 19 | 
 20 | def save_wav(wav, path, rate):
 21 |     wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6
 22 |     wavfile.write(path, rate, wav.astype(np.int16))
 23 | 
 24 | 
 25 | def load_svs_model(checkpoint_path, model):
 26 |     assert os.path.isfile(checkpoint_path)
 27 |     checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
 28 |     saved_state_dict = checkpoint_dict["model_g"]
 29 |     state_dict = model.state_dict()
 30 |     new_state_dict = {}
 31 |     for k, v in state_dict.items():
 32 |         try:
 33 |             new_state_dict[k] = saved_state_dict[k]
 34 |         except:
 35 |             print("%s is not in the checkpoint" % k)
 36 |             new_state_dict[k] = v
 37 |     model.load_state_dict(new_state_dict)
 38 |     return model
 39 | 
 40 | 
 41 | if __name__ == '__main__':
 42 |     parser = argparse.ArgumentParser()
 43 |     parser.add_argument('-c', '--config', type=str, required=True,
 44 |                         help="yaml file for configuration")
 45 |     parser.add_argument('-m', '--model', type=str, required=True,
 46 |                         help="path of checkpoint pt file")
 47 |     parser.add_argument('-p', '--pitch', type=str, required=True,
 48 |                         help="path of checkpoint pt file")
 49 |     args = parser.parse_args()
 50 | 
 51 |     # define model and load checkpoint
 52 |     hps = OmegaConf.load(args.config)
 53 | 
 54 |     singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length)
 55 |     featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length)
 56 | 
 57 |     net_g = SynthesizerTrn(
 58 |         hps.data.filter_length // 2 + 1,
 59 |         hps.data.segment_size // hps.data.hop_length,
 60 |         hps).cuda()
 61 |     net_g.eval()
 62 | 
 63 |     load_svs_model(args.model, net_g)
 64 | 
 65 |     net_p = PitchDiffusion().cuda()
 66 |     net_p.eval()
 67 |     load_svs_model(args.pitch, net_p)
 68 | 
 69 |     # check directory existence
 70 |     os.makedirs("./svs_out", exist_ok=True)
 71 |     fo = open("./svs_song.txt", "r+")
 72 |     song_rate = hps.data.sampling_rate
 73 |     song_time = fo.readline().strip().split("|")[1]
 74 |     song_length = int(song_rate * (float(song_time) + 30))
 75 |     song_data = np.zeros(song_length, dtype="float32")
 76 |     while True:
 77 |         try:
 78 |             message = fo.readline().strip()
 79 |         except Exception as e:
 80 |             print("nothing of except:", e)
 81 |             break
 82 |         if message == None:
 83 |             break
 84 |         if message == "":
 85 |             break
 86 |         (
 87 |             item_indx,
 88 |             item_time,
 89 |             labels_ids,
 90 |             labels_frames,
 91 |             scores_ids,
 92 |             scores_dur,
 93 |             labels_slr,
 94 |             labels_uvs,
 95 |         ) = singInput.parseSong(message)
 96 |         labels_ids = singInput.expandInput(labels_ids, labels_frames)
 97 |         labels_uvs = singInput.expandInput(labels_uvs, labels_frames)
 98 |         labels_slr = singInput.expandInput(labels_slr, labels_frames)
 99 |         scores_ids = singInput.expandInput(scores_ids, labels_frames)
100 | 
101 |         phone = torch.LongTensor(labels_ids)
102 |         score = torch.LongTensor(scores_ids)
103 |         slurs = torch.LongTensor(labels_slr)
104 | 
105 |         lengths = phone.size()[0]
106 |         lengths_fix = fix_len_compatibility(lengths)
107 | 
108 |         phone_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda()
109 |         score_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda()
110 |         slurs_fix = torch.zeros((1, lengths_fix), dtype=torch.long).cuda()
111 |         phone_fix[0, :lengths] = phone
112 |         score_fix[0, :lengths] = score
113 |         slurs_fix[0, :lengths] = slurs
114 | 
115 |         begin_time = time()
116 |         with torch.no_grad():
117 |             n_timesteps = 50
118 |             temperature = 1
119 |             # PIT
120 |             phone_lengths = torch.LongTensor([lengths_fix]).cuda()
121 |             pit_pri, pit_pre = net_p(phone_fix, phone_lengths, score_fix, slurs_fix, n_timesteps, temperature)
122 |             pitch = pit_pre[:, 0, :]
123 |             pitch = 2**pitch
124 |             print('~~~~~~~')
125 |             audio = (
126 |                 net_g.infer(phone_fix, phone_lengths, score_fix, pitch, slurs_fix)[0, 0]
127 |                 .data.cpu()
128 |                 .float()
129 |                 .numpy()
130 |             )
131 |        
132 |         save_wav(audio, f"./svs_out/{item_indx}.wav", hps.data.sampling_rate)
133 |         # wav
134 |         item_start = int(song_rate * float(item_time))
135 |         item_end = item_start + len(audio)
136 |         song_data[item_start:item_end] = audio
137 |     # out of for
138 |     song_data = np.array(song_data, dtype="float32")
139 |     save_wav(song_data, f"./svs_out/_song.wav", hps.data.sampling_rate)
140 |     fo.close()
141 |     # can be deleted
142 |     os.system("chmod 777 ./svs_out -R")
143 | 


--------------------------------------------------------------------------------
/svs_train.py:
--------------------------------------------------------------------------------
 1 | import sys,os
 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 3 | import argparse
 4 | import torch
 5 | import torch.multiprocessing as mp
 6 | from omegaconf import OmegaConf
 7 | 
 8 | from vits_extend.train import train
 9 | 
10 | torch.backends.cudnn.benchmark = True
11 | 
12 | 
13 | if __name__ == '__main__':
14 |     parser = argparse.ArgumentParser()
15 |     parser.add_argument('-c', '--config', type=str, required=True,
16 |                         help="yaml file for configuration")
17 |     parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
18 |                         help="path of checkpoint pt file to resume training")
19 |     parser.add_argument('-n', '--name', type=str, required=True,
20 |                         help="name of the model for logging, saving checkpoint")
21 |     args = parser.parse_args()
22 | 
23 |     hp = OmegaConf.load(args.config)
24 |     with open(args.config, 'r') as f:
25 |         hp_str = ''.join(f.readlines())
26 | 
27 |     assert hp.data.hop_length == 320, \
28 |         'hp.data.hop_length must be equal to 320, got %d' % hp.data.hop_length
29 | 
30 |     args.num_gpus = 0
31 |     torch.manual_seed(hp.train.seed)
32 |     if torch.cuda.is_available():
33 |         torch.cuda.manual_seed(hp.train.seed)
34 |         args.num_gpus = torch.cuda.device_count()
35 |         print('Batch size per GPU :', hp.train.batch_size)
36 | 
37 |         if args.num_gpus > 1:
38 |             mp.spawn(train, nprocs=args.num_gpus,
39 |                      args=(args, args.checkpoint_path, hp, hp_str,))
40 |         else:
41 |             train(0, args, args.checkpoint_path, hp, hp_str)
42 |     else:
43 |         print('No GPU find!')
44 | 


--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
  1 | import numpy as np
  2 | import librosa
  3 | import pyworld
  4 | 
  5 | from svs import label_to_ids, load_midi_map, uv_map
  6 | 
  7 | 
  8 | class SingInput(object):
  9 |     def __init__(self, samplerate=32000, hop_size=320):
 10 |         self.fs = samplerate
 11 |         self.hop = hop_size
 12 |         self.notemaper = load_midi_map()
 13 | 
 14 |     def phone_to_uv(self, phones):
 15 |         uv = []
 16 |         for phone in phones:
 17 |             uv.append(uv_map[phone.lower()])
 18 |         return uv
 19 | 
 20 |     def notes_to_id(self, notes):
 21 |         note_ids = []
 22 |         for note in notes:
 23 |             note_ids.append(self.notemaper[note])
 24 |         return note_ids
 25 | 
 26 |     def frame_duration(self, durations):
 27 |         ph_durs = [float(x) for x in durations]
 28 |         sentence_length = 0
 29 |         for ph_dur in ph_durs:
 30 |             sentence_length = sentence_length + ph_dur
 31 |         sentence_length = int(sentence_length * self.fs / self.hop + 0.5)
 32 | 
 33 |         sample_frame = []
 34 |         startTime = 0
 35 |         for i_ph in range(len(ph_durs)):
 36 |             start_frame = int(startTime * self.fs / self.hop + 0.5)
 37 |             end_frame = int((startTime + ph_durs[i_ph]) * self.fs / self.hop + 0.5)
 38 |             count_frame = end_frame - start_frame
 39 |             sample_frame.append(count_frame)
 40 |             startTime = startTime + ph_durs[i_ph]
 41 |         all_frame = np.sum(sample_frame)
 42 |         assert all_frame == sentence_length
 43 |         # match mel length
 44 |         sample_frame[-1] = sample_frame[-1] - 1
 45 |         return sample_frame
 46 | 
 47 |     def score_duration(self, durations):
 48 |         ph_durs = [float(x) for x in durations]
 49 |         sample_frame = []
 50 |         for i_ph in range(len(ph_durs)):
 51 |             count_frame = int(ph_durs[i_ph] * self.fs / self.hop + 0.5)
 52 |             if count_frame >= 256:
 53 |                 print("count_frame", count_frame)
 54 |                 count_frame = 255
 55 |             sample_frame.append(count_frame)
 56 |         return sample_frame
 57 | 
 58 |     def parseInput(self, singinfo: str):
 59 |         infos = singinfo.split("|")
 60 |         file = infos[0]
 61 |         # hanz = infos[1]
 62 |         phon = infos[2].split(" ")
 63 |         note = infos[3].split(" ")
 64 |         note_dur = infos[4].split(" ")
 65 |         phon_dur = infos[5].split(" ")
 66 |         phon_slr = infos[6].split(" ")
 67 | 
 68 |         labels_ids = label_to_ids(phon)
 69 |         labels_uvs = self.phone_to_uv(phon)
 70 |         labels_frames = self.frame_duration(phon_dur)
 71 |         scores_ids = self.notes_to_id(note)
 72 |         scores_dur = self.score_duration(note_dur)
 73 |         labels_slr = [int(x) for x in phon_slr]
 74 |         return (
 75 |             file,
 76 |             labels_ids,
 77 |             labels_frames,
 78 |             scores_ids,
 79 |             scores_dur,
 80 |             labels_slr,
 81 |             labels_uvs,
 82 |         )
 83 | 
 84 |     def parseSong(self, singinfo: str):
 85 |         infos = singinfo.split("|")
 86 |         item_indx = infos[0]
 87 |         item_time = infos[1]
 88 |         # hanz = infos[2]
 89 |         phon = infos[3].split(" ")
 90 |         note_ids = infos[4].split(" ")
 91 |         note_dur = infos[5].split(" ")
 92 |         phon_dur = infos[6].split(" ")
 93 |         phon_slr = infos[7].split(" ")
 94 | 
 95 |         labels_ids = label_to_ids(phon)
 96 |         labels_uvs = self.phone_to_uv(phon)
 97 |         labels_frames = self.frame_duration(phon_dur)
 98 |         scores_ids = [int(x) if x != "rest" else 0 for x in note_ids]
 99 |         scores_dur = self.score_duration(note_dur)
100 |         labels_slr = [int(x) for x in phon_slr]
101 |         return (
102 |             item_indx,
103 |             item_time,
104 |             labels_ids,
105 |             labels_frames,
106 |             scores_ids,
107 |             scores_dur,
108 |             labels_slr,
109 |             labels_uvs,
110 |         )
111 | 
112 |     def expandInput(self, labels_ids, labels_frames):
113 |         assert len(labels_ids) == len(labels_frames)
114 |         frame_num = np.sum(labels_frames)
115 |         frame_labels = np.zeros(frame_num, dtype=np.int)
116 |         start = 0
117 |         for index, num in enumerate(labels_frames):
118 |             frame_labels[start : start + num] = labels_ids[index]
119 |             start += num
120 |         return frame_labels
121 | 
122 |     def scorePitch(self, scores_id):
123 |         score_pitch = np.zeros(len(scores_id), dtype=np.float)
124 |         for index, score_id in enumerate(scores_id):
125 |             if score_id == 0:
126 |                 score_pitch[index] = 0
127 |             else:
128 |                 pitch = librosa.midi_to_hz(score_id)
129 |                 score_pitch[index] = round(pitch, 1)
130 |         return score_pitch
131 | 
132 |     def smoothPitch(self, pitch):
133 |         # 使用卷积对数据平滑
134 |         kernel = np.hanning(5)  # 随机生成一个卷积核(对称的)
135 |         kernel /= kernel.sum()
136 |         smooth_pitch = np.convolve(pitch, kernel, "same")
137 |         return smooth_pitch
138 | 
139 | 
140 | class FeatureInput(object):
141 |     def __init__(self, samplerate=32000, hop_size=320):
142 |         self.fs = samplerate
143 |         self.hop = hop_size
144 | 
145 |         self.f0_bin = 256
146 |         self.f0_max = 1100.0
147 |         self.f0_min = 50.0
148 |         self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
149 |         self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
150 | 
151 |     def compute_f0(self, file):
152 |         x, sr = librosa.load(file, sr=self.fs)
153 |         assert sr == self.fs
154 |         f0, t = pyworld.dio(
155 |             x.astype(np.double),
156 |             fs=sr,
157 |             f0_ceil=900,
158 |             frame_period=1000 * self.hop / sr,
159 |         )
160 |         f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
161 |         for index, pitch in enumerate(f0):
162 |             f0[index] = round(pitch, 1)
163 |         return f0
164 | 
165 |     def coarse_f0(self, f0):
166 |         f0_mel = 1127 * np.log(1 + f0 / 700)
167 |         f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * (
168 |             self.f0_bin - 2
169 |         ) / (self.f0_mel_max - self.f0_mel_min) + 1
170 | 
171 |         # use 0 or 1
172 |         f0_mel[f0_mel <= 1] = 1
173 |         f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1
174 |         f0_coarse = np.rint(f0_mel).astype(np.int)
175 |         assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
176 |             f0_coarse.max(),
177 |             f0_coarse.min(),
178 |         )
179 |         return f0_coarse
180 | 


--------------------------------------------------------------------------------
/util/generate_index.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import random
 3 | import argparse
 4 | 
 5 | 
 6 | if __name__ == "__main__":
 7 |     parser = argparse.ArgumentParser()
 8 |     parser.add_argument('--file', type=str, required=True)
 9 |     args = parser.parse_args()
10 |     alls = []
11 |     fo = open(args.file, "r+")
12 |     while True:
13 |         try:
14 |             message = fo.readline().strip()
15 |         except Exception as e:
16 |             print("nothing of except:", e)
17 |             break
18 |         if message == None:
19 |             break
20 |         if message == "":
21 |             break
22 |         alls.append(message)
23 |     fo.close()
24 | 
25 |     valids = alls[:5]
26 |     trains = alls[5:]
27 | 
28 |     random.shuffle(trains)
29 |     os.makedirs("filelists", exist_ok=True)
30 | 
31 |     fw = open("./filelists/singing_valid.txt", "w", encoding="utf-8")
32 |     for strs in valids:
33 |         print(strs, file=fw)
34 |     fw.close()
35 | 
36 |     fw = open("./filelists/singing_train.txt", "w", encoding="utf-8")
37 |     for strs in trains:
38 |         print(strs, file=fw)
39 |     fw.close()
40 | 


--------------------------------------------------------------------------------
/util/generate_label.py:
--------------------------------------------------------------------------------
  1 | import os, sys
  2 | sys.path.append(os.getcwd())
  3 | import logging
  4 | logging.basicConfig(level=logging.INFO)  # ERROR & INFO
  5 | import argparse
  6 | import numpy as np
  7 | 
  8 | from omegaconf import OmegaConf
  9 | from util import SingInput, FeatureInput
 10 | 
 11 | 
 12 | if __name__ == "__main__":
 13 |     parser = argparse.ArgumentParser()
 14 |     parser.add_argument('--config', type=str, required=True)
 15 |     parser.add_argument('--data', type=str, required=True)
 16 |     parser.add_argument('--file', type=str, required=True)
 17 |     args = parser.parse_args()
 18 | 
 19 |     hps = OmegaConf.load(args.config)
 20 | 
 21 |     assert os.path.exists(args.file)
 22 |     assert os.path.exists(os.path.join(args.data, "wavs"))
 23 |     os.makedirs(os.path.join(args.data, "labels"), exist_ok=True)
 24 | 
 25 |     singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length)
 26 |     featureInput = FeatureInput(hps.data.sampling_rate, hps.data.hop_length)
 27 | 
 28 |     raw_file = open(args.file, "r+")
 29 |     vits_file = open(os.path.join(args.data, "labels.txt"),
 30 |                      "w", encoding="utf-8")
 31 |     label_path = os.path.join(args.data, "labels")
 32 |     i = 0
 33 |     all_txt = []  # 统计非重复的句子个数
 34 |     while True:
 35 |         try:
 36 |             message = raw_file.readline().strip()
 37 |         except Exception as e:
 38 |             print("nothing of except:", e)
 39 |             break
 40 |         if message == None:
 41 |             break
 42 |         if message == "":
 43 |             break
 44 |         # i = i + 1
 45 |         # if i > 5:
 46 |         #    break
 47 |         infos = message.split("|")
 48 |         file = infos[0]
 49 |         hanz = infos[1]
 50 |         all_txt.append(hanz)
 51 |         phon = infos[2].split(" ")
 52 |         note = infos[3].split(" ")
 53 |         note_dur = infos[4].split(" ")
 54 |         phon_dur = infos[5].split(" ")
 55 |         phon_slur = infos[6].split(" ")
 56 | 
 57 |         logging.info("----------------------------")
 58 |         logging.info(file)
 59 |         logging.info(hanz)
 60 |         logging.info(phon)
 61 |         # logging.info(note_dur)
 62 |         # logging.info(phon_dur)
 63 |         # logging.info(phon_slur)
 64 |         path_wave = os.path.join(args.data, "wavs", f"{file}.wav")
 65 |         path_label = os.path.join(label_path, f"{file}_label.npy")
 66 |         path_score = os.path.join(label_path, f"{file}_score.npy")
 67 |         path_pitch = os.path.join(label_path, f"{file}_pitch.npy")
 68 |         path_slurs = os.path.join(label_path, f"{file}_slurs.npy")
 69 | 
 70 |         (
 71 |             file,
 72 |             labels_ids,
 73 |             labels_frames,
 74 |             scores_ids,
 75 |             scores_dur,
 76 |             labels_slr,
 77 |             labels_uvs,
 78 |         ) = singInput.parseInput(message)
 79 |         labels_ids = singInput.expandInput(labels_ids, labels_frames)
 80 |         labels_uvs = singInput.expandInput(labels_uvs, labels_frames)
 81 |         labels_slr = singInput.expandInput(labels_slr, labels_frames)
 82 |         scores_ids = singInput.expandInput(scores_ids, labels_frames)
 83 | 
 84 |         featur_pit = featureInput.compute_f0(path_wave)
 85 |         featur_pit = featur_pit[: len(labels_ids)]
 86 |         featur_pit = featur_pit * labels_uvs
 87 | 
 88 |         assert len(labels_ids) == len(featur_pit)
 89 | 
 90 |         np.save(path_label, labels_ids, allow_pickle=False)
 91 |         np.save(path_score, scores_ids, allow_pickle=False)
 92 |         np.save(path_pitch, featur_pit, allow_pickle=False)
 93 |         np.save(path_slurs, labels_slr, allow_pickle=False)
 94 | 
 95 |         print(
 96 |             f"{path_wave}|{path_label}|{path_score}|{path_pitch}|{path_slurs}",
 97 |             file=vits_file,
 98 |         )
 99 | 
100 |     raw_file.close()
101 |     vits_file.close()
102 |     print(len(set(all_txt)))  # 统计非重复的句子个数
103 | 


--------------------------------------------------------------------------------
/util/resample.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import librosa
 3 | import argparse
 4 | import numpy as np
 5 | from tqdm import tqdm
 6 | from concurrent.futures import ThreadPoolExecutor, as_completed
 7 | from scipy.io import wavfile
 8 | 
 9 | 
10 | def resample_wave(wav_in, wav_out, sample_rate):
11 |     wav, _ = librosa.load(wav_in, sr=sample_rate)
12 |     wav = wav / np.abs(wav).max() * 0.6
13 |     wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6
14 |     wavfile.write(wav_out, sample_rate, wav.astype(np.int16))
15 | 
16 | 
17 | def process_file(file, wavPath, outPath, sr):
18 |     if file.endswith(".wav"):
19 |         file = file[:-4]
20 |         resample_wave(f"{wavPath}/{file}.wav", f"{outPath}/{file}.wav", sr)
21 | 
22 | 
23 | def process_files_with_thread_pool(wavPath, outPath, sr, thread_num=None):
24 |     files = [f for f in os.listdir(f"./{wavPath}") if f.endswith(".wav")]
25 | 
26 |     with ThreadPoolExecutor(max_workers=thread_num) as executor:
27 |         futures = {executor.submit(process_file, file, wavPath, outPath, sr): file for file in files}
28 | 
29 |         for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr}'):
30 |             future.result()
31 | 
32 | 
33 | if __name__ == "__main__":
34 |     parser = argparse.ArgumentParser()
35 |     parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True)
36 |     parser.add_argument("-o", "--out", help="out", dest="out", required=True)
37 |     parser.add_argument("-s", "--sr", help="sample rate", dest="sr", type=int, required=True)
38 |     parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1)
39 | 
40 |     args = parser.parse_args()
41 |     print(args.wav)
42 |     print(args.out)
43 |     print(args.sr)
44 | 
45 |     os.makedirs(args.out, exist_ok=True)
46 |     wavPath = args.wav
47 |     outPath = args.out
48 | 
49 |     if args.thread_count == 0:
50 |         process_num = os.cpu_count() // 2 + 1
51 |     else:
52 |         process_num = args.thread_count
53 |     process_files_with_thread_pool(wavPath, outPath, args.sr, process_num)
54 | 


--------------------------------------------------------------------------------
/vits/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/vits/__init__.py


--------------------------------------------------------------------------------
/vits/attentions.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import torch
  3 | from torch import nn
  4 | from einops import rearrange
  5 | 
  6 | 
  7 | class LayerNorm(nn.Module):
  8 |     def __init__(self, channels, eps=1e-4):
  9 |         super(LayerNorm, self).__init__()
 10 |         self.channels = channels
 11 |         self.eps = eps
 12 | 
 13 |         self.gamma = torch.nn.Parameter(torch.ones(channels))
 14 |         self.beta = torch.nn.Parameter(torch.zeros(channels))
 15 | 
 16 |     def forward(self, x):
 17 |         n_dims = len(x.shape)
 18 |         mean = torch.mean(x, 1, keepdim=True)
 19 |         variance = torch.mean((x - mean)**2, 1, keepdim=True)
 20 | 
 21 |         x = (x - mean) * torch.rsqrt(variance + self.eps)
 22 | 
 23 |         shape = [1, -1] + [1] * (n_dims - 2)
 24 |         x = x * self.gamma.view(*shape) + self.beta.view(*shape)
 25 |         return x
 26 | 
 27 | 
 28 | class ConvReluNorm(nn.Module):
 29 |     def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, 
 30 |                  n_layers, p_dropout, eps=1e-5):
 31 |         super(ConvReluNorm, self).__init__()
 32 |         self.in_channels = in_channels
 33 |         self.hidden_channels = hidden_channels
 34 |         self.out_channels = out_channels
 35 |         self.kernel_size = kernel_size
 36 |         self.n_layers = n_layers
 37 |         self.p_dropout = p_dropout
 38 |         self.eps = eps
 39 | 
 40 |         self.conv_layers = torch.nn.ModuleList()
 41 |         self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, 
 42 |                                                 kernel_size, padding=kernel_size//2))
 43 |         self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
 44 |         for _ in range(n_layers - 1):
 45 |             self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, 
 46 |                                                     kernel_size, padding=kernel_size//2))
 47 |         self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
 48 |         self.proj.weight.data.zero_()
 49 |         self.proj.bias.data.zero_()
 50 | 
 51 |     def forward(self, x, x_mask):
 52 |         for i in range(self.n_layers):
 53 |             x = self.conv_layers[i](x * x_mask)
 54 |             x = self.instance_norm(x, x_mask)
 55 |             x = self.relu_drop(x)
 56 |         x = self.proj(x)
 57 |         return x * x_mask
 58 | 
 59 |     def instance_norm(self, x, mask, return_mean_std=False):
 60 |         mean, std = self.calc_mean_std(x, mask)
 61 |         x = (x - mean) / std
 62 |         if return_mean_std:
 63 |             return x, mean, std
 64 |         else:
 65 |             return x
 66 | 
 67 |     def calc_mean_std(self, x, mask=None):
 68 |         x = x * mask
 69 |         B, C = x.shape[:2]
 70 |         mn = x.view(B, C, -1).mean(-1)
 71 |         sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
 72 |         mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
 73 |         sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
 74 |         return mn, sd
 75 | 
 76 | 
 77 | class RotaryPositionalEmbeddings(nn.Module):
 78 |     """
 79 |     ## RoPE module
 80 |     https://github.com/labmlai/annotated_deep_learning_paper_implementations
 81 |     
 82 |     Rotary encoding transforms pairs of features by rotating in the 2D plane.
 83 |     That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
 84 |     Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
 85 |     by an angle depending on the position of the token.
 86 |     """
 87 |     def __init__(self, d: int, base: int = 10_000):
 88 |         r"""
 89 |         * `d` is the number of features $d$
 90 |         * `base` is the constant used for calculating $\Theta$
 91 |         """
 92 |         super().__init__()
 93 |         self.base = base
 94 |         self.d = int(d)
 95 |         self.cos_cached = None
 96 |         self.sin_cached = None
 97 | 
 98 |     def _build_cache(self, x: torch.Tensor):
 99 |         r"""
100 |         Cache $\cos$ and $\sin$ values
101 |         """
102 |         # Return if cache is already built
103 |         if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
104 |             return
105 |         # Get sequence length
106 |         seq_len = x.shape[0]
107 |         theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
108 |         # Create position indexes `[0, 1, ..., seq_len - 1]`
109 |         seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
110 |         # Calculate the product of position index and $\theta_i$
111 |         idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
112 |         # Concatenate so that for row $m$ we have
113 |         idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
114 |         # Cache them
115 |         self.cos_cached = idx_theta2.cos()[:, None, None, :]
116 |         self.sin_cached = idx_theta2.sin()[:, None, None, :]
117 | 
118 |     def _neg_half(self, x: torch.Tensor):
119 |         d_2 = self.d // 2
120 |         return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
121 | 
122 |     def forward(self, x: torch.Tensor):
123 |         """
124 |         * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
125 |         """
126 |         x = rearrange(x, "b h t d -> t b h d")
127 |         self._build_cache(x)
128 |         # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
129 |         x_rope, x_pass = x[..., : self.d], x[..., self.d :]
130 |         # Calculate
131 |         neg_half_x = self._neg_half(x_rope)
132 |         x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
133 |         return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
134 | 
135 | 
136 | class MultiHeadAttention(nn.Module):
137 |     def __init__(self, channels, out_channels, n_heads, 
138 |                  heads_share=True, p_dropout=0.0, proximal_bias=False, 
139 |                  proximal_init=False):
140 |         super(MultiHeadAttention, self).__init__()
141 |         assert channels % n_heads == 0
142 | 
143 |         self.channels = channels
144 |         self.out_channels = out_channels
145 |         self.n_heads = n_heads
146 |         self.heads_share = heads_share
147 |         self.proximal_bias = proximal_bias
148 |         self.p_dropout = p_dropout
149 |         self.attn = None
150 | 
151 |         self.k_channels = channels // n_heads
152 |         self.conv_q = torch.nn.Conv1d(channels, channels, 1)
153 |         self.conv_k = torch.nn.Conv1d(channels, channels, 1)
154 |         self.conv_v = torch.nn.Conv1d(channels, channels, 1)
155 | 
156 |         # from https://nn.labml.ai/transformers/rope/index.html
157 |         self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
158 |         self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
159 | 
160 |         self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
161 |         self.drop = torch.nn.Dropout(p_dropout)
162 | 
163 |         torch.nn.init.xavier_uniform_(self.conv_q.weight)
164 |         torch.nn.init.xavier_uniform_(self.conv_k.weight)
165 |         if proximal_init:
166 |             self.conv_k.weight.data.copy_(self.conv_q.weight.data)
167 |             self.conv_k.bias.data.copy_(self.conv_q.bias.data)
168 |         torch.nn.init.xavier_uniform_(self.conv_v.weight)
169 | 
170 |     def forward(self, x, c, attn_mask=None):
171 |         q = self.conv_q(x)
172 |         k = self.conv_k(c)
173 |         v = self.conv_v(c)
174 | 
175 |         x, self.attn = self.attention(q, k, v, mask=attn_mask)
176 | 
177 |         x = self.conv_o(x)
178 |         return x
179 | 
180 |     def attention(self, query, key, value, mask=None):
181 |         b, d, t_s, t_t = (*key.size(), query.size(2))
182 |         query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
183 |         key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
184 |         value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
185 | 
186 |         query = self.query_rotary_pe(query)
187 |         key = self.key_rotary_pe(key)
188 | 
189 |         scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
190 | 
191 |         if self.proximal_bias:
192 |             assert t_s == t_t, "Proximal bias is only available for self-attention."
193 |             scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, 
194 |                                                                     dtype=scores.dtype)
195 |         if mask is not None:
196 |             scores = scores.masked_fill(mask == 0, -1e4)
197 |         p_attn = torch.nn.functional.softmax(scores, dim=-1)
198 |         p_attn = self.drop(p_attn)
199 |         output = torch.matmul(p_attn, value)
200 |         output = output.transpose(2, 3).contiguous().view(b, d, t_t)
201 |         return output, p_attn
202 | 
203 |     def _attention_bias_proximal(self, length):
204 |         r = torch.arange(length, dtype=torch.float32)
205 |         diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
206 |         return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
207 | 
208 | 
209 | class FFN(nn.Module):
210 |     def __init__(self, in_channels, out_channels, filter_channels, kernel_size, 
211 |                  p_dropout=0.0):
212 |         super(FFN, self).__init__()
213 |         self.in_channels = in_channels
214 |         self.out_channels = out_channels
215 |         self.filter_channels = filter_channels
216 |         self.kernel_size = kernel_size
217 |         self.p_dropout = p_dropout
218 | 
219 |         self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, 
220 |                                       padding=kernel_size//2)
221 |         self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, 
222 |                                       padding=kernel_size//2)
223 |         self.drop = torch.nn.Dropout(p_dropout)
224 | 
225 |     def forward(self, x, x_mask):
226 |         x = self.conv_1(x * x_mask)
227 |         x = torch.relu(x)
228 |         x = self.drop(x)
229 |         x = self.conv_2(x * x_mask)
230 |         return x * x_mask
231 | 
232 | 
233 | class Encoder(nn.Module):
234 |     def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, 
235 |                  kernel_size=1, p_dropout=0.0, **kwargs):
236 |         super(Encoder, self).__init__()
237 |         self.hidden_channels = hidden_channels
238 |         self.filter_channels = filter_channels
239 |         self.n_heads = n_heads
240 |         self.n_layers = n_layers
241 |         self.kernel_size = kernel_size
242 |         self.p_dropout = p_dropout
243 | 
244 |         self.drop = torch.nn.Dropout(p_dropout)
245 |         self.attn_layers = torch.nn.ModuleList()
246 |         self.norm_layers_1 = torch.nn.ModuleList()
247 |         self.ffn_layers = torch.nn.ModuleList()
248 |         self.norm_layers_2 = torch.nn.ModuleList()
249 |         for _ in range(self.n_layers):
250 |             self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, 
251 |                                                        n_heads, p_dropout=p_dropout))
252 |             self.norm_layers_1.append(LayerNorm(hidden_channels))
253 |             self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
254 |                                        filter_channels, kernel_size, p_dropout=p_dropout))
255 |             self.norm_layers_2.append(LayerNorm(hidden_channels))
256 | 
257 |     def forward(self, x, x_mask):
258 |         attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
259 |         for i in range(self.n_layers):
260 |             x = x * x_mask
261 |             y = self.attn_layers[i](x, x, attn_mask)
262 |             y = self.drop(y)
263 |             x = self.norm_layers_1[i](x + y)
264 |             y = self.ffn_layers[i](x, x_mask)
265 |             y = self.drop(y)
266 |             x = self.norm_layers_2[i](x + y)
267 |         x = x * x_mask
268 |         return x
269 | 


--------------------------------------------------------------------------------
/vits/commons.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import numpy as np
  3 | import torch
  4 | from torch import nn
  5 | from torch.nn import functional as F
  6 | 
  7 | 
  8 | def slice_pitch_segments(x, ids_str, segment_size=4):
  9 |     ret = torch.zeros_like(x[:, :segment_size])
 10 |     for i in range(x.size(0)):
 11 |         idx_str = ids_str[i]
 12 |         idx_end = idx_str + segment_size
 13 |         ret[i] = x[i, idx_str:idx_end]
 14 |     return ret
 15 | 
 16 | 
 17 | def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
 18 |     b, d, t = x.size()
 19 |     if x_lengths is None:
 20 |         x_lengths = t
 21 |     ids_str_max = x_lengths - segment_size + 1
 22 |     ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
 23 |     ret = slice_segments(x, ids_str, segment_size)
 24 |     ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
 25 |     return ret, ret_pitch, ids_str
 26 | 
 27 | 
 28 | def rand_spec_segments(x, x_lengths=None, segment_size=4):
 29 |     b, d, t = x.size()
 30 |     if x_lengths is None:
 31 |         x_lengths = t
 32 |     ids_str_max = x_lengths - segment_size
 33 |     ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
 34 |     ret = slice_segments(x, ids_str, segment_size)
 35 |     return ret, ids_str
 36 | 
 37 | 
 38 | def init_weights(m, mean=0.0, std=0.01):
 39 |     classname = m.__class__.__name__
 40 |     if classname.find("Conv") != -1:
 41 |         m.weight.data.normal_(mean, std)
 42 | 
 43 | 
 44 | def get_padding(kernel_size, dilation=1):
 45 |     return int((kernel_size * dilation - dilation) / 2)
 46 | 
 47 | 
 48 | def convert_pad_shape(pad_shape):
 49 |     l = pad_shape[::-1]
 50 |     pad_shape = [item for sublist in l for item in sublist]
 51 |     return pad_shape
 52 | 
 53 | 
 54 | def kl_divergence(m_p, logs_p, m_q, logs_q):
 55 |     """KL(P||Q)"""
 56 |     kl = (logs_q - logs_p) - 0.5
 57 |     kl += (
 58 |         0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
 59 |     )
 60 |     return kl
 61 | 
 62 | 
 63 | def rand_gumbel(shape):
 64 |     """Sample from the Gumbel distribution, protect from overflows."""
 65 |     uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
 66 |     return -torch.log(-torch.log(uniform_samples))
 67 | 
 68 | 
 69 | def rand_gumbel_like(x):
 70 |     g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
 71 |     return g
 72 | 
 73 | 
 74 | def slice_segments(x, ids_str, segment_size=4):
 75 |     ret = torch.zeros_like(x[:, :, :segment_size])
 76 |     for i in range(x.size(0)):
 77 |         idx_str = ids_str[i]
 78 |         idx_end = idx_str + segment_size
 79 |         ret[i] = x[i, :, idx_str:idx_end]
 80 |     return ret
 81 | 
 82 | 
 83 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
 84 |     b, d, t = x.size()
 85 |     if x_lengths is None:
 86 |         x_lengths = t
 87 |     ids_str_max = x_lengths - segment_size + 1
 88 |     ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
 89 |     ret = slice_segments(x, ids_str, segment_size)
 90 |     return ret, ids_str
 91 | 
 92 | 
 93 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
 94 |     position = torch.arange(length, dtype=torch.float)
 95 |     num_timescales = channels // 2
 96 |     log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
 97 |         num_timescales - 1
 98 |     )
 99 |     inv_timescales = min_timescale * torch.exp(
100 |         torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
101 |     )
102 |     scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
103 |     signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
104 |     signal = F.pad(signal, [0, 0, 0, channels % 2])
105 |     signal = signal.view(1, channels, length)
106 |     return signal
107 | 
108 | 
109 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
110 |     b, channels, length = x.size()
111 |     signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
112 |     return x + signal.to(dtype=x.dtype, device=x.device)
113 | 
114 | 
115 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
116 |     b, channels, length = x.size()
117 |     signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
118 |     return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
119 | 
120 | 
121 | def subsequent_mask(length):
122 |     mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
123 |     return mask
124 | 
125 | 
126 | @torch.jit.script
127 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
128 |     n_channels_int = n_channels[0]
129 |     in_act = input_a + input_b
130 |     t_act = torch.tanh(in_act[:, :n_channels_int, :])
131 |     s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
132 |     acts = t_act * s_act
133 |     return acts
134 | 
135 | 
136 | def convert_pad_shape(pad_shape):
137 |     l = pad_shape[::-1]
138 |     pad_shape = [item for sublist in l for item in sublist]
139 |     return pad_shape
140 | 
141 | 
142 | def shift_1d(x):
143 |     x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
144 |     return x
145 | 
146 | 
147 | def sequence_mask(length, max_length=None):
148 |     if max_length is None:
149 |         max_length = length.max()
150 |     x = torch.arange(max_length, dtype=length.dtype, device=length.device)
151 |     return x.unsqueeze(0) < length.unsqueeze(1)
152 | 
153 | 
154 | def generate_path(duration, mask):
155 |     """
156 |     duration: [b, 1, t_x]
157 |     mask: [b, 1, t_y, t_x]
158 |     """
159 |     device = duration.device
160 | 
161 |     b, _, t_y, t_x = mask.shape
162 |     cum_duration = torch.cumsum(duration, -1)
163 | 
164 |     cum_duration_flat = cum_duration.view(b * t_x)
165 |     path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
166 |     path = path.view(b, t_x, t_y)
167 |     path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
168 |     path = path.unsqueeze(1).transpose(2, 3) * mask
169 |     return path
170 | 
171 | 
172 | def clip_grad_value_(parameters, clip_value, norm_type=2):
173 |     if isinstance(parameters, torch.Tensor):
174 |         parameters = [parameters]
175 |     parameters = list(filter(lambda p: p.grad is not None, parameters))
176 |     norm_type = float(norm_type)
177 |     if clip_value is not None:
178 |         clip_value = float(clip_value)
179 | 
180 |     total_norm = 0
181 |     for p in parameters:
182 |         param_norm = p.grad.data.norm(norm_type)
183 |         total_norm += param_norm.item() ** norm_type
184 |         if clip_value is not None:
185 |             p.grad.data.clamp_(min=-clip_value, max=clip_value)
186 |     total_norm = total_norm ** (1.0 / norm_type)
187 |     return total_norm
188 | 


--------------------------------------------------------------------------------
/vits/data_utils.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import numpy as np
  3 | import torch
  4 | import torch.utils.data
  5 | 
  6 | from vits.spectrogram import spectrogram_torch
  7 | from vits.utils import load_wav_to_torch
  8 | 
  9 | 
 10 | def load_filepaths(filename, split="|"):
 11 |     with open(filename, encoding='utf-8') as f:
 12 |         filepaths = [line.strip().split(split) for line in f]
 13 |     return filepaths
 14 | 
 15 | 
 16 | class TextAudioLoader(torch.utils.data.Dataset):
 17 |     """
 18 |     1) loads audio, text pairs
 19 |     2) normalizes text and converts them to sequences of integers
 20 |     3) computes spectrograms from audio files.
 21 |     """
 22 | 
 23 |     def __init__(self, audiopaths_and_text, hparams):
 24 |         self.audiopaths_and_text = load_filepaths(audiopaths_and_text)
 25 |         self.max_wav_value  = hparams.max_wav_value
 26 |         self.sampling_rate  = hparams.sampling_rate
 27 |         self.filter_length  = hparams.filter_length 
 28 |         self.hop_length     = hparams.hop_length 
 29 |         self.win_length     = hparams.win_length
 30 |         self.sampling_rate  = hparams.sampling_rate 
 31 |         self.min_text_len   = getattr(hparams, "min_text_len", 1)
 32 |         self.max_text_len   = getattr(hparams, "max_text_len", 5000)
 33 |         self._filter()
 34 |         print(f"~~~~~~~~~~~~~~~~~~~~~{len(self)}~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
 35 | 
 36 |     def _filter(self):
 37 |         """
 38 |         Filter text & store spec lengths
 39 |         """
 40 |         # Store spectrogram lengths for Bucketing
 41 |         # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
 42 |         # spec_length = wav_length // hop_length
 43 |         audiopaths_and_text_new = []
 44 |         lengths = []
 45 |         for audiopath, text, score, pitch, slur in self.audiopaths_and_text:
 46 |             if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
 47 |                 wav_len = os.path.getsize(audiopath) // (2 * self.hop_length)
 48 |                 if wav_len < 50: # no use short wave
 49 |                     continue
 50 |                 audiopaths_and_text_new.append([audiopath, text, score, pitch, slur])
 51 |                 lengths.append(wav_len)
 52 |         self.audiopaths_and_text = audiopaths_and_text_new
 53 |         self.lengths = lengths
 54 | 
 55 |     def get_audio_text_pair(self, audiopath_and_text):
 56 |         # separate filename and text
 57 |         file = audiopath_and_text[0]
 58 |         phone = audiopath_and_text[1]
 59 |         score = audiopath_and_text[2]
 60 |         pitch = audiopath_and_text[3]
 61 |         slurs = audiopath_and_text[4]
 62 | 
 63 |         phone, score, pitch, slurs = self.get_labels(phone, score, pitch, slurs)
 64 |         spec, wav = self.get_audio(file)
 65 | 
 66 |         len_phone = phone.size()[0]
 67 |         len_spec = spec.size()[-1]
 68 | 
 69 |         if len_phone != len_spec:
 70 |             # print("**************CareFull*******************")
 71 |             # print(f"filepath={audiopath_and_text[0]}")
 72 |             # print(f"len_text={len_phone}")
 73 |             # print(f"len_spec={len_spec}")
 74 |             if len_phone > len_spec:
 75 |                 print(file)
 76 |                 print("len_phone", len_phone)
 77 |                 print("len_spec", len_spec)
 78 |             assert len_phone < len_spec
 79 |             len_min = min(len_phone, len_spec)
 80 |             len_wav = len_min * self.hop_length
 81 |             # print(wav.size())
 82 |             # print(f"len_min={len_min}")
 83 |             # print(f"len_wav={len_wav}")
 84 |             spec = spec[:, :len_min]
 85 |             wav = wav[:, :len_wav]
 86 |         return (phone, score, pitch, slurs, spec, wav)
 87 | 
 88 |     def get_labels(self, phone, score, pitch, slurs):
 89 |         phone = np.load(phone)
 90 |         score = np.load(score)
 91 |         pitch = np.load(pitch)
 92 |         slurs = np.load(slurs)
 93 |         phone = torch.LongTensor(phone)
 94 |         score = torch.LongTensor(score)
 95 |         pitch = torch.FloatTensor(pitch)
 96 |         slurs = torch.LongTensor(slurs)
 97 |         return phone, score, pitch, slurs
 98 | 
 99 |     def get_audio(self, filename):
100 |         audio, sampling_rate = load_wav_to_torch(filename)
101 |         if sampling_rate != self.sampling_rate:
102 |             raise ValueError(
103 |                 "{} {} SR doesn't match target {} SR".format(
104 |                     sampling_rate, self.sampling_rate
105 |                 )
106 |             )
107 |         audio_norm = audio / self.max_wav_value
108 |         audio_norm = audio_norm.unsqueeze(0)
109 |         spec_filename = filename.replace(".wav", ".spec.pt")
110 |         if os.path.exists(spec_filename):
111 |             spec = torch.load(spec_filename)
112 |         else:
113 |             spec = spectrogram_torch(
114 |                 audio_norm,
115 |                 self.filter_length,
116 |                 self.sampling_rate,
117 |                 self.hop_length,
118 |                 self.win_length,
119 |                 center=False,
120 |             )
121 |             spec = torch.squeeze(spec, 0)
122 |             torch.save(spec, spec_filename)
123 |         return spec, audio_norm
124 | 
125 |     def __getitem__(self, index):
126 |         return self.get_audio_text_pair(self.audiopaths_and_text[index])
127 | 
128 |     def __len__(self):
129 |         return len(self.audiopaths_and_text)
130 | 
131 | 
132 | class TextAudioCollate:
133 |     """Zero-pads model inputs and targets"""
134 | 
135 |     def __init__(self, return_ids=False):
136 |         self.return_ids = return_ids
137 | 
138 |     def __call__(self, batch):
139 |         """Collate's training batch from normalized text and aduio
140 |         PARAMS
141 |         ------
142 |         batch: [text_normalized, spec_normalized, wav_normalized]
143 |         """
144 |         # Right zero-pad all one-hot text sequences to max input length
145 |         _, ids_sorted_decreasing = torch.sort(
146 |             torch.LongTensor([x[4].size(1) for x in batch]), dim=0, descending=True
147 |         )
148 | 
149 |         max_phone_len = max([len(x[0]) for x in batch])
150 |         phone_lengths = torch.LongTensor(len(batch))
151 |         phone_padded = torch.LongTensor(len(batch), max_phone_len)
152 |         score_padded = torch.LongTensor(len(batch), max_phone_len)
153 |         pitch_padded = torch.FloatTensor(len(batch), max_phone_len)
154 |         slurs_padded = torch.LongTensor(len(batch), max_phone_len)
155 |         phone_padded.zero_()
156 |         score_padded.zero_()
157 |         pitch_padded.zero_()
158 |         slurs_padded.zero_()
159 | 
160 |         max_spec_len = max([x[4].size(1) for x in batch])
161 |         max_wave_len = max([x[5].size(1) for x in batch])
162 |         spec_lengths = torch.LongTensor(len(batch))
163 |         wave_lengths = torch.LongTensor(len(batch))
164 |         spec_padded = torch.FloatTensor(len(batch), batch[0][4].size(0), max_spec_len)
165 |         wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
166 |         spec_padded.zero_()
167 |         wave_padded.zero_()
168 | 
169 |         for i in range(len(ids_sorted_decreasing)):
170 |             row = batch[ids_sorted_decreasing[i]]
171 | 
172 |             phone = row[0]
173 |             phone_padded[i, : phone.size(0)] = phone
174 |             phone_lengths[i] = phone.size(0)
175 | 
176 |             score = row[1]
177 |             score_padded[i, : score.size(0)] = score
178 | 
179 |             pitch = row[2]
180 |             pitch_padded[i, : pitch.size(0)] = pitch
181 | 
182 |             slurs = row[3]
183 |             slurs_padded[i, : slurs.size(0)] = slurs
184 | 
185 |             spec = row[4]
186 |             spec_padded[i, :, : spec.size(1)] = spec
187 |             spec_lengths[i] = spec.size(1)
188 | 
189 |             wave = row[5]
190 |             wave_padded[i, :, : wave.size(1)] = wave
191 |             wave_lengths[i] = wave.size(1)
192 | 
193 |         return (
194 |             phone_padded,
195 |             phone_lengths,
196 |             score_padded,
197 |             pitch_padded,
198 |             slurs_padded,
199 |             spec_padded,
200 |             spec_lengths,
201 |             wave_padded,
202 |             wave_lengths,
203 |         )
204 | 
205 | 
206 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
207 |     """
208 |     Maintain similar input lengths in a batch.
209 |     Length groups are specified by boundaries.
210 |     Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
211 | 
212 |     It removes samples which are not included in the boundaries.
213 |     Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
214 |     """
215 | 
216 |     def __init__(
217 |         self,
218 |         dataset,
219 |         batch_size,
220 |         boundaries,
221 |         num_replicas=None,
222 |         rank=None,
223 |         shuffle=True,
224 |     ):
225 |         super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
226 |         self.lengths = dataset.lengths
227 |         self.batch_size = batch_size
228 |         self.boundaries = boundaries
229 | 
230 |         self.buckets, self.num_samples_per_bucket = self._create_buckets()
231 |         self.total_size = sum(self.num_samples_per_bucket)
232 |         self.num_samples = self.total_size // self.num_replicas
233 | 
234 |     def _create_buckets(self):
235 |         buckets = [[] for _ in range(len(self.boundaries) - 1)]
236 |         for i in range(len(self.lengths)):
237 |             length = self.lengths[i]
238 |             idx_bucket = self._bisect(length)
239 |             if idx_bucket != -1:
240 |                 buckets[idx_bucket].append(i)
241 | 
242 |         for i in range(len(buckets) - 1, 0, -1):
243 |             if len(buckets[i]) == 0:
244 |                 buckets.pop(i)
245 |                 self.boundaries.pop(i + 1)
246 | 
247 |         num_samples_per_bucket = []
248 |         for i in range(len(buckets)):
249 |             len_bucket = len(buckets[i])
250 |             total_batch_size = self.num_replicas * self.batch_size
251 |             rem = (
252 |                 total_batch_size - (len_bucket % total_batch_size)
253 |             ) % total_batch_size
254 |             num_samples_per_bucket.append(len_bucket + rem)
255 |         return buckets, num_samples_per_bucket
256 | 
257 |     def __iter__(self):
258 |         # deterministically shuffle based on epoch
259 |         g = torch.Generator()
260 |         g.manual_seed(self.epoch)
261 | 
262 |         indices = []
263 |         if self.shuffle:
264 |             for bucket in self.buckets:
265 |                 indices.append(torch.randperm(len(bucket), generator=g).tolist())
266 |         else:
267 |             for bucket in self.buckets:
268 |                 indices.append(list(range(len(bucket))))
269 | 
270 |         batches = []
271 |         for i in range(len(self.buckets)):
272 |             bucket = self.buckets[i]
273 |             len_bucket = len(bucket)
274 |             if (len_bucket == 0):
275 |                 continue
276 |             ids_bucket = indices[i]
277 |             num_samples_bucket = self.num_samples_per_bucket[i]
278 | 
279 |             # add extra samples to make it evenly divisible
280 |             rem = num_samples_bucket - len_bucket
281 |             ids_bucket = (
282 |                 ids_bucket
283 |                 + ids_bucket * (rem // len_bucket)
284 |                 + ids_bucket[: (rem % len_bucket)]
285 |             )
286 | 
287 |             # subsample
288 |             ids_bucket = ids_bucket[self.rank:: self.num_replicas]
289 | 
290 |             # batching
291 |             for j in range(len(ids_bucket) // self.batch_size):
292 |                 batch = [
293 |                     bucket[idx]
294 |                     for idx in ids_bucket[
295 |                         j * self.batch_size: (j + 1) * self.batch_size
296 |                     ]
297 |                 ]
298 |                 batches.append(batch)
299 | 
300 |         if self.shuffle:
301 |             batch_ids = torch.randperm(len(batches), generator=g).tolist()
302 |             batches = [batches[i] for i in batch_ids]
303 |         self.batches = batches
304 | 
305 |         assert len(self.batches) * self.batch_size == self.num_samples
306 |         return iter(self.batches)
307 | 
308 |     def _bisect(self, x, lo=0, hi=None):
309 |         if hi is None:
310 |             hi = len(self.boundaries) - 1
311 | 
312 |         if hi > lo:
313 |             mid = (hi + lo) // 2
314 |             if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
315 |                 return mid
316 |             elif x <= self.boundaries[mid]:
317 |                 return self._bisect(x, lo, mid)
318 |             else:
319 |                 return self._bisect(x, mid + 1, hi)
320 |         else:
321 |             return -1
322 | 
323 |     def __len__(self):
324 |         return self.num_samples // self.batch_size
325 | 


--------------------------------------------------------------------------------
/vits/losses.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | 
 3 | 
 4 | def feature_loss(fmap_r, fmap_g):
 5 |     loss = 0
 6 |     for dr, dg in zip(fmap_r, fmap_g):
 7 |         for rl, gl in zip(dr, dg):
 8 |             rl = rl.float().detach()
 9 |             gl = gl.float()
10 |             loss += torch.mean(torch.abs(rl - gl))
11 | 
12 |     return loss * 2
13 | 
14 | 
15 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16 |     loss = 0
17 |     r_losses = []
18 |     g_losses = []
19 |     for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20 |         dr = dr.float()
21 |         dg = dg.float()
22 |         r_loss = torch.mean((1 - dr) ** 2)
23 |         g_loss = torch.mean(dg**2)
24 |         loss += r_loss + g_loss
25 |         r_losses.append(r_loss.item())
26 |         g_losses.append(g_loss.item())
27 | 
28 |     return loss, r_losses, g_losses
29 | 
30 | 
31 | def generator_loss(disc_outputs):
32 |     loss = 0
33 |     gen_losses = []
34 |     for dg in disc_outputs:
35 |         dg = dg.float()
36 |         l = torch.mean((1 - dg) ** 2)
37 |         gen_losses.append(l)
38 |         loss += l
39 | 
40 |     return loss, gen_losses
41 | 
42 | 
43 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44 |     """
45 |     z_p, logs_q: [b, h, t_t]
46 |     m_p, logs_p: [b, h, t_t]
47 |     """
48 |     z_p = z_p.float()
49 |     logs_q = logs_q.float()
50 |     m_p = m_p.float()
51 |     logs_p = logs_p.float()
52 |     z_mask = z_mask.float()
53 | 
54 |     kl = logs_p - logs_q - 0.5
55 |     kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56 |     kl = torch.sum(kl * z_mask)
57 |     l = kl / torch.sum(z_mask)
58 |     return l
59 | 


--------------------------------------------------------------------------------
/vits/models.py:
--------------------------------------------------------------------------------
  1 | 
  2 | import torch
  3 | import math
  4 | 
  5 | from torch import nn
  6 | from torch.nn import functional as F
  7 | from vits import attentions
  8 | from vits import commons
  9 | from vits import modules
 10 | from vits.utils import f0_to_coarse
 11 | from vits_decoder.generator import Generator
 12 | 
 13 | 
 14 | class TextEncoder(nn.Module):
 15 |     def __init__(self,
 16 |                  out_channels,
 17 |                  hidden_channels,
 18 |                  filter_channels,
 19 |                  n_heads,
 20 |                  n_layers,
 21 |                  kernel_size,
 22 |                  p_dropout):
 23 |         super().__init__()
 24 |         self.out_channels = out_channels
 25 |         self.hidden_channels = hidden_channels
 26 |         self.emb_phone = nn.Embedding(63, hidden_channels)      # phone lables
 27 |         self.emb_score = nn.Embedding(128, hidden_channels)     # pitch notes
 28 |         self.emb_pitch = nn.Embedding(256, hidden_channels)     # pitch 256
 29 |         self.emb_slurs = nn.Embedding(2, hidden_channels)       # phone slur
 30 |         nn.init.normal_(self.emb_phone.weight, 0.0, hidden_channels**-0.5)
 31 |         nn.init.normal_(self.emb_score.weight, 0.0, hidden_channels**-0.5)
 32 |         nn.init.normal_(self.emb_pitch.weight, 0.0, hidden_channels**-0.5)
 33 |         nn.init.normal_(self.emb_slurs.weight, 0.0, hidden_channels**-0.5)
 34 |         self.enc = attentions.Encoder(
 35 |             hidden_channels,
 36 |             filter_channels,
 37 |             n_heads,
 38 |             n_layers,
 39 |             kernel_size,
 40 |             p_dropout)
 41 |         self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
 42 | 
 43 |     def forward(self, phone, lengths, score, slurs, pitch):
 44 |         x = self.emb_phone(phone) + self.emb_score(score) + self.emb_pitch(pitch) + self.emb_slurs(slurs)
 45 |         x = x * math.sqrt(self.hidden_channels)  # [b, t, h]
 46 |         x = torch.transpose(x, 1, -1)  # [b, h, t]
 47 |         x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
 48 |             x.dtype
 49 |         )
 50 |         x = self.enc(x * x_mask, x_mask)
 51 |         stats = self.proj(x) * x_mask
 52 |         m, logs = torch.split(stats, self.out_channels, dim=1)
 53 |         z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
 54 |         return z, m, logs, x_mask, x
 55 | 
 56 | 
 57 | class ResidualCouplingBlock(nn.Module):
 58 |     def __init__(
 59 |         self,
 60 |         channels,
 61 |         hidden_channels,
 62 |         kernel_size,
 63 |         dilation_rate,
 64 |         n_layers,
 65 |         n_flows=3,
 66 |         gin_channels=0,
 67 |     ):
 68 |         super().__init__()
 69 |         self.flows = nn.ModuleList()
 70 |         for i in range(n_flows):
 71 |             self.flows.append(
 72 |                 modules.ResidualCouplingLayer(
 73 |                     channels,
 74 |                     hidden_channels,
 75 |                     kernel_size,
 76 |                     dilation_rate,
 77 |                     n_layers,
 78 |                     gin_channels=gin_channels,
 79 |                     mean_only=True,
 80 |                 )
 81 |             )
 82 |             self.flows.append(modules.Flip())
 83 | 
 84 |     def forward(self, x, x_mask, g=None, reverse=False):
 85 |         if not reverse:
 86 |             total_logdet = 0
 87 |             for flow in self.flows:
 88 |                 x, log_det = flow(x, x_mask, g=g, reverse=reverse)
 89 |                 total_logdet += log_det
 90 |             return x, total_logdet
 91 |         else:
 92 |             total_logdet = 0
 93 |             for flow in reversed(self.flows):
 94 |                 x, log_det = flow(x, x_mask, g=g, reverse=reverse)
 95 |                 total_logdet += log_det
 96 |             return x, total_logdet
 97 | 
 98 | 
 99 | class PosteriorEncoder(nn.Module):
100 |     def __init__(
101 |         self,
102 |         in_channels,
103 |         out_channels,
104 |         hidden_channels,
105 |         kernel_size,
106 |         dilation_rate,
107 |         n_layers,
108 |         gin_channels=0,
109 |     ):
110 |         super().__init__()
111 |         self.out_channels = out_channels
112 |         self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
113 |         self.enc = modules.WN(
114 |             hidden_channels,
115 |             kernel_size,
116 |             dilation_rate,
117 |             n_layers,
118 |             gin_channels=gin_channels,
119 |         )
120 |         self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
121 | 
122 |     def forward(self, x, x_lengths, g=None):
123 |         x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
124 |             x.dtype
125 |         )
126 |         x = self.pre(x) * x_mask
127 |         x = self.enc(x, x_mask, g=g)
128 |         stats = self.proj(x) * x_mask
129 |         m, logs = torch.split(stats, self.out_channels, dim=1)
130 |         z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
131 |         return z, m, logs, x_mask
132 | 
133 | 
134 | class SynthesizerTrn(nn.Module):
135 |     def __init__(
136 |         self,
137 |         spec_channels,
138 |         segment_size,
139 |         hp
140 |     ):
141 |         super().__init__()
142 |         self.segment_size = segment_size
143 |         self.enc_p = TextEncoder(
144 |             hp.vits.inter_channels,
145 |             hp.vits.hidden_channels,
146 |             hp.vits.filter_channels,
147 |             2,
148 |             6,
149 |             3,
150 |             0.1,
151 |         )
152 |         self.enc_q = PosteriorEncoder(
153 |             spec_channels,
154 |             hp.vits.inter_channels,
155 |             hp.vits.hidden_channels,
156 |             5,
157 |             1,
158 |             16,
159 |             gin_channels=hp.vits.gin_channels,
160 |         )
161 |         self.flow = ResidualCouplingBlock(
162 |             hp.vits.inter_channels,
163 |             hp.vits.hidden_channels,
164 |             5,
165 |             1,
166 |             4,
167 |             gin_channels=hp.vits.gin_channels
168 |         )
169 |         self.dec = Generator(hp=hp)
170 | 
171 |     def forward(self, phone, phone_l, score, pitch, slurs, spec, spec_l):
172 | 
173 |         z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
174 |             phone, phone_l, score, slurs, f0_to_coarse(pitch))
175 |         z_q, m_q, logs_q, spec_mask = self.enc_q(spec, spec_l)
176 | 
177 |         z_slice, pit_slice, ids_slice = commons.rand_slice_segments_with_pitch(
178 |             z_q, pitch, spec_l, self.segment_size)
179 |         audio = self.dec(z_slice, pit_slice)
180 | 
181 |         # SNAC to flow
182 |         z_f, logdet_f = self.flow(z_q, spec_mask)
183 |         z_r, logdet_r = self.flow(z_p, spec_mask, reverse=True)
184 |         return audio, ids_slice, spec_mask, (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r)
185 | 
186 |     def infer(self, phone, phone_l, score, pitch, slurs):
187 |         z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
188 |             phone, phone_l, score, slurs, f0_to_coarse(pitch))
189 |         z, _ = self.flow(z_p, ppg_mask, reverse=True)
190 |         o = self.dec(z * ppg_mask, pitch)
191 |         return o
192 | 


--------------------------------------------------------------------------------
/vits/modules.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | from torch import nn
  3 | from vits import commons
  4 | 
  5 | 
  6 | class WN(torch.nn.Module):
  7 |     def __init__(
  8 |         self,
  9 |         hidden_channels,
 10 |         kernel_size,
 11 |         dilation_rate,
 12 |         n_layers,
 13 |         gin_channels=0,
 14 |         p_dropout=0,
 15 |     ):
 16 |         super(WN, self).__init__()
 17 |         assert kernel_size % 2 == 1
 18 |         self.hidden_channels = hidden_channels
 19 |         self.kernel_size = (kernel_size,)
 20 |         self.dilation_rate = dilation_rate
 21 |         self.n_layers = n_layers
 22 |         self.gin_channels = gin_channels
 23 |         self.p_dropout = p_dropout
 24 | 
 25 |         self.in_layers = torch.nn.ModuleList()
 26 |         self.res_skip_layers = torch.nn.ModuleList()
 27 |         self.drop = nn.Dropout(p_dropout)
 28 | 
 29 |         if gin_channels != 0:
 30 |             cond_layer = torch.nn.Conv1d(
 31 |                 gin_channels, 2 * hidden_channels * n_layers, 1
 32 |             )
 33 |             self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
 34 | 
 35 |         for i in range(n_layers):
 36 |             dilation = dilation_rate**i
 37 |             padding = int((kernel_size * dilation - dilation) / 2)
 38 |             in_layer = torch.nn.Conv1d(
 39 |                 hidden_channels,
 40 |                 2 * hidden_channels,
 41 |                 kernel_size,
 42 |                 dilation=dilation,
 43 |                 padding=padding,
 44 |             )
 45 |             in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
 46 |             self.in_layers.append(in_layer)
 47 | 
 48 |             # last one is not necessary
 49 |             if i < n_layers - 1:
 50 |                 res_skip_channels = 2 * hidden_channels
 51 |             else:
 52 |                 res_skip_channels = hidden_channels
 53 | 
 54 |             res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
 55 |             res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
 56 |             self.res_skip_layers.append(res_skip_layer)
 57 | 
 58 |     def forward(self, x, x_mask, g=None, **kwargs):
 59 |         output = torch.zeros_like(x)
 60 |         n_channels_tensor = torch.IntTensor([self.hidden_channels])
 61 | 
 62 |         if g is not None:
 63 |             g = self.cond_layer(g)
 64 | 
 65 |         for i in range(self.n_layers):
 66 |             x_in = self.in_layers[i](x)
 67 |             if g is not None:
 68 |                 cond_offset = i * 2 * self.hidden_channels
 69 |                 g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
 70 |             else:
 71 |                 g_l = torch.zeros_like(x_in)
 72 | 
 73 |             acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
 74 |             acts = self.drop(acts)
 75 | 
 76 |             res_skip_acts = self.res_skip_layers[i](acts)
 77 |             if i < self.n_layers - 1:
 78 |                 res_acts = res_skip_acts[:, : self.hidden_channels, :]
 79 |                 x = (x + res_acts) * x_mask
 80 |                 output = output + res_skip_acts[:, self.hidden_channels:, :]
 81 |             else:
 82 |                 output = output + res_skip_acts
 83 |         return output * x_mask
 84 | 
 85 |     def remove_weight_norm(self):
 86 |         if self.gin_channels != 0:
 87 |             torch.nn.utils.remove_weight_norm(self.cond_layer)
 88 |         for l in self.in_layers:
 89 |             torch.nn.utils.remove_weight_norm(l)
 90 |         for l in self.res_skip_layers:
 91 |             torch.nn.utils.remove_weight_norm(l)
 92 | 
 93 | 
 94 | class Flip(nn.Module):
 95 |     def forward(self, x, *args, reverse=False, **kwargs):
 96 |         x = torch.flip(x, [1])
 97 |         logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
 98 |         return x, logdet
 99 | 
100 | 
101 | class ResidualCouplingLayer(nn.Module):
102 |     def __init__(
103 |         self,
104 |         channels,
105 |         hidden_channels,
106 |         kernel_size,
107 |         dilation_rate,
108 |         n_layers,
109 |         p_dropout=0,
110 |         gin_channels=0,
111 |         mean_only=False,
112 |     ):
113 |         assert channels % 2 == 0, "channels should be divisible by 2"
114 |         super().__init__()
115 |         self.channels = channels
116 |         self.hidden_channels = hidden_channels
117 |         self.kernel_size = kernel_size
118 |         self.dilation_rate = dilation_rate
119 |         self.n_layers = n_layers
120 |         self.half_channels = channels // 2
121 |         self.mean_only = mean_only
122 | 
123 |         self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
124 |         self.enc = WN(
125 |             hidden_channels,
126 |             kernel_size,
127 |             dilation_rate,
128 |             n_layers,
129 |             p_dropout=p_dropout,
130 |             gin_channels=gin_channels,
131 |         )
132 |         self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
133 |         self.post.weight.data.zero_()
134 |         self.post.bias.data.zero_()
135 | 
136 |     def forward(self, x, x_mask, g=None, reverse=False):
137 |         x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
138 |         h = self.pre(x0) * x_mask
139 |         h = self.enc(h, x_mask, g=g)
140 |         stats = self.post(h) * x_mask
141 |         if not self.mean_only:
142 |             m, logs = torch.split(stats, [self.half_channels] * 2, 1)
143 |         else:
144 |             m = stats
145 |             logs = torch.zeros_like(m)
146 | 
147 |         if not reverse:
148 |             x1 = m + x1 * torch.exp(logs) * x_mask
149 |             x = torch.cat([x0, x1], 1)
150 |             logdet = torch.sum(logs, [1, 2])
151 |             return x, logdet
152 |         else:
153 |             x1 = (x1 - m) * torch.exp(-logs) * x_mask
154 |             x = torch.cat([x0, x1], 1)
155 |             logdet = torch.sum(logs, [1, 2])
156 |             return x, logdet
157 | 


--------------------------------------------------------------------------------
/vits/spectrogram.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.utils.data
  3 | 
  4 | from librosa.filters import mel as librosa_mel_fn
  5 | 
  6 | MAX_WAV_VALUE = 32768.0
  7 | 
  8 | 
  9 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
 10 |     """
 11 |     PARAMS
 12 |     ------
 13 |     C: compression factor
 14 |     """
 15 |     return torch.log(torch.clamp(x, min=clip_val) * C)
 16 | 
 17 | 
 18 | def dynamic_range_decompression_torch(x, C=1):
 19 |     """
 20 |     PARAMS
 21 |     ------
 22 |     C: compression factor used to compress
 23 |     """
 24 |     return torch.exp(x) / C
 25 | 
 26 | 
 27 | def spectral_normalize_torch(magnitudes):
 28 |     output = dynamic_range_compression_torch(magnitudes)
 29 |     return output
 30 | 
 31 | 
 32 | def spectral_de_normalize_torch(magnitudes):
 33 |     output = dynamic_range_decompression_torch(magnitudes)
 34 |     return output
 35 | 
 36 | 
 37 | mel_basis = {}
 38 | hann_window = {}
 39 | 
 40 | 
 41 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
 42 |     if torch.min(y) < -1.0:
 43 |         print("min value is ", torch.min(y))
 44 |     if torch.max(y) > 1.0:
 45 |         print("max value is ", torch.max(y))
 46 | 
 47 |     global hann_window
 48 |     dtype_device = str(y.dtype) + "_" + str(y.device)
 49 |     wnsize_dtype_device = str(win_size) + "_" + dtype_device
 50 |     if wnsize_dtype_device not in hann_window:
 51 |         hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
 52 |             dtype=y.dtype, device=y.device
 53 |         )
 54 | 
 55 |     y = torch.nn.functional.pad(
 56 |         y.unsqueeze(1),
 57 |         (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
 58 |         mode="reflect",
 59 |     )
 60 |     y = y.squeeze(1)
 61 | 
 62 |     spec = torch.stft(
 63 |         y,
 64 |         n_fft,
 65 |         hop_length=hop_size,
 66 |         win_length=win_size,
 67 |         window=hann_window[wnsize_dtype_device],
 68 |         center=center,
 69 |         pad_mode="reflect",
 70 |         normalized=False,
 71 |         onesided=True,
 72 |         return_complex=False,
 73 |     )
 74 | 
 75 |     spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
 76 |     return spec
 77 | 
 78 | 
 79 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
 80 |     global mel_basis
 81 |     dtype_device = str(spec.dtype) + "_" + str(spec.device)
 82 |     fmax_dtype_device = str(fmax) + "_" + dtype_device
 83 |     if fmax_dtype_device not in mel_basis:
 84 |         mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
 85 |         mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
 86 |             dtype=spec.dtype, device=spec.device
 87 |         )
 88 |     spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
 89 |     spec = spectral_normalize_torch(spec)
 90 |     return spec
 91 | 
 92 | 
 93 | def mel_spectrogram_torch(
 94 |     y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
 95 | ):
 96 |     if torch.min(y) < -1.0:
 97 |         print("min value is ", torch.min(y))
 98 |     if torch.max(y) > 1.0:
 99 |         print("max value is ", torch.max(y))
100 | 
101 |     global mel_basis, hann_window
102 |     dtype_device = str(y.dtype) + "_" + str(y.device)
103 |     fmax_dtype_device = str(fmax) + "_" + dtype_device
104 |     wnsize_dtype_device = str(win_size) + "_" + dtype_device
105 |     if fmax_dtype_device not in mel_basis:
106 |         mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
107 |         mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
108 |             dtype=y.dtype, device=y.device
109 |         )
110 |     if wnsize_dtype_device not in hann_window:
111 |         hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
112 |             dtype=y.dtype, device=y.device
113 |         )
114 | 
115 |     y = torch.nn.functional.pad(
116 |         y.unsqueeze(1),
117 |         (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
118 |         mode="reflect",
119 |     )
120 |     y = y.squeeze(1)
121 | 
122 |     spec = torch.stft(
123 |         y,
124 |         n_fft,
125 |         hop_length=hop_size,
126 |         win_length=win_size,
127 |         window=hann_window[wnsize_dtype_device],
128 |         center=center,
129 |         pad_mode="reflect",
130 |         normalized=False,
131 |         onesided=True,
132 |         return_complex=False,
133 |     )
134 | 
135 |     spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
136 | 
137 |     spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
138 |     spec = spectral_normalize_torch(spec)
139 | 
140 |     return spec
141 | 


--------------------------------------------------------------------------------
/vits/utils.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import numpy as np
 3 | from scipy.io.wavfile import read
 4 | 
 5 | MATPLOTLIB_FLAG = False
 6 | 
 7 | 
 8 | def load_wav_to_torch(full_path):
 9 |     sampling_rate, data = read(full_path)
10 |     return torch.FloatTensor(data.astype(np.float32)), sampling_rate
11 | 
12 | 
13 | f0_bin = 256
14 | f0_max = 1100.0
15 | f0_min = 50.0
16 | f0_mel_min = 1127 * np.log(1 + f0_min / 700)
17 | f0_mel_max = 1127 * np.log(1 + f0_max / 700)
18 | 
19 | 
20 | def f0_to_coarse(f0):
21 |     is_torch = isinstance(f0, torch.Tensor)
22 |     f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * \
23 |         np.log(1 + f0 / 700)
24 |     f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * \
25 |         (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
26 | 
27 |     f0_mel[f0_mel <= 1] = 1
28 |     f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
29 |     f0_coarse = (
30 |         f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
31 |     assert f0_coarse.max() <= 255 and f0_coarse.min(
32 |     ) >= 1, (f0_coarse.max(), f0_coarse.min())
33 |     return f0_coarse


--------------------------------------------------------------------------------
/vits_decoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .alias.act import SnakeAlias


--------------------------------------------------------------------------------
/vits_decoder/alias/__init__.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | #   LICENSE is in incl_licenses directory.
3 | 
4 | from .filter import *
5 | from .resample import *
6 | from .act import *


--------------------------------------------------------------------------------
/vits_decoder/alias/act.py:
--------------------------------------------------------------------------------
  1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
  2 | #   LICENSE is in incl_licenses directory.
  3 | 
  4 | import torch
  5 | import torch.nn as nn
  6 | import torch.nn.functional as F
  7 | 
  8 | from torch import sin, pow
  9 | from torch.nn import Parameter
 10 | from .resample import UpSample1d, DownSample1d
 11 | 
 12 | 
 13 | class Activation1d(nn.Module):
 14 |     def __init__(self,
 15 |                  activation,
 16 |                  up_ratio: int = 2,
 17 |                  down_ratio: int = 2,
 18 |                  up_kernel_size: int = 12,
 19 |                  down_kernel_size: int = 12):
 20 |         super().__init__()
 21 |         self.up_ratio = up_ratio
 22 |         self.down_ratio = down_ratio
 23 |         self.act = activation
 24 |         self.upsample = UpSample1d(up_ratio, up_kernel_size)
 25 |         self.downsample = DownSample1d(down_ratio, down_kernel_size)
 26 | 
 27 |     # x: [B,C,T]
 28 |     def forward(self, x):
 29 |         x = self.upsample(x)
 30 |         x = self.act(x)
 31 |         x = self.downsample(x)
 32 | 
 33 |         return x
 34 | 
 35 | 
 36 | class SnakeBeta(nn.Module):
 37 |     '''
 38 |     A modified Snake function which uses separate parameters for the magnitude of the periodic components
 39 |     Shape:
 40 |         - Input: (B, C, T)
 41 |         - Output: (B, C, T), same shape as the input
 42 |     Parameters:
 43 |         - alpha - trainable parameter that controls frequency
 44 |         - beta - trainable parameter that controls magnitude
 45 |     References:
 46 |         - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
 47 |         https://arxiv.org/abs/2006.08195
 48 |     Examples:
 49 |         >>> a1 = snakebeta(256)
 50 |         >>> x = torch.randn(256)
 51 |         >>> x = a1(x)
 52 |     '''
 53 | 
 54 |     def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
 55 |         '''
 56 |         Initialization.
 57 |         INPUT:
 58 |             - in_features: shape of the input
 59 |             - alpha - trainable parameter that controls frequency
 60 |             - beta - trainable parameter that controls magnitude
 61 |             alpha is initialized to 1 by default, higher values = higher-frequency.
 62 |             beta is initialized to 1 by default, higher values = higher-magnitude.
 63 |             alpha will be trained along with the rest of your model.
 64 |         '''
 65 |         super(SnakeBeta, self).__init__()
 66 |         self.in_features = in_features
 67 |         # initialize alpha
 68 |         self.alpha_logscale = alpha_logscale
 69 |         if self.alpha_logscale:  # log scale alphas initialized to zeros
 70 |             self.alpha = Parameter(torch.zeros(in_features) * alpha)
 71 |             self.beta = Parameter(torch.zeros(in_features) * alpha)
 72 |         else:  # linear scale alphas initialized to ones
 73 |             self.alpha = Parameter(torch.ones(in_features) * alpha)
 74 |             self.beta = Parameter(torch.ones(in_features) * alpha)
 75 |         self.alpha.requires_grad = alpha_trainable
 76 |         self.beta.requires_grad = alpha_trainable
 77 |         self.no_div_by_zero = 0.000000001
 78 | 
 79 |     def forward(self, x):
 80 |         '''
 81 |         Forward pass of the function.
 82 |         Applies the function to the input elementwise.
 83 |         SnakeBeta = x + 1/b * sin^2 (xa)
 84 |         '''
 85 |         alpha = self.alpha.unsqueeze(
 86 |             0).unsqueeze(-1)  # line up with x to [B, C, T]
 87 |         beta = self.beta.unsqueeze(0).unsqueeze(-1)
 88 |         if self.alpha_logscale:
 89 |             alpha = torch.exp(alpha)
 90 |             beta = torch.exp(beta)
 91 |         x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
 92 |         return x
 93 | 
 94 | 
 95 | class Mish(nn.Module):
 96 |     """
 97 |     Mish activation function is proposed in "Mish: A Self 
 98 |     Regularized Non-Monotonic Neural Activation Function" 
 99 |     paper, https://arxiv.org/abs/1908.08681.
100 |     """
101 | 
102 |     def __init__(self):
103 |         super().__init__()
104 | 
105 |     def forward(self, x):
106 |         return x * torch.tanh(F.softplus(x))
107 | 
108 | 
109 | class SnakeAlias(nn.Module):
110 |     def __init__(self,
111 |                  channels,
112 |                  up_ratio: int = 2,
113 |                  down_ratio: int = 2,
114 |                  up_kernel_size: int = 12,
115 |                  down_kernel_size: int = 12):
116 |         super().__init__()
117 |         self.up_ratio = up_ratio
118 |         self.down_ratio = down_ratio
119 |         self.act = SnakeBeta(channels, alpha_logscale=True)
120 |         self.upsample = UpSample1d(up_ratio, up_kernel_size)
121 |         self.downsample = DownSample1d(down_ratio, down_kernel_size)
122 | 
123 |     # x: [B,C,T]
124 |     def forward(self, x):
125 |         x = self.upsample(x)
126 |         x = self.act(x)
127 |         x = self.downsample(x)
128 | 
129 |         return x


--------------------------------------------------------------------------------
/vits_decoder/alias/filter.py:
--------------------------------------------------------------------------------
 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
 2 | #   LICENSE is in incl_licenses directory.
 3 | 
 4 | import torch
 5 | import torch.nn as nn
 6 | import torch.nn.functional as F
 7 | import math
 8 | 
 9 | if 'sinc' in dir(torch):
10 |     sinc = torch.sinc
11 | else:
12 |     # This code is adopted from adefossez's julius.core.sinc under the MIT License
13 |     # https://adefossez.github.io/julius/julius/core.html
14 |     #   LICENSE is in incl_licenses directory.
15 |     def sinc(x: torch.Tensor):
16 |         """
17 |         Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18 |         __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19 |         """
20 |         return torch.where(x == 0,
21 |                            torch.tensor(1., device=x.device, dtype=x.dtype),
22 |                            torch.sin(math.pi * x) / math.pi / x)
23 | 
24 | 
25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26 | # https://adefossez.github.io/julius/julius/lowpass.html
27 | #   LICENSE is in incl_licenses directory.
28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29 |     even = (kernel_size % 2 == 0)
30 |     half_size = kernel_size // 2
31 | 
32 |     #For kaiser window
33 |     delta_f = 4 * half_width
34 |     A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35 |     if A > 50.:
36 |         beta = 0.1102 * (A - 8.7)
37 |     elif A >= 21.:
38 |         beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39 |     else:
40 |         beta = 0.
41 |     window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42 | 
43 |     # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44 |     if even:
45 |         time = (torch.arange(-half_size, half_size) + 0.5)
46 |     else:
47 |         time = torch.arange(kernel_size) - half_size
48 |     if cutoff == 0:
49 |         filter_ = torch.zeros_like(time)
50 |     else:
51 |         filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52 |         # Normalize filter to have sum = 1, otherwise we will have a small leakage
53 |         # of the constant component in the input signal.
54 |         filter_ /= filter_.sum()
55 |         filter = filter_.view(1, 1, kernel_size)
56 | 
57 |     return filter
58 | 
59 | 
60 | class LowPassFilter1d(nn.Module):
61 |     def __init__(self,
62 |                  cutoff=0.5,
63 |                  half_width=0.6,
64 |                  stride: int = 1,
65 |                  padding: bool = True,
66 |                  padding_mode: str = 'replicate',
67 |                  kernel_size: int = 12):
68 |         # kernel_size should be even number for stylegan3 setup,
69 |         # in this implementation, odd number is also possible.
70 |         super().__init__()
71 |         if cutoff < -0.:
72 |             raise ValueError("Minimum cutoff must be larger than zero.")
73 |         if cutoff > 0.5:
74 |             raise ValueError("A cutoff above 0.5 does not make sense.")
75 |         self.kernel_size = kernel_size
76 |         self.even = (kernel_size % 2 == 0)
77 |         self.pad_left = kernel_size // 2 - int(self.even)
78 |         self.pad_right = kernel_size // 2
79 |         self.stride = stride
80 |         self.padding = padding
81 |         self.padding_mode = padding_mode
82 |         filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83 |         self.register_buffer("filter", filter)
84 | 
85 |     #input [B, C, T]
86 |     def forward(self, x):
87 |         _, C, _ = x.shape
88 | 
89 |         if self.padding:
90 |             x = F.pad(x, (self.pad_left, self.pad_right),
91 |                       mode=self.padding_mode)
92 |         out = F.conv1d(x, self.filter.expand(C, -1, -1),
93 |                        stride=self.stride, groups=C)
94 | 
95 |         return out


--------------------------------------------------------------------------------
/vits_decoder/alias/resample.py:
--------------------------------------------------------------------------------
 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
 2 | #   LICENSE is in incl_licenses directory.
 3 | 
 4 | import torch.nn as nn
 5 | from torch.nn import functional as F
 6 | from .filter import LowPassFilter1d
 7 | from .filter import kaiser_sinc_filter1d
 8 | 
 9 | 
10 | class UpSample1d(nn.Module):
11 |     def __init__(self, ratio=2, kernel_size=None):
12 |         super().__init__()
13 |         self.ratio = ratio
14 |         self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15 |         self.stride = ratio
16 |         self.pad = self.kernel_size // ratio - 1
17 |         self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18 |         self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19 |         filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20 |                                       half_width=0.6 / ratio,
21 |                                       kernel_size=self.kernel_size)
22 |         self.register_buffer("filter", filter)
23 | 
24 |     # x: [B, C, T]
25 |     def forward(self, x):
26 |         _, C, _ = x.shape
27 | 
28 |         x = F.pad(x, (self.pad, self.pad), mode='replicate')
29 |         x = self.ratio * F.conv_transpose1d(
30 |             x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31 |         x = x[..., self.pad_left:-self.pad_right]
32 | 
33 |         return x
34 | 
35 | 
36 | class DownSample1d(nn.Module):
37 |     def __init__(self, ratio=2, kernel_size=None):
38 |         super().__init__()
39 |         self.ratio = ratio
40 |         self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41 |         self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42 |                                        half_width=0.6 / ratio,
43 |                                        stride=ratio,
44 |                                        kernel_size=self.kernel_size)
45 | 
46 |     def forward(self, x):
47 |         xx = self.lowpass(x)
48 | 
49 |         return xx


--------------------------------------------------------------------------------
/vits_decoder/bigv.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | 
 4 | from torch.nn import Conv1d
 5 | from torch.nn.utils import weight_norm, remove_weight_norm
 6 | from .alias.act import SnakeAlias
 7 | 
 8 | 
 9 | def init_weights(m, mean=0.0, std=0.01):
10 |     classname = m.__class__.__name__
11 |     if classname.find("Conv") != -1:
12 |         m.weight.data.normal_(mean, std)
13 | 
14 | 
15 | def get_padding(kernel_size, dilation=1):
16 |     return int((kernel_size*dilation - dilation)/2)
17 | 
18 | 
19 | class AMPBlock(torch.nn.Module):
20 |     def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
21 |         super(AMPBlock, self).__init__()
22 |         self.convs1 = nn.ModuleList([
23 |             weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
24 |                                padding=get_padding(kernel_size, dilation[0]))),
25 |             weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
26 |                                padding=get_padding(kernel_size, dilation[1]))),
27 |             weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
28 |                                padding=get_padding(kernel_size, dilation[2])))
29 |         ])
30 |         self.convs1.apply(init_weights)
31 | 
32 |         self.convs2 = nn.ModuleList([
33 |             weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
34 |                                padding=get_padding(kernel_size, 1))),
35 |             weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
36 |                                padding=get_padding(kernel_size, 1))),
37 |             weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
38 |                                padding=get_padding(kernel_size, 1)))
39 |         ])
40 |         self.convs2.apply(init_weights)
41 | 
42 |         # total number of conv layers
43 |         self.num_layers = len(self.convs1) + len(self.convs2)
44 | 
45 |         # periodic nonlinearity with snakebeta function and anti-aliasing
46 |         self.activations = nn.ModuleList([
47 |             SnakeAlias(channels) for _ in range(self.num_layers)
48 |         ])
49 | 
50 |     def forward(self, x):
51 |         acts1, acts2 = self.activations[::2], self.activations[1::2]
52 |         for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
53 |             xt = a1(x)
54 |             xt = c1(xt)
55 |             xt = a2(xt)
56 |             xt = c2(xt)
57 |             x = xt + x
58 |         return x
59 | 
60 |     def remove_weight_norm(self):
61 |         for l in self.convs1:
62 |             remove_weight_norm(l)
63 |         for l in self.convs2:
64 |             remove_weight_norm(l)


--------------------------------------------------------------------------------
/vits_decoder/discriminator.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | 
 4 | from omegaconf import OmegaConf
 5 | from .msd import ScaleDiscriminator
 6 | from .mpd import MultiPeriodDiscriminator
 7 | from .mrd import MultiResolutionDiscriminator
 8 | 
 9 | 
10 | class Discriminator(nn.Module):
11 |     def __init__(self, hp):
12 |         super(Discriminator, self).__init__()
13 |         self.MRD = MultiResolutionDiscriminator(hp)
14 |         self.MPD = MultiPeriodDiscriminator(hp)
15 |         self.MSD = ScaleDiscriminator()
16 | 
17 |     def forward(self, x):
18 |         r = self.MRD(x)
19 |         p = self.MPD(x)
20 |         s = self.MSD(x)
21 |         return r + p + s
22 | 
23 | 
24 | if __name__ == '__main__':
25 |     hp = OmegaConf.load('../config/base.yaml')
26 |     model = Discriminator(hp)
27 | 
28 |     x = torch.randn(3, 1, 16384)
29 |     print(x.shape)
30 | 
31 |     output = model(x)
32 |     for features, score in output:
33 |         for feat in features:
34 |             print(feat.shape)
35 |         print(score.shape)
36 | 
37 |     pytorch_total_params = sum(p.numel()
38 |                                for p in model.parameters() if p.requires_grad)
39 |     print(pytorch_total_params)
40 | 


--------------------------------------------------------------------------------
/vits_decoder/generator.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn as nn
  3 | import torch.nn.functional as F
  4 | import numpy as np
  5 | 
  6 | from torch.nn import Conv1d
  7 | from torch.nn import ConvTranspose1d
  8 | from torch.nn.utils import weight_norm
  9 | from torch.nn.utils import remove_weight_norm
 10 | 
 11 | from .nsf import SourceModuleHnNSF
 12 | from .bigv import init_weights, AMPBlock, SnakeAlias
 13 | 
 14 | 
 15 | class Generator(torch.nn.Module):
 16 |     # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
 17 |     def __init__(self, hp):
 18 |         super(Generator, self).__init__()
 19 |         self.hp = hp
 20 |         self.num_kernels = len(hp.gen.resblock_kernel_sizes)
 21 |         self.num_upsamples = len(hp.gen.upsample_rates)
 22 |         # pre conv
 23 |         self.conv_pre = Conv1d(hp.gen.upsample_input,
 24 |                                hp.gen.upsample_initial_channel, 7, 1, padding=3)
 25 |         # nsf
 26 |         self.f0_upsamp = torch.nn.Upsample(
 27 |             scale_factor=np.prod(hp.gen.upsample_rates))
 28 |         self.m_source = SourceModuleHnNSF(sampling_rate=hp.data.sampling_rate)
 29 |         self.noise_convs = nn.ModuleList()
 30 |         # transposed conv-based upsamplers. does not apply anti-aliasing
 31 |         self.ups = nn.ModuleList()
 32 |         for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)):
 33 |             # print(f'ups: {i} {k}, {u}, {(k - u) // 2}')
 34 |             # base
 35 |             self.ups.append(
 36 |                 weight_norm(
 37 |                     ConvTranspose1d(
 38 |                         hp.gen.upsample_initial_channel // (2 ** i),
 39 |                         hp.gen.upsample_initial_channel // (2 ** (i + 1)),
 40 |                         k,
 41 |                         u,
 42 |                         padding=(k - u) // 2)
 43 |                 )
 44 |             )
 45 |             # nsf
 46 |             if i + 1 < len(hp.gen.upsample_rates):
 47 |                 stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:])
 48 |                 stride_f0 = int(stride_f0)
 49 |                 self.noise_convs.append(
 50 |                     Conv1d(
 51 |                         1,
 52 |                         hp.gen.upsample_initial_channel // (2 ** (i + 1)),
 53 |                         kernel_size=stride_f0 * 2,
 54 |                         stride=stride_f0,
 55 |                         padding=stride_f0 // 2,
 56 |                     )
 57 |                 )
 58 |             else:
 59 |                 self.noise_convs.append(
 60 |                     Conv1d(1, hp.gen.upsample_initial_channel //
 61 |                            (2 ** (i + 1)), kernel_size=1)
 62 |                 )
 63 | 
 64 |         # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
 65 |         self.resblocks = nn.ModuleList()
 66 |         for i in range(len(self.ups)):
 67 |             ch = hp.gen.upsample_initial_channel // (2 ** (i + 1))
 68 |             for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes):
 69 |                 self.resblocks.append(AMPBlock(ch, k, d))
 70 | 
 71 |         # post conv
 72 |         self.activation_post = SnakeAlias(ch)
 73 |         self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
 74 |         # weight initialization
 75 |         self.ups.apply(init_weights)
 76 | 
 77 |     def forward(self, x, f0):
 78 |         # Perturbation
 79 |         x = x + torch.randn_like(x)
 80 |         x = self.conv_pre(x)
 81 |         x = x * torch.tanh(F.softplus(x))
 82 |         # nsf
 83 |         f0 = f0[:, None]
 84 |         f0 = self.f0_upsamp(f0).transpose(1, 2)
 85 |         har_source = self.m_source(f0)
 86 |         har_source = har_source.transpose(1, 2)
 87 | 
 88 |         for i in range(self.num_upsamples):
 89 |             # upsampling
 90 |             x = self.ups[i](x)
 91 |             # nsf
 92 |             x_source = self.noise_convs[i](har_source)
 93 |             x = x + x_source
 94 |             # AMP blocks
 95 |             xs = None
 96 |             for j in range(self.num_kernels):
 97 |                 if xs is None:
 98 |                     xs = self.resblocks[i * self.num_kernels + j](x)
 99 |                 else:
100 |                     xs += self.resblocks[i * self.num_kernels + j](x)
101 |             x = xs / self.num_kernels
102 | 
103 |         # post conv
104 |         x = self.activation_post(x)
105 |         x = self.conv_post(x)
106 |         x = torch.tanh(x)
107 |         return x
108 | 
109 |     def remove_weight_norm(self):
110 |         for l in self.ups:
111 |             remove_weight_norm(l)
112 |         for l in self.resblocks:
113 |             l.remove_weight_norm()
114 | 
115 |     def eval(self, inference=False):
116 |         super(Generator, self).eval()
117 |         # don't remove weight norm while validation in training loop
118 |         if inference:
119 |             self.remove_weight_norm()
120 | 
121 |     def pitch2source(self, f0):
122 |         f0 = f0[:, None]
123 |         f0 = self.f0_upsamp(f0).transpose(1, 2)  # [1,len,1]
124 |         har_source = self.m_source(f0)
125 |         har_source = har_source.transpose(1, 2)  # [1,1,len]
126 |         return har_source
127 | 
128 |     def source2wav(self, audio):
129 |         MAX_WAV_VALUE = 32768.0
130 |         audio = audio.squeeze()
131 |         audio = MAX_WAV_VALUE * audio
132 |         audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
133 |         audio = audio.short()
134 |         return audio.cpu().detach().numpy()
135 | 
136 |     def inference(self, x, har_source):
137 |         # Perturbation
138 |         x = x + torch.randn_like(x) * 0.1
139 |         x = self.conv_pre(x)
140 |         x = x * torch.tanh(F.softplus(x))
141 | 
142 |         for i in range(self.num_upsamples):
143 |             # upsampling
144 |             x = self.ups[i](x)
145 |             # nsf
146 |             x_source = self.noise_convs[i](har_source)
147 |             x = x + x_source
148 |             # AMP blocks
149 |             xs = None
150 |             for j in range(self.num_kernels):
151 |                 if xs is None:
152 |                     xs = self.resblocks[i * self.num_kernels + j](x)
153 |                 else:
154 |                     xs += self.resblocks[i * self.num_kernels + j](x)
155 |             x = xs / self.num_kernels
156 | 
157 |         # post conv
158 |         x = self.activation_post(x)
159 |         x = self.conv_post(x)
160 |         x = torch.tanh(x)
161 |         return x
162 | 


--------------------------------------------------------------------------------
/vits_decoder/mpd.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | import torch.nn.functional as F
 4 | from torch.nn.utils import weight_norm, spectral_norm
 5 | 
 6 | class DiscriminatorP(nn.Module):
 7 |     def __init__(self, hp, period):
 8 |         super(DiscriminatorP, self).__init__()
 9 | 
10 |         self.LRELU_SLOPE = hp.mpd.lReLU_slope
11 |         self.period = period
12 | 
13 |         kernel_size = hp.mpd.kernel_size
14 |         stride = hp.mpd.stride
15 |         norm_f = weight_norm if hp.mpd.use_spectral_norm == False else spectral_norm
16 | 
17 |         self.convs = nn.ModuleList([
18 |             norm_f(nn.Conv2d(1, 64, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
19 |             norm_f(nn.Conv2d(64, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
20 |             norm_f(nn.Conv2d(128, 256, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
21 |             norm_f(nn.Conv2d(256, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
22 |             norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), 1, padding=(kernel_size // 2, 0))),
23 |         ])
24 |         self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
25 | 
26 |     def forward(self, x):
27 |         fmap = []
28 | 
29 |         # 1d to 2d
30 |         b, c, t = x.shape
31 |         if t % self.period != 0: # pad first
32 |             n_pad = self.period - (t % self.period)
33 |             x = F.pad(x, (0, n_pad), "reflect")
34 |             t = t + n_pad
35 |         x = x.view(b, c, t // self.period, self.period)
36 | 
37 |         for l in self.convs:
38 |             x = l(x)
39 |             x = F.leaky_relu(x, self.LRELU_SLOPE)
40 |             fmap.append(x)
41 |         x = self.conv_post(x)
42 |         fmap.append(x)
43 |         x = torch.flatten(x, 1, -1)
44 | 
45 |         return fmap, x
46 | 
47 | 
48 | class MultiPeriodDiscriminator(nn.Module):
49 |     def __init__(self, hp):
50 |         super(MultiPeriodDiscriminator, self).__init__()
51 | 
52 |         self.discriminators = nn.ModuleList(
53 |             [DiscriminatorP(hp, period) for period in hp.mpd.periods]
54 |         )
55 | 
56 |     def forward(self, x):
57 |         ret = list()
58 |         for disc in self.discriminators:
59 |             ret.append(disc(x))
60 | 
61 |         return ret  # [(feat, score), (feat, score), (feat, score), (feat, score), (feat, score)]
62 | 


--------------------------------------------------------------------------------
/vits_decoder/mrd.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | import torch.nn.functional as F
 4 | from torch.nn.utils import weight_norm, spectral_norm
 5 | 
 6 | class DiscriminatorR(torch.nn.Module):
 7 |     def __init__(self, hp, resolution):
 8 |         super(DiscriminatorR, self).__init__()
 9 | 
10 |         self.resolution = resolution
11 |         self.LRELU_SLOPE = hp.mpd.lReLU_slope
12 | 
13 |         norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm
14 | 
15 |         self.convs = nn.ModuleList([
16 |             norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
17 |             norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
18 |             norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
19 |             norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
20 |             norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
21 |         ])
22 |         self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
23 | 
24 |     def forward(self, x):
25 |         fmap = []
26 | 
27 |         x = self.spectrogram(x)
28 |         x = x.unsqueeze(1)
29 |         for l in self.convs:
30 |             x = l(x)
31 |             x = F.leaky_relu(x, self.LRELU_SLOPE)
32 |             fmap.append(x)
33 |         x = self.conv_post(x)
34 |         fmap.append(x)
35 |         x = torch.flatten(x, 1, -1)
36 | 
37 |         return fmap, x
38 | 
39 |     def spectrogram(self, x):
40 |         n_fft, hop_length, win_length = self.resolution
41 |         x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
42 |         x = x.squeeze(1)
43 |         x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) #[B, F, TT, 2]
44 |         mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
45 | 
46 |         return mag
47 | 
48 | 
49 | class MultiResolutionDiscriminator(torch.nn.Module):
50 |     def __init__(self, hp):
51 |         super(MultiResolutionDiscriminator, self).__init__()
52 |         self.resolutions = eval(hp.mrd.resolutions)
53 |         self.discriminators = nn.ModuleList(
54 |             [DiscriminatorR(hp, resolution) for resolution in self.resolutions]
55 |         )
56 | 
57 |     def forward(self, x):
58 |         ret = list()
59 |         for disc in self.discriminators:
60 |             ret.append(disc(x))
61 | 
62 |         return ret  # [(feat, score), (feat, score), (feat, score)]
63 | 


--------------------------------------------------------------------------------
/vits_decoder/msd.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import torch.nn as nn
 3 | import torch.nn.functional as F
 4 | from torch.nn.utils import weight_norm
 5 | 
 6 | 
 7 | class ScaleDiscriminator(torch.nn.Module):
 8 |     def __init__(self):
 9 |         super(ScaleDiscriminator, self).__init__()
10 |         self.convs = nn.ModuleList([
11 |             weight_norm(nn.Conv1d(1, 16, 15, 1, padding=7)),
12 |             weight_norm(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
13 |             weight_norm(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
14 |             weight_norm(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
15 |             weight_norm(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
16 |             weight_norm(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
17 |         ])
18 |         self.conv_post = weight_norm(nn.Conv1d(1024, 1, 3, 1, padding=1))
19 | 
20 |     def forward(self, x):
21 |         fmap = []
22 |         for l in self.convs:
23 |             x = l(x)
24 |             x = F.leaky_relu(x, 0.1)
25 |             fmap.append(x)
26 |         x = self.conv_post(x)
27 |         fmap.append(x)
28 |         x = torch.flatten(x, 1, -1)
29 |         return [(fmap, x)]
30 | 


--------------------------------------------------------------------------------
/vits_extend/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlayVoice/VI-SVS/d9768caa05c742bc5f580a4754b3cd7b444a5eb4/vits_extend/__init__.py


--------------------------------------------------------------------------------
/vits_extend/dataloader.py:
--------------------------------------------------------------------------------
 1 | from torch.utils.data import DataLoader
 2 | from vits.data_utils import DistributedBucketSampler
 3 | from vits.data_utils import TextAudioLoader
 4 | from vits.data_utils import TextAudioCollate
 5 | 
 6 | 
 7 | def create_dataloader_train(hps, n_gpus, rank):
 8 |     collate_fn = TextAudioCollate()
 9 |     train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
10 |     train_sampler = DistributedBucketSampler(
11 |         train_dataset,
12 |         hps.train.batch_size,
13 |         [32, 300, 400, 500, 600, 700, 800, 900, 1000],
14 |         num_replicas=n_gpus,
15 |         rank=rank,
16 |         shuffle=True)
17 |     train_loader = DataLoader(
18 |         train_dataset,
19 |         num_workers=4,
20 |         shuffle=False,
21 |         pin_memory=True,
22 |         collate_fn=collate_fn,
23 |         batch_sampler=train_sampler)
24 |     return train_loader
25 | 
26 | 
27 | def create_dataloader_eval(hps):
28 |     collate_fn = TextAudioCollate()
29 |     eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data)
30 |     eval_loader = DataLoader(
31 |         eval_dataset,
32 |         num_workers=2,
33 |         shuffle=False,
34 |         batch_size=hps.train.batch_size,
35 |         pin_memory=True,
36 |         drop_last=False,
37 |         collate_fn=collate_fn)
38 |     return eval_loader
39 | 


--------------------------------------------------------------------------------
/vits_extend/plotting.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | mpl_logger = logging.getLogger('matplotlib')  # must before import matplotlib
 3 | mpl_logger.setLevel(logging.WARNING)
 4 | import matplotlib
 5 | matplotlib.use("Agg")
 6 | 
 7 | import numpy as np
 8 | import matplotlib.pylab as plt
 9 | 
10 | 
11 | def save_figure_to_numpy(fig):
12 |     # save it to a numpy array.
13 |     data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
14 |     data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
15 |     data = np.transpose(data, (2, 0, 1))
16 |     return data
17 | 
18 | 
19 | def plot_waveform_to_numpy(waveform):
20 |     fig, ax = plt.subplots(figsize=(12, 4))
21 |     ax.plot()
22 |     ax.plot(range(len(waveform)), waveform,
23 |             linewidth=0.1, alpha=0.7, color='blue')
24 | 
25 |     plt.xlabel("Samples")
26 |     plt.ylabel("Amplitude")
27 |     plt.ylim(-1, 1)
28 |     plt.tight_layout()
29 | 
30 |     fig.canvas.draw()
31 |     data = save_figure_to_numpy(fig)
32 |     plt.close()
33 | 
34 |     return data
35 | 
36 | 
37 | def plot_spectrogram_to_numpy(spectrogram):
38 |     fig, ax = plt.subplots(figsize=(12, 4))
39 |     im = ax.imshow(spectrogram, aspect="auto", origin="lower",
40 |                    interpolation='none')
41 |     plt.colorbar(im, ax=ax)
42 |     plt.xlabel("Frames")
43 |     plt.ylabel("Channels")
44 |     plt.tight_layout()
45 | 
46 |     fig.canvas.draw()
47 |     data = save_figure_to_numpy(fig)
48 |     plt.close()
49 |     return data
50 | 


--------------------------------------------------------------------------------
/vits_extend/stft.py:
--------------------------------------------------------------------------------
  1 | # MIT License
  2 | #
  3 | # Copyright (c) 2020 Jungil Kong
  4 | #
  5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
  6 | # of this software and associated documentation files (the "Software"), to deal
  7 | # in the Software without restriction, including without limitation the rights
  8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9 | # copies of the Software, and to permit persons to whom the Software is
 10 | # furnished to do so, subject to the following conditions:
 11 | #
 12 | # The above copyright notice and this permission notice shall be included in all
 13 | # copies or substantial portions of the Software.
 14 | #
 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 21 | # SOFTWARE.
 22 | 
 23 | import math
 24 | import os
 25 | import random
 26 | import torch
 27 | import torch.utils.data
 28 | import numpy as np
 29 | from librosa.util import normalize
 30 | from scipy.io.wavfile import read
 31 | from librosa.filters import mel as librosa_mel_fn
 32 | 
 33 | 
 34 | class TacotronSTFT(torch.nn.Module):
 35 |     def __init__(self, filter_length=512, hop_length=160, win_length=512,
 36 |                  n_mel_channels=80, sampling_rate=16000, mel_fmin=0.0,
 37 |                  mel_fmax=None, center=False, device='cpu'):
 38 |         super(TacotronSTFT, self).__init__()
 39 |         self.n_mel_channels = n_mel_channels
 40 |         self.sampling_rate = sampling_rate
 41 |         self.n_fft = filter_length
 42 |         self.hop_size = hop_length
 43 |         self.win_size = win_length
 44 |         self.fmin = mel_fmin
 45 |         self.fmax = mel_fmax
 46 |         self.center = center
 47 | 
 48 |         mel = librosa_mel_fn(
 49 |             sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax)
 50 | 
 51 |         mel_basis = torch.from_numpy(mel).float().to(device)
 52 |         hann_window = torch.hann_window(win_length).to(device)
 53 | 
 54 |         self.register_buffer('mel_basis', mel_basis)
 55 |         self.register_buffer('hann_window', hann_window)
 56 | 
 57 |     def linear_spectrogram(self, y):
 58 |         # assert (torch.min(y.data) >= -1)
 59 |         # assert (torch.max(y.data) <= 1)
 60 | 
 61 |         y = torch.nn.functional.pad(y.unsqueeze(1),
 62 |                                     (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)),
 63 |                                     mode='reflect')
 64 |         y = y.squeeze(1)
 65 |         spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
 66 |                           center=self.center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
 67 |         spec = torch.norm(spec, p=2, dim=-1)
 68 | 
 69 |         return spec
 70 | 
 71 |     def mel_spectrogram(self, y):
 72 |         """Computes mel-spectrograms from a batch of waves
 73 |         PARAMS
 74 |         ------
 75 |         y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
 76 | 
 77 |         RETURNS
 78 |         -------
 79 |         mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
 80 |         """
 81 |         # assert(torch.min(y.data) >= -1)
 82 |         # assert(torch.max(y.data) <= 1)
 83 | 
 84 |         y = torch.nn.functional.pad(y.unsqueeze(1),
 85 |                                     (int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)),
 86 |                                     mode='reflect')
 87 |         y = y.squeeze(1)
 88 | 
 89 |         spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
 90 |                           center=self.center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
 91 | 
 92 |         spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
 93 | 
 94 |         spec = torch.matmul(self.mel_basis, spec)
 95 |         spec = self.spectral_normalize_torch(spec)
 96 | 
 97 |         return spec
 98 | 
 99 |     def spectral_normalize_torch(self, magnitudes):
100 |         output = self.dynamic_range_compression_torch(magnitudes)
101 |         return output
102 | 
103 |     def dynamic_range_compression_torch(self, x, C=1, clip_val=1e-5):
104 |         return torch.log(torch.clamp(x, min=clip_val) * C)
105 | 


--------------------------------------------------------------------------------
/vits_extend/stft_loss.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | 
  3 | # Copyright 2019 Tomoki Hayashi
  4 | #  MIT License (https://opensource.org/licenses/MIT)
  5 | 
  6 | """STFT-based Loss modules."""
  7 | 
  8 | import torch
  9 | import torch.nn.functional as F
 10 | 
 11 | 
 12 | def stft(x, fft_size, hop_size, win_length, window):
 13 |     """Perform STFT and convert to magnitude spectrogram.
 14 |     Args:
 15 |         x (Tensor): Input signal tensor (B, T).
 16 |         fft_size (int): FFT size.
 17 |         hop_size (int): Hop size.
 18 |         win_length (int): Window length.
 19 |         window (str): Window function type.
 20 |     Returns:
 21 |         Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
 22 |     """
 23 |     x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False)
 24 |     real = x_stft[..., 0]
 25 |     imag = x_stft[..., 1]
 26 | 
 27 |     # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
 28 |     return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
 29 | 
 30 | 
 31 | class SpectralConvergengeLoss(torch.nn.Module):
 32 |     """Spectral convergence loss module."""
 33 | 
 34 |     def __init__(self):
 35 |         """Initilize spectral convergence loss module."""
 36 |         super(SpectralConvergengeLoss, self).__init__()
 37 | 
 38 |     def forward(self, x_mag, y_mag):
 39 |         """Calculate forward propagation.
 40 |         Args:
 41 |             x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
 42 |             y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
 43 |         Returns:
 44 |             Tensor: Spectral convergence loss value.
 45 |         """
 46 |         return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
 47 | 
 48 | 
 49 | class LogSTFTMagnitudeLoss(torch.nn.Module):
 50 |     """Log STFT magnitude loss module."""
 51 | 
 52 |     def __init__(self):
 53 |         """Initilize los STFT magnitude loss module."""
 54 |         super(LogSTFTMagnitudeLoss, self).__init__()
 55 | 
 56 |     def forward(self, x_mag, y_mag):
 57 |         """Calculate forward propagation.
 58 |         Args:
 59 |             x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
 60 |             y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
 61 |         Returns:
 62 |             Tensor: Log STFT magnitude loss value.
 63 |         """
 64 |         return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
 65 | 
 66 | 
 67 | class STFTLoss(torch.nn.Module):
 68 |     """STFT loss module."""
 69 | 
 70 |     def __init__(self, device, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
 71 |         """Initialize STFT loss module."""
 72 |         super(STFTLoss, self).__init__()
 73 |         self.fft_size = fft_size
 74 |         self.shift_size = shift_size
 75 |         self.win_length = win_length
 76 |         self.window = getattr(torch, window)(win_length).to(device)
 77 |         self.spectral_convergenge_loss = SpectralConvergengeLoss()
 78 |         self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
 79 | 
 80 |     def forward(self, x, y):
 81 |         """Calculate forward propagation.
 82 |         Args:
 83 |             x (Tensor): Predicted signal (B, T).
 84 |             y (Tensor): Groundtruth signal (B, T).
 85 |         Returns:
 86 |             Tensor: Spectral convergence loss value.
 87 |             Tensor: Log STFT magnitude loss value.
 88 |         """
 89 |         x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
 90 |         y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
 91 |         sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
 92 |         mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
 93 | 
 94 |         return sc_loss, mag_loss
 95 | 
 96 | 
 97 | class MultiResolutionSTFTLoss(torch.nn.Module):
 98 |     """Multi resolution STFT loss module."""
 99 | 
100 |     def __init__(self,
101 |                  device,
102 |                  resolutions,
103 |                  window="hann_window"):
104 |         """Initialize Multi resolution STFT loss module.
105 |         Args:
106 |             resolutions (list): List of (FFT size, hop size, window length).
107 |             window (str): Window function type.
108 |         """
109 |         super(MultiResolutionSTFTLoss, self).__init__()
110 |         self.stft_losses = torch.nn.ModuleList()
111 |         for fs, ss, wl in resolutions:
112 |             self.stft_losses += [STFTLoss(device, fs, ss, wl, window)]
113 | 
114 |     def forward(self, x, y):
115 |         """Calculate forward propagation.
116 |         Args:
117 |             x (Tensor): Predicted signal (B, T).
118 |             y (Tensor): Groundtruth signal (B, T).
119 |         Returns:
120 |             Tensor: Multi resolution spectral convergence loss value.
121 |             Tensor: Multi resolution log STFT magnitude loss value.
122 |         """
123 |         sc_loss = 0.0
124 |         mag_loss = 0.0
125 |         for f in self.stft_losses:
126 |             sc_l, mag_l = f(x, y)
127 |             sc_loss += sc_l
128 |             mag_loss += mag_l
129 | 
130 |         sc_loss /= len(self.stft_losses)
131 |         mag_loss /= len(self.stft_losses)
132 | 
133 |         return sc_loss, mag_loss
134 | 


--------------------------------------------------------------------------------
/vits_extend/validation.py:
--------------------------------------------------------------------------------
 1 | import tqdm
 2 | import torch
 3 | import torch.nn.functional as F
 4 | 
 5 | 
 6 | def validate(hp, args, generator, discriminator, valloader, stft, writer, step, device):
 7 |     generator.eval()
 8 |     discriminator.eval()
 9 |     torch.backends.cudnn.benchmark = False
10 | 
11 |     loader = tqdm.tqdm(valloader, desc='Validation loop')
12 |     mel_loss = 0.0
13 |     for idx, (phone, phone_l, score, pitch, slurs, spec, spec_l, audio, audio_l) in enumerate(loader):
14 |         phone = phone.to(device)
15 |         phone_l = phone_l.to(device)
16 |         score = score.to(device)
17 |         pitch = pitch.to(device)
18 |         slurs = slurs.to(device)
19 |         audio = audio.to(device)
20 | 
21 |         if hasattr(generator, 'module'):
22 |             fake_audio = generator.module.infer(phone, phone_l, score, pitch, slurs)[
23 |                 :, :, :audio.size(2)]
24 |         else:
25 |             fake_audio = generator.infer(phone, phone_l, score, pitch, slurs)[
26 |                 :, :, :audio.size(2)]
27 | 
28 |         mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1))
29 |         mel_real = stft.mel_spectrogram(audio.squeeze(1))
30 | 
31 |         mel_loss += F.l1_loss(mel_fake, mel_real).item()
32 | 
33 |         if idx < hp.log.num_audio:
34 |             spec_fake = stft.linear_spectrogram(fake_audio.squeeze(1))
35 |             spec_real = stft.linear_spectrogram(audio.squeeze(1))
36 | 
37 |             audio = audio[0][0].cpu().detach().numpy()
38 |             fake_audio = fake_audio[0][0].cpu().detach().numpy()
39 |             spec_fake = spec_fake[0].cpu().detach().numpy()
40 |             spec_real = spec_real[0].cpu().detach().numpy()
41 |             writer.log_fig_audio(
42 |                 audio, fake_audio, spec_fake, spec_real, idx, step)
43 | 
44 |     mel_loss = mel_loss / len(valloader.dataset)
45 | 
46 |     writer.log_validation(mel_loss, generator, discriminator, step)
47 | 
48 |     torch.backends.cudnn.benchmark = True
49 | 


--------------------------------------------------------------------------------
/vits_extend/writer.py:
--------------------------------------------------------------------------------
 1 | from torch.utils.tensorboard import SummaryWriter
 2 | import numpy as np
 3 | import librosa
 4 | 
 5 | from .plotting import plot_waveform_to_numpy, plot_spectrogram_to_numpy
 6 | 
 7 | class MyWriter(SummaryWriter):
 8 |     def __init__(self, hp, logdir):
 9 |         super(MyWriter, self).__init__(logdir)
10 |         self.sample_rate = hp.data.sampling_rate
11 | 
12 |     def log_training(self, g_loss, d_loss, mel_loss, stft_loss, k_loss, r_loss, score_loss, step):
13 |         self.add_scalar('train/g_loss', g_loss, step)
14 |         self.add_scalar('train/d_loss', d_loss, step)
15 |         
16 |         self.add_scalar('train/score_loss', score_loss, step)
17 |         self.add_scalar('train/stft_loss', stft_loss, step)
18 |         self.add_scalar('train/mel_loss', mel_loss, step)
19 |         self.add_scalar('train/kl_f_loss', k_loss, step)
20 |         self.add_scalar('train/kl_r_loss', r_loss, step)
21 | 
22 |     def log_validation(self, mel_loss, generator, discriminator, step):
23 |         self.add_scalar('validation/mel_loss', mel_loss, step)
24 | 
25 |     def log_fig_audio(self, real, fake, spec_fake, spec_real, idx, step):
26 |         if idx == 0:
27 |             spec_fake = librosa.amplitude_to_db(spec_fake, ref=np.max,top_db=80.)
28 |             spec_real = librosa.amplitude_to_db(spec_real, ref=np.max,top_db=80.)
29 |             self.add_image(f'spec_fake/{step}', plot_spectrogram_to_numpy(spec_fake), step)
30 |             self.add_image(f'wave_fake/{step}', plot_waveform_to_numpy(fake), step)
31 |             self.add_image(f'spec_real/{step}', plot_spectrogram_to_numpy(spec_real), step)
32 |             self.add_image(f'wave_real/{step}', plot_waveform_to_numpy(real), step)
33 | 
34 |             self.add_audio(f'fake/{step}', fake, step, self.sample_rate)
35 |             self.add_audio(f'real/{step}', real, step, self.sample_rate)
36 | 
37 |     def log_histogram(self, model, step):
38 |         for tag, value in model.named_parameters():
39 |             self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)
40 | 


--------------------------------------------------------------------------------