├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── attentions.py ├── commons.py ├── configs ├── requirements.txt └── singing_base.json ├── data_utils.py ├── evaluate ├── evaluate_f0.py ├── evaluate_mcd.py ├── evaluate_semitone.py └── evaluate_vuv.py ├── evaluate_score.sh ├── filelists ├── singing_test.txt ├── singing_train.txt └── singing_valid.txt ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── normalize_wav.py ├── plot_f0.py ├── prepare ├── __init__.py ├── align_wav_spec.py ├── data_vits.py ├── data_vits_phn.py ├── data_vits_phn_ofuton.py ├── dur_to_frame.py ├── gen_ofuton_transcript.py ├── midi-HZ.scp ├── midi-note.scp ├── phone_map.py ├── phone_uv.py ├── preprocess.py ├── preprocess_jp.py ├── resample_wav.py └── resample_wav.sh ├── resource ├── 2005000151.wav ├── 2005000152.wav ├── 2006000186.wav ├── 2006000187.wav ├── 2008000268.wav ├── vising_loss.png └── vising_mel.png ├── train.py ├── train.sh ├── transforms.py ├── utils.py ├── vsinging_debug.py ├── vsinging_infer.py ├── vsinging_infer.txt ├── vsinging_infer_jp.py ├── vsinging_infer_jp.txt ├── vsinging_song.py └── vsinging_song_midi.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pyc 3 | filelists/singing_train.txt 4 | filelists/singing_valid.txt 5 | filelists/vits_file.txt 6 | logs 7 | singing_out 8 | */*_res 9 | *.zip 10 | nohup.out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Init 2 | Unofficial Implement of VISinger 3 | 4 | # Reference Repos 5 | https://github.com/jaywalnut310/vits 6 | 7 | https://github.com/MoonInTheRiver/DiffSinger 8 | 9 | https://wenet.org.cn/opencpop/ 10 | 11 | https://github.com/PlayVoice/VI-SVS 12 | 13 | # Data Preprocess 14 | ```bash 15 | export PYTHONPATH=. 16 | ``` 17 | 18 | Generate ../VISinger_data/label_vits_phn/XXX._label.npy|XXX._label_dur.npy|XXX_score.npy|XXX_score_dur.npy|XXX_pitch.npy|XXX_slurs.npy 19 | 20 | ```bash 21 | python prepare/data_vits_phn.py 22 | ``` 23 | 24 | Generate filelists/vits_file.txt 25 | Format: wave path|label path|label duration path|score path|score duration path|pitch path|slurs path; 26 | 27 | ```bash 28 | python prepare/preprocess.py 29 | ``` 30 | 31 | # VISinger training 32 | 33 | ```bash 34 | python train.py -c configs/singing_base.json -m singing_base 35 | ``` 36 | 37 | or 38 | 39 | ```bash 40 | ./train.sh 41 | ``` 42 | 43 | # Inference 44 | 45 | ```bash 46 | ./evaluate_score.sh 47 | ``` 48 | 49 | ![LOSS](/resource/vising_loss.png) 50 | ![MEL](/resource/vising_mel.png) 51 | 52 | # Samples 53 | 54 | 57 | 58 | 61 | 62 | 65 | 66 | 69 | 70 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | remote_theme: pages-themes/cayman@v0.2.0 2 | plugins: 3 | - jekyll-remote-theme # add this line to the plugins list if you already have one -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import commons 9 | import modules 10 | from modules import LayerNorm 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__( 15 | self, 16 | hidden_channels, 17 | filter_channels, 18 | n_heads, 19 | n_layers, 20 | kernel_size=1, 21 | p_dropout=0.0, 22 | window_size=10, 23 | **kwargs 24 | ): 25 | super().__init__() 26 | self.hidden_channels = hidden_channels 27 | self.filter_channels = filter_channels 28 | self.n_heads = n_heads 29 | self.n_layers = n_layers 30 | self.kernel_size = kernel_size 31 | self.p_dropout = p_dropout 32 | self.window_size = window_size 33 | 34 | self.drop = nn.Dropout(p_dropout) 35 | self.attn_layers = nn.ModuleList() 36 | self.norm_layers_1 = nn.ModuleList() 37 | self.ffn_layers = nn.ModuleList() 38 | self.norm_layers_2 = nn.ModuleList() 39 | for i in range(self.n_layers): 40 | self.attn_layers.append( 41 | MultiHeadAttention( 42 | hidden_channels, 43 | hidden_channels, 44 | n_heads, 45 | p_dropout=p_dropout, 46 | window_size=window_size, 47 | ) 48 | ) 49 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 50 | self.ffn_layers.append( 51 | FFN( 52 | hidden_channels, 53 | hidden_channels, 54 | filter_channels, 55 | kernel_size, 56 | p_dropout=p_dropout, 57 | ) 58 | ) 59 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 60 | 61 | def forward(self, x, x_mask): 62 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 63 | x = x * x_mask 64 | for i in range(self.n_layers): 65 | y = self.attn_layers[i](x, x, attn_mask) 66 | y = self.drop(y) 67 | x = self.norm_layers_1[i](x + y) 68 | 69 | y = self.ffn_layers[i](x, x_mask) 70 | y = self.drop(y) 71 | x = self.norm_layers_2[i](x + y) 72 | x = x * x_mask 73 | return x 74 | 75 | 76 | class Decoder(nn.Module): 77 | def __init__( 78 | self, 79 | hidden_channels, 80 | filter_channels, 81 | n_heads, 82 | n_layers, 83 | kernel_size=1, 84 | p_dropout=0.0, 85 | proximal_bias=False, 86 | proximal_init=True, 87 | **kwargs 88 | ): 89 | super().__init__() 90 | self.hidden_channels = hidden_channels 91 | self.filter_channels = filter_channels 92 | self.n_heads = n_heads 93 | self.n_layers = n_layers 94 | self.kernel_size = kernel_size 95 | self.p_dropout = p_dropout 96 | self.proximal_bias = proximal_bias 97 | self.proximal_init = proximal_init 98 | 99 | self.drop = nn.Dropout(p_dropout) 100 | self.self_attn_layers = nn.ModuleList() 101 | self.norm_layers_0 = nn.ModuleList() 102 | self.encdec_attn_layers = nn.ModuleList() 103 | self.norm_layers_1 = nn.ModuleList() 104 | self.ffn_layers = nn.ModuleList() 105 | self.norm_layers_2 = nn.ModuleList() 106 | for i in range(self.n_layers): 107 | self.self_attn_layers.append( 108 | MultiHeadAttention( 109 | hidden_channels, 110 | hidden_channels, 111 | n_heads, 112 | p_dropout=p_dropout, 113 | proximal_bias=proximal_bias, 114 | proximal_init=proximal_init, 115 | ) 116 | ) 117 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 118 | self.encdec_attn_layers.append( 119 | MultiHeadAttention( 120 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 121 | ) 122 | ) 123 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 124 | self.ffn_layers.append( 125 | FFN( 126 | hidden_channels, 127 | hidden_channels, 128 | filter_channels, 129 | kernel_size, 130 | p_dropout=p_dropout, 131 | causal=True, 132 | ) 133 | ) 134 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 135 | 136 | def forward(self, x, x_mask, h, h_mask): 137 | """ 138 | x: decoder input 139 | h: encoder output 140 | """ 141 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 142 | device=x.device, dtype=x.dtype 143 | ) 144 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 145 | x = x * x_mask 146 | for i in range(self.n_layers): 147 | y = self.self_attn_layers[i](x, x, self_attn_mask) 148 | y = self.drop(y) 149 | x = self.norm_layers_0[i](x + y) 150 | 151 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 152 | y = self.drop(y) 153 | x = self.norm_layers_1[i](x + y) 154 | 155 | y = self.ffn_layers[i](x, x_mask) 156 | y = self.drop(y) 157 | x = self.norm_layers_2[i](x + y) 158 | x = x * x_mask 159 | return x 160 | 161 | 162 | class MultiHeadAttention(nn.Module): 163 | def __init__( 164 | self, 165 | channels, 166 | out_channels, 167 | n_heads, 168 | p_dropout=0.0, 169 | window_size=None, 170 | heads_share=True, 171 | block_length=None, 172 | proximal_bias=False, 173 | proximal_init=False, 174 | ): 175 | super().__init__() 176 | assert channels % n_heads == 0 177 | 178 | self.channels = channels 179 | self.out_channels = out_channels 180 | self.n_heads = n_heads 181 | self.p_dropout = p_dropout 182 | self.window_size = window_size 183 | self.heads_share = heads_share 184 | self.block_length = block_length 185 | self.proximal_bias = proximal_bias 186 | self.proximal_init = proximal_init 187 | self.attn = None 188 | 189 | self.k_channels = channels // n_heads 190 | self.conv_q = nn.Conv1d(channels, channels, 1) 191 | self.conv_k = nn.Conv1d(channels, channels, 1) 192 | self.conv_v = nn.Conv1d(channels, channels, 1) 193 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 194 | self.drop = nn.Dropout(p_dropout) 195 | 196 | if window_size is not None: 197 | n_heads_rel = 1 if heads_share else n_heads 198 | rel_stddev = self.k_channels**-0.5 199 | self.emb_rel_k = nn.Parameter( 200 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 201 | * rel_stddev 202 | ) 203 | self.emb_rel_v = nn.Parameter( 204 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 205 | * rel_stddev 206 | ) 207 | 208 | nn.init.xavier_uniform_(self.conv_q.weight) 209 | nn.init.xavier_uniform_(self.conv_k.weight) 210 | nn.init.xavier_uniform_(self.conv_v.weight) 211 | if proximal_init: 212 | with torch.no_grad(): 213 | self.conv_k.weight.copy_(self.conv_q.weight) 214 | self.conv_k.bias.copy_(self.conv_q.bias) 215 | 216 | def forward(self, x, c, attn_mask=None): 217 | q = self.conv_q(x) 218 | k = self.conv_k(c) 219 | v = self.conv_v(c) 220 | 221 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 222 | 223 | x = self.conv_o(x) 224 | return x 225 | 226 | def attention(self, query, key, value, mask=None): 227 | # reshape [b, d, t] -> [b, n_h, t, d_k] 228 | b, d, t_s, t_t = (*key.size(), query.size(2)) 229 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 230 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 231 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 232 | 233 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 234 | if self.window_size is not None: 235 | assert ( 236 | t_s == t_t 237 | ), "Relative attention is only available for self-attention." 238 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 239 | rel_logits = self._matmul_with_relative_keys( 240 | query / math.sqrt(self.k_channels), key_relative_embeddings 241 | ) 242 | scores_local = self._relative_position_to_absolute_position(rel_logits) 243 | scores = scores + scores_local 244 | if self.proximal_bias: 245 | assert t_s == t_t, "Proximal bias is only available for self-attention." 246 | scores = scores + self._attention_bias_proximal(t_s).to( 247 | device=scores.device, dtype=scores.dtype 248 | ) 249 | if mask is not None: 250 | scores = scores.masked_fill(mask == 0, -1e4) 251 | if self.block_length is not None: 252 | assert ( 253 | t_s == t_t 254 | ), "Local attention is only available for self-attention." 255 | block_mask = ( 256 | torch.ones_like(scores) 257 | .triu(-self.block_length) 258 | .tril(self.block_length) 259 | ) 260 | scores = scores.masked_fill(block_mask == 0, -1e4) 261 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 262 | p_attn = self.drop(p_attn) 263 | output = torch.matmul(p_attn, value) 264 | if self.window_size is not None: 265 | relative_weights = self._absolute_position_to_relative_position(p_attn) 266 | value_relative_embeddings = self._get_relative_embeddings( 267 | self.emb_rel_v, t_s 268 | ) 269 | output = output + self._matmul_with_relative_values( 270 | relative_weights, value_relative_embeddings 271 | ) 272 | output = ( 273 | output.transpose(2, 3).contiguous().view(b, d, t_t) 274 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 275 | return output, p_attn 276 | 277 | def _matmul_with_relative_values(self, x, y): 278 | """ 279 | x: [b, h, l, m] 280 | y: [h or 1, m, d] 281 | ret: [b, h, l, d] 282 | """ 283 | ret = torch.matmul(x, y.unsqueeze(0)) 284 | return ret 285 | 286 | def _matmul_with_relative_keys(self, x, y): 287 | """ 288 | x: [b, h, l, d] 289 | y: [h or 1, m, d] 290 | ret: [b, h, l, m] 291 | """ 292 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 293 | return ret 294 | 295 | def _get_relative_embeddings(self, relative_embeddings, length): 296 | max_relative_position = 2 * self.window_size + 1 297 | # Pad first before slice to avoid using cond ops. 298 | pad_length = max(length - (self.window_size + 1), 0) 299 | slice_start_position = max((self.window_size + 1) - length, 0) 300 | slice_end_position = slice_start_position + 2 * length - 1 301 | if pad_length > 0: 302 | padded_relative_embeddings = F.pad( 303 | relative_embeddings, 304 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 305 | ) 306 | else: 307 | padded_relative_embeddings = relative_embeddings 308 | used_relative_embeddings = padded_relative_embeddings[ 309 | :, slice_start_position:slice_end_position 310 | ] 311 | return used_relative_embeddings 312 | 313 | def _relative_position_to_absolute_position(self, x): 314 | """ 315 | x: [b, h, l, 2*l-1] 316 | ret: [b, h, l, l] 317 | """ 318 | batch, heads, length, _ = x.size() 319 | # Concat columns of pad to shift from relative to absolute indexing. 320 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 321 | 322 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 323 | x_flat = x.view([batch, heads, length * 2 * length]) 324 | x_flat = F.pad( 325 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 326 | ) 327 | 328 | # Reshape and slice out the padded elements. 329 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 330 | :, :, :length, length - 1 : 331 | ] 332 | return x_final 333 | 334 | def _absolute_position_to_relative_position(self, x): 335 | """ 336 | x: [b, h, l, l] 337 | ret: [b, h, l, 2*l-1] 338 | """ 339 | batch, heads, length, _ = x.size() 340 | # padd along column 341 | x = F.pad( 342 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 343 | ) 344 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 345 | # add 0's in the beginning that will skew the elements after reshape 346 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 347 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 348 | return x_final 349 | 350 | def _attention_bias_proximal(self, length): 351 | """Bias for self-attention to encourage attention to close positions. 352 | Args: 353 | length: an integer scalar. 354 | Returns: 355 | a Tensor with shape [1, 1, length, length] 356 | """ 357 | r = torch.arange(length, dtype=torch.float32) 358 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 359 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 360 | 361 | 362 | class FFN(nn.Module): 363 | def __init__( 364 | self, 365 | in_channels, 366 | out_channels, 367 | filter_channels, 368 | kernel_size, 369 | p_dropout=0.0, 370 | activation=None, 371 | causal=False, 372 | ): 373 | super().__init__() 374 | self.in_channels = in_channels 375 | self.out_channels = out_channels 376 | self.filter_channels = filter_channels 377 | self.kernel_size = kernel_size 378 | self.p_dropout = p_dropout 379 | self.activation = activation 380 | self.causal = causal 381 | 382 | if causal: 383 | self.padding = self._causal_padding 384 | else: 385 | self.padding = self._same_padding 386 | 387 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 388 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 389 | self.drop = nn.Dropout(p_dropout) 390 | 391 | def forward(self, x, x_mask): 392 | x = self.conv_1(self.padding(x * x_mask)) 393 | if self.activation == "gelu": 394 | x = x * torch.sigmoid(1.702 * x) 395 | else: 396 | x = torch.relu(x) 397 | x = self.drop(x) 398 | x = self.conv_2(self.padding(x * x_mask)) 399 | return x * x_mask 400 | 401 | def _causal_padding(self, x): 402 | if self.kernel_size == 1: 403 | return x 404 | pad_l = self.kernel_size - 1 405 | pad_r = 0 406 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 407 | x = F.pad(x, commons.convert_pad_shape(padding)) 408 | return x 409 | 410 | def _same_padding(self, x): 411 | if self.kernel_size == 1: 412 | return x 413 | pad_l = (self.kernel_size - 1) // 2 414 | pad_r = self.kernel_size // 2 415 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 416 | x = F.pad(x, commons.convert_pad_shape(padding)) 417 | return x 418 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def kl_divergence(m_p, logs_p, m_q, logs_q): 25 | """KL(P||Q)""" 26 | kl = (logs_q - logs_p) - 0.5 27 | kl += ( 28 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 29 | ) 30 | return kl 31 | 32 | 33 | def rand_gumbel(shape): 34 | """Sample from the Gumbel distribution, protect from overflows.""" 35 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 36 | return -torch.log(-torch.log(uniform_samples)) 37 | 38 | 39 | def rand_gumbel_like(x): 40 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 41 | return g 42 | 43 | 44 | def slice_segments(x, ids_str, segment_size=4): 45 | ret = torch.zeros_like(x[:, :, :segment_size]) 46 | for i in range(x.size(0)): 47 | idx_str = ids_str[i] 48 | idx_end = idx_str + segment_size 49 | ret[i] = x[i, :, idx_str:idx_end] 50 | return ret 51 | 52 | 53 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 54 | b, d, t = x.size() 55 | if x_lengths is None: 56 | x_lengths = t 57 | ids_str_max = x_lengths - segment_size + 1 58 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 59 | ret = slice_segments(x, ids_str, segment_size) 60 | return ret, ids_str 61 | 62 | 63 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 64 | position = torch.arange(length, dtype=torch.float) 65 | num_timescales = channels // 2 66 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 67 | num_timescales - 1 68 | ) 69 | inv_timescales = min_timescale * torch.exp( 70 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 71 | ) 72 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 73 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 74 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 75 | signal = signal.view(1, channels, length) 76 | return signal 77 | 78 | 79 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 80 | b, channels, length = x.size() 81 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 82 | return x + signal.to(dtype=x.dtype, device=x.device) 83 | 84 | 85 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 86 | b, channels, length = x.size() 87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 88 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 89 | 90 | 91 | def subsequent_mask(length): 92 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 93 | return mask 94 | 95 | 96 | @torch.jit.script 97 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 98 | n_channels_int = n_channels[0] 99 | in_act = input_a + input_b 100 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 101 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 102 | acts = t_act * s_act 103 | return acts 104 | 105 | 106 | def convert_pad_shape(pad_shape): 107 | l = pad_shape[::-1] 108 | pad_shape = [item for sublist in l for item in sublist] 109 | return pad_shape 110 | 111 | 112 | def shift_1d(x): 113 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 114 | return x 115 | 116 | 117 | def sequence_mask(length, max_length=None): 118 | if max_length is None: 119 | max_length = length.max() 120 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 121 | return x.unsqueeze(0) < length.unsqueeze(1) 122 | 123 | 124 | def generate_path(duration, mask): 125 | """ 126 | duration: [b, 1, t_x] 127 | mask: [b, 1, t_y, t_x] 128 | """ 129 | device = duration.device 130 | 131 | b, _, t_y, t_x = mask.shape 132 | cum_duration = torch.cumsum(duration, -1) 133 | 134 | cum_duration_flat = cum_duration.view(b * t_x) 135 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 136 | path = path.view(b, t_x, t_y) 137 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 138 | path = path.unsqueeze(1).transpose(2, 3) * mask 139 | return path 140 | 141 | 142 | def clip_grad_value_(parameters, clip_value, norm_type=2): 143 | if isinstance(parameters, torch.Tensor): 144 | parameters = [parameters] 145 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 146 | norm_type = float(norm_type) 147 | if clip_value is not None: 148 | clip_value = float(clip_value) 149 | 150 | total_norm = 0 151 | for p in parameters: 152 | param_norm = p.grad.data.norm(norm_type) 153 | total_norm += param_norm.item() ** norm_type 154 | if clip_value is not None: 155 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 156 | total_norm = total_norm ** (1.0 / norm_type) 157 | return total_norm 158 | -------------------------------------------------------------------------------- /configs/requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | librosa==0.8.0 3 | matplotlib==3.3.1 4 | numpy==1.18.5 5 | phonemizer==2.2.1 6 | scipy==1.5.2 7 | tensorboard==2.3.0 8 | torch==1.6.0 9 | torchvision==0.7.0 10 | Unidecode==1.1.1 11 | -------------------------------------------------------------------------------- /configs/singing_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 2000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 1e-4, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-9, 13 | "batch_size": 6, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 8192, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 45, 20 | "c_kl": 1.0, 21 | "keep_n_models": 20 22 | }, 23 | "data": { 24 | "training_files": "filelists/singing_train.txt", 25 | "validation_files": "filelists/singing_valid.txt", 26 | "max_wav_value": 32768.0, 27 | "sampling_rate": 24000, 28 | "filter_length": 1024, 29 | "hop_length": 256, 30 | "win_length": 1024, 31 | "n_mel_channels": 80, 32 | "mel_fmin": 0.0, 33 | "mel_fmax": null, 34 | "n_speakers": 0 35 | }, 36 | "model": { 37 | "inter_channels": 192, 38 | "hidden_channels": 192, 39 | "filter_channels": 768, 40 | "n_heads": 2, 41 | "n_layers": 6, 42 | "kernel_size": 3, 43 | "p_dropout": 0.1, 44 | "resblock": "1", 45 | "resblock_kernel_sizes": [ 46 | 3, 47 | 7, 48 | 11 49 | ], 50 | "resblock_dilation_sizes": [ 51 | [ 52 | 1, 53 | 3, 54 | 5 55 | ], 56 | [ 57 | 1, 58 | 3, 59 | 5 60 | ], 61 | [ 62 | 1, 63 | 3, 64 | 5 65 | ] 66 | ], 67 | "upsample_rates": [ 68 | 8, 69 | 8, 70 | 2, 71 | 2 72 | ], 73 | "upsample_initial_channel": 384, 74 | "upsample_kernel_sizes": [ 75 | 16, 76 | 16, 77 | 4, 78 | 4 79 | ], 80 | "use_spectral_norm": false 81 | } 82 | } -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | from mel_processing import spectrogram_torch 7 | from utils import load_wav_to_torch, load_filepaths_and_text 8 | import scipy.io.wavfile as sciwav 9 | 10 | 11 | class TextAudioLoader(torch.utils.data.Dataset): 12 | """ 13 | 1) loads audio, text pairs 14 | 2) normalizes text and converts them to sequences of integers 15 | 3) computes spectrograms from audio files. 16 | """ 17 | 18 | def __init__(self, audiopaths_and_text, hparams): 19 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 20 | self.max_wav_value = hparams.max_wav_value 21 | self.sampling_rate = hparams.sampling_rate 22 | self.filter_length = hparams.filter_length 23 | self.hop_length = hparams.hop_length 24 | self.win_length = hparams.win_length 25 | self.sampling_rate = hparams.sampling_rate 26 | self.min_text_len = getattr(hparams, "min_text_len", 1) 27 | self.max_text_len = getattr(hparams, "max_text_len", 5000) 28 | self._filter() 29 | 30 | def _filter(self): 31 | """ 32 | Filter text & store spec lengths 33 | """ 34 | # Store spectrogram lengths for Bucketing 35 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 36 | # spec_length = wav_length // hop_length 37 | audiopaths_and_text_new = [] 38 | lengths = [] 39 | 40 | for ( 41 | audiopath, 42 | text, 43 | text_dur, 44 | score, 45 | score_dur, 46 | pitch, 47 | slur, 48 | ) in self.audiopaths_and_text: 49 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 50 | audiopaths_and_text_new.append( 51 | [audiopath, text, text_dur, score, score_dur, pitch, slur] 52 | ) 53 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 54 | self.audiopaths_and_text = audiopaths_and_text_new 55 | self.lengths = lengths 56 | 57 | def get_audio_text_pair(self, audiopath_and_text): 58 | # separate filename and text 59 | file = audiopath_and_text[0] 60 | phone = audiopath_and_text[1] 61 | phone_dur = audiopath_and_text[2] 62 | score = audiopath_and_text[3] 63 | score_dur = audiopath_and_text[4] 64 | pitch = audiopath_and_text[5] 65 | slurs = audiopath_and_text[6] 66 | 67 | phone, phone_dur, score, score_dur, pitch, slurs = self.get_labels( 68 | phone, phone_dur, score, score_dur, pitch, slurs 69 | ) 70 | spec, wav = self.get_audio(file, phone_dur) 71 | 72 | len_phone = phone.size()[0] 73 | len_spec = spec.size()[-1] 74 | 75 | if len_phone != len_spec: 76 | # print("**************CareFull*******************") 77 | # print(f"filepath={audiopath_and_text[0]}") 78 | # print(f"len_text={len_phone}") 79 | # print(f"len_spec={len_spec}") 80 | if len_phone > len_spec: 81 | print(file) 82 | print("len_phone", len_phone) 83 | print("len_spec", len_spec) 84 | assert len_phone < len_spec 85 | # len_min = min(len_phone, len_spec) 86 | # amor hop_size=256 87 | len_wav = len_spec * self.hop_length 88 | # print(wav.size()) 89 | # print(f"len_min={len_min}") 90 | # print(f"len_wav={len_wav}") 91 | # spec = spec[:, :len_min] 92 | wav = wav[:, :len_wav] 93 | return (phone, phone_dur, score, score_dur, pitch, slurs, spec, wav) 94 | 95 | def get_labels(self, phone, phone_dur, score, score_dur, pitch, slurs): 96 | phone = np.load(phone) 97 | phone_dur = np.load(phone_dur) 98 | score = np.load(score) 99 | score_dur = np.load(score_dur) 100 | pitch = np.load(pitch) 101 | slurs = np.load(slurs) 102 | phone = torch.LongTensor(phone) 103 | phone_dur = torch.LongTensor(phone_dur) 104 | score = torch.LongTensor(score) 105 | score_dur = torch.LongTensor(score_dur) 106 | pitch = torch.FloatTensor(pitch) 107 | slurs = torch.LongTensor(slurs) 108 | return phone, phone_dur, score, score_dur, pitch, slurs 109 | 110 | def get_audio(self, filename, phone_dur): 111 | audio, sampling_rate = load_wav_to_torch(filename) 112 | if sampling_rate != self.sampling_rate: 113 | raise ValueError( 114 | "{} {} SR doesn't match target {} SR".format( 115 | filename, sampling_rate, self.sampling_rate 116 | ) 117 | ) 118 | audio_norm = audio / self.max_wav_value 119 | audio_norm = audio_norm.unsqueeze(0) 120 | spec_filename = filename.replace(".wav", ".spec.pt") 121 | if os.path.exists(spec_filename): 122 | spec = torch.load(spec_filename) 123 | else: 124 | print("please run data_vits_phn.py first") 125 | assert FileExistsError 126 | # else: 127 | # spec = spectrogram_torch( 128 | # audio_norm, 129 | # self.filter_length, 130 | # self.sampling_rate, 131 | # self.hop_length, 132 | # self.win_length, 133 | # center=False, 134 | # ) 135 | # # align mel and wave 136 | # phone_dur_sum = torch.sum(phone_dur).item() 137 | # spec_length = spec.shape[2] 138 | 139 | # if spec_length > phone_dur_sum: 140 | # spec = spec[:, :, :phone_dur_sum] 141 | # elif spec_length < phone_dur_sum: 142 | # pad_length = phone_dur_sum - spec_length 143 | # spec = torch.nn.functional.pad( 144 | # input=spec, pad=(0, pad_length, 0, 0), mode="constant", value=0 145 | # ) 146 | # assert spec.shape[2] == phone_dur_sum 147 | 148 | # # align wav 149 | # fixed_wav_len = phone_dur_sum * self.hop_length 150 | # if audio_norm.shape[1] > fixed_wav_len: 151 | # audio_norm = audio_norm[:, :fixed_wav_len] 152 | # elif audio_norm.shape[1] < fixed_wav_len: 153 | # pad_length = fixed_wav_len - audio_norm.shape[1] 154 | # audio_norm = torch.nn.functional.pad( 155 | # input=audio_norm, 156 | # pad=(0, pad_length, 0, 0), 157 | # mode="constant", 158 | # value=0, 159 | # ) 160 | # assert audio_norm.shape[1] == fixed_wav_len 161 | 162 | # # rewrite aligned wav 163 | # audio = (audio_norm * self.max_wav_value).transpose(0, 1).numpy().astype(np.int16) 164 | 165 | # sciwav.write( 166 | # filename, 167 | # self.sampling_rate, 168 | # audio, 169 | # ) 170 | # # save spec 171 | # spec = torch.squeeze(spec, 0) 172 | # torch.save(spec, spec_filename) 173 | return spec, audio_norm 174 | 175 | def __getitem__(self, index): 176 | return self.get_audio_text_pair(self.audiopaths_and_text[index]) 177 | 178 | def __len__(self): 179 | return len(self.audiopaths_and_text) 180 | 181 | 182 | class TextAudioCollate: 183 | """Zero-pads model inputs and targets""" 184 | 185 | def __init__(self, return_ids=False): 186 | self.return_ids = return_ids 187 | 188 | def __call__(self, batch): 189 | """Collate's training batch from normalized text and aduio 190 | PARAMS 191 | ------ 192 | batch: [text_normalized, spec_normalized, wav_normalized] 193 | """ 194 | # Right zero-pad all one-hot text sequences to max input length 195 | _, ids_sorted_decreasing = torch.sort( 196 | torch.LongTensor([x[6].size(1) for x in batch]), dim=0, descending=True 197 | ) 198 | 199 | max_phone_len = max([len(x[0]) for x in batch]) 200 | max_spec_len = max([x[6].size(1) for x in batch]) 201 | max_wave_len = max([x[7].size(1) for x in batch]) 202 | 203 | phone_lengths = torch.LongTensor(len(batch)) 204 | phone_padded = torch.LongTensor(len(batch), max_phone_len) 205 | phone_dur_padded = torch.LongTensor(len(batch), max_phone_len) 206 | score_padded = torch.LongTensor(len(batch), max_phone_len) 207 | score_dur_padded = torch.LongTensor(len(batch), max_phone_len) 208 | pitch_padded = torch.FloatTensor(len(batch), max_spec_len) 209 | slurs_padded = torch.LongTensor(len(batch), max_phone_len) 210 | phone_padded.zero_() 211 | phone_dur_padded.zero_() 212 | score_padded.zero_() 213 | score_dur_padded.zero_() 214 | pitch_padded.zero_() 215 | slurs_padded.zero_() 216 | 217 | spec_lengths = torch.LongTensor(len(batch)) 218 | wave_lengths = torch.LongTensor(len(batch)) 219 | spec_padded = torch.FloatTensor(len(batch), batch[0][6].size(0), max_spec_len) 220 | wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) 221 | spec_padded.zero_() 222 | wave_padded.zero_() 223 | 224 | for i in range(len(ids_sorted_decreasing)): 225 | row = batch[ids_sorted_decreasing[i]] 226 | 227 | phone = row[0] 228 | phone_padded[i, : phone.size(0)] = phone 229 | phone_lengths[i] = phone.size(0) 230 | 231 | phone_dur = row[1] 232 | phone_dur_padded[i, : phone_dur.size(0)] = phone_dur 233 | 234 | score = row[2] 235 | score_padded[i, : score.size(0)] = score 236 | 237 | score_dur = row[3] 238 | score_dur_padded[i, : score_dur.size(0)] = score_dur 239 | 240 | pitch = row[4] 241 | pitch_padded[i, : pitch.size(0)] = pitch 242 | 243 | slurs = row[5] 244 | slurs_padded[i, : slurs.size(0)] = slurs 245 | 246 | spec = row[6] 247 | spec_padded[i, :, : spec.size(1)] = spec 248 | spec_lengths[i] = spec.size(1) 249 | 250 | wave = row[7] 251 | wave_padded[i, :, : wave.size(1)] = wave 252 | wave_lengths[i] = wave.size(1) 253 | 254 | return ( 255 | phone_padded, 256 | phone_lengths, 257 | phone_dur_padded, 258 | score_padded, 259 | score_dur_padded, 260 | pitch_padded, 261 | slurs_padded, 262 | spec_padded, 263 | spec_lengths, 264 | wave_padded, 265 | wave_lengths, 266 | ) 267 | 268 | 269 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 270 | """ 271 | Maintain similar input lengths in a batch. 272 | Length groups are specified by boundaries. 273 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 274 | 275 | It removes samples which are not included in the boundaries. 276 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 277 | """ 278 | 279 | def __init__( 280 | self, 281 | dataset, 282 | batch_size, 283 | boundaries, 284 | num_replicas=None, 285 | rank=None, 286 | shuffle=True, 287 | ): 288 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 289 | self.lengths = dataset.lengths 290 | self.batch_size = batch_size 291 | self.boundaries = boundaries 292 | 293 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 294 | self.total_size = sum(self.num_samples_per_bucket) 295 | self.num_samples = self.total_size // self.num_replicas 296 | 297 | def _create_buckets(self): 298 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 299 | for i in range(len(self.lengths)): 300 | length = self.lengths[i] 301 | idx_bucket = self._bisect(length) 302 | if idx_bucket != -1: 303 | buckets[idx_bucket].append(i) 304 | 305 | for i in range(len(buckets) - 1, 0, -1): 306 | if len(buckets[i]) == 0: 307 | buckets.pop(i) 308 | self.boundaries.pop(i + 1) 309 | 310 | num_samples_per_bucket = [] 311 | for i in range(len(buckets)): 312 | len_bucket = len(buckets[i]) 313 | total_batch_size = self.num_replicas * self.batch_size 314 | rem = ( 315 | total_batch_size - (len_bucket % total_batch_size) 316 | ) % total_batch_size 317 | num_samples_per_bucket.append(len_bucket + rem) 318 | return buckets, num_samples_per_bucket 319 | 320 | def __iter__(self): 321 | # deterministically shuffle based on epoch 322 | g = torch.Generator() 323 | g.manual_seed(self.epoch) 324 | 325 | indices = [] 326 | if self.shuffle: 327 | for bucket in self.buckets: 328 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 329 | else: 330 | for bucket in self.buckets: 331 | indices.append(list(range(len(bucket)))) 332 | 333 | batches = [] 334 | for i in range(len(self.buckets)): 335 | bucket = self.buckets[i] 336 | len_bucket = len(bucket) 337 | ids_bucket = indices[i] 338 | num_samples_bucket = self.num_samples_per_bucket[i] 339 | 340 | # add extra samples to make it evenly divisible 341 | rem = num_samples_bucket - len_bucket 342 | ids_bucket = ( 343 | ids_bucket 344 | + ids_bucket * (rem // len_bucket) 345 | + ids_bucket[: (rem % len_bucket)] 346 | ) 347 | 348 | # subsample 349 | ids_bucket = ids_bucket[self.rank :: self.num_replicas] 350 | 351 | # batching 352 | for j in range(len(ids_bucket) // self.batch_size): 353 | batch = [ 354 | bucket[idx] 355 | for idx in ids_bucket[ 356 | j * self.batch_size : (j + 1) * self.batch_size 357 | ] 358 | ] 359 | batches.append(batch) 360 | 361 | if self.shuffle: 362 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 363 | batches = [batches[i] for i in batch_ids] 364 | self.batches = batches 365 | 366 | assert len(self.batches) * self.batch_size == self.num_samples 367 | return iter(self.batches) 368 | 369 | def _bisect(self, x, lo=0, hi=None): 370 | if hi is None: 371 | hi = len(self.boundaries) - 1 372 | 373 | if hi > lo: 374 | mid = (hi + lo) // 2 375 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 376 | return mid 377 | elif x <= self.boundaries[mid]: 378 | return self._bisect(x, lo, mid) 379 | else: 380 | return self._bisect(x, mid + 1, hi) 381 | else: 382 | return -1 383 | 384 | def __len__(self): 385 | return self.num_samples // self.batch_size 386 | -------------------------------------------------------------------------------- /evaluate/evaluate_f0.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2021 Wen-Chin Huang and Tomoki Hayashi 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | """Evaluate log-F0 RMSE between generated and groundtruth audios based on World.""" 7 | 8 | import argparse 9 | import fnmatch 10 | import logging 11 | import multiprocessing as mp 12 | import os 13 | from typing import Dict, List, Tuple 14 | 15 | import librosa 16 | import numpy as np 17 | import pysptk 18 | import pyworld as pw 19 | import soundfile as sf 20 | from fastdtw import fastdtw 21 | from scipy import spatial 22 | 23 | 24 | def find_files( 25 | root_dir: str, query: List[str] = ["*.flac", "*.wav"], include_root_dir: bool = True 26 | ) -> List[str]: 27 | """Find files recursively. 28 | 29 | Args: 30 | root_dir (str): Root root_dir to find. 31 | query (List[str]): Query to find. 32 | include_root_dir (bool): If False, root_dir name is not included. 33 | 34 | Returns: 35 | List[str]: List of found filenames. 36 | 37 | """ 38 | files = [] 39 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 40 | for q in query: 41 | for filename in fnmatch.filter(filenames, q): 42 | files.append(os.path.join(root, filename)) 43 | if not include_root_dir: 44 | files = [file_.replace(root_dir + "/", "") for file_ in files] 45 | 46 | return files 47 | 48 | 49 | def world_extract( 50 | x: np.ndarray, 51 | fs: int, 52 | f0min: int = 40, 53 | f0max: int = 800, 54 | n_fft: int = 512, 55 | n_shift: int = 256, 56 | mcep_dim: int = 25, 57 | mcep_alpha: float = 0.41, 58 | ) -> np.ndarray: 59 | """Extract World-based acoustic features. 60 | 61 | Args: 62 | x (ndarray): 1D waveform array. 63 | fs (int): Minimum f0 value (default=40). 64 | f0 (int): Maximum f0 value (default=800). 65 | n_shift (int): Shift length in point (default=256). 66 | n_fft (int): FFT length in point (default=512). 67 | n_shift (int): Shift length in point (default=256). 68 | mcep_dim (int): Dimension of mel-cepstrum (default=25). 69 | mcep_alpha (float): All pass filter coefficient (default=0.41). 70 | 71 | Returns: 72 | ndarray: Mel-cepstrum with the size (N, n_fft). 73 | ndarray: F0 sequence (N,). 74 | 75 | """ 76 | # extract features 77 | x = x.astype(np.float64) 78 | f0, time_axis = pw.harvest( 79 | x, 80 | fs, 81 | f0_floor=f0min, 82 | f0_ceil=f0max, 83 | frame_period=n_shift / fs * 1000, 84 | ) 85 | sp = pw.cheaptrick(x, f0, time_axis, fs, fft_size=n_fft) 86 | if mcep_dim is None or mcep_alpha is None: 87 | mcep_dim, mcep_alpha = _get_best_mcep_params(fs) 88 | mcep = pysptk.sp2mc(sp, mcep_dim, mcep_alpha) 89 | 90 | return mcep, f0 91 | 92 | 93 | def _get_basename(path: str) -> str: 94 | return os.path.splitext(os.path.split(path)[-1])[0] 95 | 96 | 97 | def _get_best_mcep_params(fs: int) -> Tuple[int, float]: 98 | if fs == 16000: 99 | return 23, 0.42 100 | elif fs == 22050: 101 | return 34, 0.45 102 | elif fs == 24000: 103 | return 34, 0.46 104 | elif fs == 44100: 105 | return 39, 0.53 106 | elif fs == 48000: 107 | return 39, 0.55 108 | else: 109 | raise ValueError(f"Not found the setting for {fs}.") 110 | 111 | 112 | def calculate( 113 | file_list: List[str], 114 | gt_file_list: List[str], 115 | args: argparse.Namespace, 116 | f0_rmse_dict: Dict[str, float], 117 | ): 118 | """Calculate log-F0 RMSE.""" 119 | for i, gen_path in enumerate(file_list): 120 | corresponding_list = list( 121 | filter( 122 | lambda gt_path: _get_basename(gt_path)[:-7] in gen_path, gt_file_list 123 | ) 124 | ) 125 | assert len(corresponding_list) == 1 126 | gt_path = corresponding_list[0] 127 | gt_basename = _get_basename(gt_path) 128 | 129 | # load wav file as int16 130 | gen_x, gen_fs = sf.read(gen_path, dtype="int16") 131 | gt_x, gt_fs = sf.read(gt_path, dtype="int16") 132 | 133 | fs = gen_fs 134 | if gen_fs != gt_fs: 135 | gt_x = librosa.resample(gt_x.astype(np.float), gt_fs, gen_fs) 136 | 137 | # extract ground truth and converted features 138 | gen_mcep, gen_f0 = world_extract( 139 | x=gen_x, 140 | fs=fs, 141 | f0min=args.f0min, 142 | f0max=args.f0max, 143 | n_fft=args.n_fft, 144 | n_shift=args.n_shift, 145 | mcep_dim=args.mcep_dim, 146 | mcep_alpha=args.mcep_alpha, 147 | ) 148 | gt_mcep, gt_f0 = world_extract( 149 | x=gt_x, 150 | fs=fs, 151 | f0min=args.f0min, 152 | f0max=args.f0max, 153 | n_fft=args.n_fft, 154 | n_shift=args.n_shift, 155 | mcep_dim=args.mcep_dim, 156 | mcep_alpha=args.mcep_alpha, 157 | ) 158 | 159 | # DTW 160 | _, path = fastdtw(gen_mcep, gt_mcep, dist=spatial.distance.euclidean) 161 | twf = np.array(path).T 162 | gen_f0_dtw = gen_f0[twf[0]] 163 | gt_f0_dtw = gt_f0[twf[1]] 164 | 165 | # Get voiced part 166 | nonzero_idxs = np.where((gen_f0_dtw != 0) & (gt_f0_dtw != 0))[0] 167 | gen_f0_dtw_voiced = np.log(gen_f0_dtw[nonzero_idxs]) 168 | gt_f0_dtw_voiced = np.log(gt_f0_dtw[nonzero_idxs]) 169 | 170 | # log F0 RMSE 171 | log_f0_rmse = np.sqrt(np.mean((gen_f0_dtw_voiced - gt_f0_dtw_voiced) ** 2)) 172 | logging.info(f"{gt_basename} {log_f0_rmse:.4f}") 173 | f0_rmse_dict[gt_basename] = log_f0_rmse 174 | 175 | 176 | def get_parser() -> argparse.Namespace: 177 | """Get argument parser.""" 178 | parser = argparse.ArgumentParser(description="Evaluate Mel-cepstrum distortion.") 179 | parser.add_argument( 180 | "gen_wavdir_or_wavscp", 181 | type=str, 182 | help="Path of directory or wav.scp for generated waveforms.", 183 | ) 184 | parser.add_argument( 185 | "gt_wavdir_or_wavscp", 186 | type=str, 187 | help="Path of directory or wav.scp for ground truth waveforms.", 188 | ) 189 | parser.add_argument( 190 | "--outdir", 191 | type=str, 192 | help="Path of directory to write the results.", 193 | ) 194 | 195 | # analysis related 196 | parser.add_argument( 197 | "--mcep_dim", 198 | default=None, 199 | type=int, 200 | help=( 201 | "Dimension of mel cepstrum coefficients. " 202 | "If None, automatically set to the best dimension for the sampling." 203 | ), 204 | ) 205 | parser.add_argument( 206 | "--mcep_alpha", 207 | default=None, 208 | type=float, 209 | help=( 210 | "All pass constant for mel-cepstrum analysis. " 211 | "If None, automatically set to the best dimension for the sampling." 212 | ), 213 | ) 214 | parser.add_argument( 215 | "--n_fft", 216 | default=1024, 217 | type=int, 218 | help="The number of FFT points.", 219 | ) 220 | parser.add_argument( 221 | "--n_shift", 222 | default=256, 223 | type=int, 224 | help="The number of shift points.", 225 | ) 226 | parser.add_argument( 227 | "--f0min", 228 | default=40, 229 | type=int, 230 | help="Minimum f0 value.", 231 | ) 232 | parser.add_argument( 233 | "--f0max", 234 | default=800, 235 | type=int, 236 | help="Maximum f0 value.", 237 | ) 238 | parser.add_argument( 239 | "--nj", 240 | default=16, 241 | type=int, 242 | help="Number of parallel jobs.", 243 | ) 244 | parser.add_argument( 245 | "--verbose", 246 | default=1, 247 | type=int, 248 | help="Verbosity level. Higher is more logging.", 249 | ) 250 | return parser 251 | 252 | 253 | def main(): 254 | """Run log-F0 RMSE calculation in parallel.""" 255 | args = get_parser().parse_args() 256 | 257 | # logging info 258 | if args.verbose > 1: 259 | logging.basicConfig( 260 | level=logging.DEBUG, 261 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 262 | ) 263 | elif args.verbose > 0: 264 | logging.basicConfig( 265 | level=logging.INFO, 266 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 267 | ) 268 | else: 269 | logging.basicConfig( 270 | level=logging.WARN, 271 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 272 | ) 273 | logging.warning("Skip DEBUG/INFO messages") 274 | 275 | # find files 276 | if os.path.isdir(args.gen_wavdir_or_wavscp): 277 | gen_files = sorted(find_files(args.gen_wavdir_or_wavscp)) 278 | else: 279 | with open(args.gen_wavdir_or_wavscp) as f: 280 | gen_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 281 | if gen_files[0].endswith("|"): 282 | raise ValueError("Not supported wav.scp format.") 283 | if os.path.isdir(args.gt_wavdir_or_wavscp): 284 | gt_files = sorted(find_files(args.gt_wavdir_or_wavscp)) 285 | else: 286 | with open(args.gt_wavdir_or_wavscp) as f: 287 | gt_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 288 | if gt_files[0].endswith("|"): 289 | raise ValueError("Not supported wav.scp format.") 290 | 291 | # Get and divide list 292 | if len(gen_files) == 0: 293 | raise FileNotFoundError("Not found any generated audio files.") 294 | if len(gen_files) > len(gt_files): 295 | raise ValueError( 296 | "#groundtruth files are less than #generated files " 297 | f"(#gen={len(gen_files)} vs. #gt={len(gt_files)}). " 298 | "Please check the groundtruth directory." 299 | ) 300 | logging.info("The number of utterances = %d" % len(gen_files)) 301 | file_lists = np.array_split(gen_files, args.nj) 302 | file_lists = [f_list.tolist() for f_list in file_lists] 303 | 304 | # multi processing 305 | with mp.Manager() as manager: 306 | log_f0_rmse_dict = manager.dict() 307 | processes = [] 308 | # for f in file_lists: 309 | # calculate(f, gt_files, args, log_f0_rmse_dict) 310 | for f in file_lists: 311 | p = mp.Process(target=calculate, args=(f, gt_files, args, log_f0_rmse_dict)) 312 | p.start() 313 | processes.append(p) 314 | 315 | # wait for all process 316 | for p in processes: 317 | p.join() 318 | 319 | # convert to standard list 320 | log_f0_rmse_dict = dict(log_f0_rmse_dict) 321 | 322 | # calculate statistics 323 | mean_log_f0_rmse = np.mean(np.array([v for v in log_f0_rmse_dict.values()])) 324 | std_log_f0_rmse = np.std(np.array([v for v in log_f0_rmse_dict.values()])) 325 | logging.info(f"Average: {mean_log_f0_rmse:.4f} ± {std_log_f0_rmse:.4f}") 326 | 327 | # write results 328 | if args.outdir is None: 329 | if os.path.isdir(args.gen_wavdir_or_wavscp): 330 | args.outdir = args.gen_wavdir_or_wavscp 331 | else: 332 | args.outdir = os.path.dirname(args.gen_wavdir_or_wavscp) 333 | os.makedirs(args.outdir, exist_ok=True) 334 | with open(f"{args.outdir}/utt2log_f0_rmse", "w") as f: 335 | for utt_id in sorted(log_f0_rmse_dict.keys()): 336 | log_f0_rmse = log_f0_rmse_dict[utt_id] 337 | f.write(f"{utt_id} {log_f0_rmse:.4f}\n") 338 | with open(f"{args.outdir}/log_f0_rmse_avg_result.txt", "w") as f: 339 | f.write(f"#utterances: {len(gen_files)}\n") 340 | f.write(f"Average: {mean_log_f0_rmse:.4f} ± {std_log_f0_rmse:.4f}") 341 | 342 | logging.info("Successfully finished log-F0 RMSE evaluation.") 343 | 344 | 345 | if __name__ == "__main__": 346 | main() 347 | -------------------------------------------------------------------------------- /evaluate/evaluate_mcd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2020 Wen-Chin Huang and Tomoki Hayashi 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | """Evaluate MCD between generated and groundtruth audios with SPTK-based mcep.""" 7 | 8 | import argparse 9 | import fnmatch 10 | import logging 11 | import multiprocessing as mp 12 | import os 13 | from typing import Dict, List, Tuple 14 | 15 | import librosa 16 | import numpy as np 17 | import pysptk 18 | import soundfile as sf 19 | from fastdtw import fastdtw 20 | from scipy import spatial 21 | 22 | 23 | def find_files( 24 | root_dir: str, query: List[str] = ["*.flac", "*.wav"], include_root_dir: bool = True 25 | ) -> List[str]: 26 | """Find files recursively. 27 | 28 | Args: 29 | root_dir (str): Root root_dir to find. 30 | query (List[str]): Query to find. 31 | include_root_dir (bool): If False, root_dir name is not included. 32 | 33 | Returns: 34 | List[str]: List of found filenames. 35 | 36 | """ 37 | files = [] 38 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 39 | for q in query: 40 | for filename in fnmatch.filter(filenames, q): 41 | files.append(os.path.join(root, filename)) 42 | if not include_root_dir: 43 | files = [file_.replace(root_dir + "/", "") for file_ in files] 44 | 45 | return files 46 | 47 | 48 | def sptk_extract( 49 | x: np.ndarray, 50 | fs: int, 51 | n_fft: int = 512, 52 | n_shift: int = 256, 53 | mcep_dim: int = 25, 54 | mcep_alpha: float = 0.41, 55 | is_padding: bool = False, 56 | ) -> np.ndarray: 57 | """Extract SPTK-based mel-cepstrum. 58 | 59 | Args: 60 | x (ndarray): 1D waveform array. 61 | fs (int): Sampling rate 62 | n_fft (int): FFT length in point (default=512). 63 | n_shift (int): Shift length in point (default=256). 64 | mcep_dim (int): Dimension of mel-cepstrum (default=25). 65 | mcep_alpha (float): All pass filter coefficient (default=0.41). 66 | is_padding (bool): Whether to pad the end of signal (default=False). 67 | 68 | Returns: 69 | ndarray: Mel-cepstrum with the size (N, n_fft). 70 | 71 | """ 72 | # perform padding 73 | if is_padding: 74 | n_pad = n_fft - (len(x) - n_fft) % n_shift 75 | x = np.pad(x, (0, n_pad), "reflect") 76 | 77 | # get number of frames 78 | n_frame = (len(x) - n_fft) // n_shift + 1 79 | 80 | # get window function 81 | win = pysptk.sptk.hamming(n_fft) 82 | 83 | # check mcep and alpha 84 | if mcep_dim is None or mcep_alpha is None: 85 | mcep_dim, mcep_alpha = _get_best_mcep_params(fs) 86 | 87 | # calculate spectrogram 88 | mcep = [ 89 | pysptk.mcep( 90 | x[n_shift * i : n_shift * i + n_fft] * win, 91 | mcep_dim, 92 | mcep_alpha, 93 | eps=1e-6, 94 | etype=1, 95 | ) 96 | for i in range(n_frame) 97 | ] 98 | 99 | return np.stack(mcep) 100 | 101 | 102 | def _get_basename(path: str) -> str: 103 | return os.path.splitext(os.path.split(path)[-1])[0] 104 | 105 | 106 | def _get_best_mcep_params(fs: int) -> Tuple[int, float]: 107 | if fs == 16000: 108 | return 23, 0.42 109 | elif fs == 22050: 110 | return 34, 0.45 111 | elif fs == 24000: 112 | return 34, 0.46 113 | elif fs == 44100: 114 | return 39, 0.53 115 | elif fs == 48000: 116 | return 39, 0.55 117 | else: 118 | raise ValueError(f"Not found the setting for {fs}.") 119 | 120 | 121 | def calculate( 122 | file_list: List[str], 123 | gt_file_list: List[str], 124 | args: argparse.Namespace, 125 | mcd_dict: Dict, 126 | ): 127 | """Calculate MCD.""" 128 | for i, gen_path in enumerate(file_list): 129 | corresponding_list = list( 130 | filter( 131 | lambda gt_path: _get_basename(gt_path)[:-7] in gen_path, gt_file_list 132 | ) 133 | ) 134 | print("corresponding_list", corresponding_list) 135 | assert len(corresponding_list) == 1 136 | gt_path = corresponding_list[0] 137 | gt_basename = _get_basename(gt_path) 138 | 139 | # load wav file as int16 140 | gen_x, gen_fs = sf.read(gen_path, dtype="int16") 141 | gt_x, gt_fs = sf.read(gt_path, dtype="int16") 142 | 143 | fs = gen_fs 144 | if gen_fs != gt_fs: 145 | gt_x = librosa.resample(gt_x.astype(np.float), gt_fs, gen_fs) 146 | 147 | # extract ground truth and converted features 148 | gen_mcep = sptk_extract( 149 | x=gen_x, 150 | fs=fs, 151 | n_fft=args.n_fft, 152 | n_shift=args.n_shift, 153 | mcep_dim=args.mcep_dim, 154 | mcep_alpha=args.mcep_alpha, 155 | ) 156 | gt_mcep = sptk_extract( 157 | x=gt_x, 158 | fs=fs, 159 | n_fft=args.n_fft, 160 | n_shift=args.n_shift, 161 | mcep_dim=args.mcep_dim, 162 | mcep_alpha=args.mcep_alpha, 163 | ) 164 | 165 | # DTW 166 | _, path = fastdtw(gen_mcep, gt_mcep, dist=spatial.distance.euclidean) 167 | twf = np.array(path).T 168 | gen_mcep_dtw = gen_mcep[twf[0]] 169 | gt_mcep_dtw = gt_mcep[twf[1]] 170 | 171 | # MCD 172 | diff2sum = np.sum((gen_mcep_dtw - gt_mcep_dtw) ** 2, 1) 173 | mcd = np.mean(10.0 / np.log(10.0) * np.sqrt(2 * diff2sum), 0) 174 | logging.info(f"{gt_basename} {mcd:.4f}") 175 | mcd_dict[gt_basename] = mcd 176 | 177 | 178 | def get_parser() -> argparse.Namespace: 179 | """Get argument parser.""" 180 | parser = argparse.ArgumentParser(description="Evaluate Mel-cepstrum distortion.") 181 | parser.add_argument( 182 | "gen_wavdir_or_wavscp", 183 | type=str, 184 | help="Path of directory or wav.scp for generated waveforms.", 185 | ) 186 | parser.add_argument( 187 | "gt_wavdir_or_wavscp", 188 | type=str, 189 | help="Path of directory or wav.scp for ground truth waveforms.", 190 | ) 191 | parser.add_argument( 192 | "--outdir", 193 | type=str, 194 | help="Path of directory to write the results.", 195 | ) 196 | 197 | # analysis related 198 | parser.add_argument( 199 | "--mcep_dim", 200 | default=None, 201 | type=int, 202 | help=( 203 | "Dimension of mel cepstrum coefficients. " 204 | "If None, automatically set to the best dimension for the sampling." 205 | ), 206 | ) 207 | parser.add_argument( 208 | "--mcep_alpha", 209 | default=None, 210 | type=float, 211 | help=( 212 | "All pass constant for mel-cepstrum analysis. " 213 | "If None, automatically set to the best dimension for the sampling." 214 | ), 215 | ) 216 | parser.add_argument( 217 | "--n_fft", 218 | default=1024, 219 | type=int, 220 | help="The number of FFT points.", 221 | ) 222 | parser.add_argument( 223 | "--n_shift", 224 | default=256, 225 | type=int, 226 | help="The number of shift points.", 227 | ) 228 | parser.add_argument( 229 | "--nj", 230 | default=16, 231 | type=int, 232 | help="Number of parallel jobs.", 233 | ) 234 | parser.add_argument( 235 | "--verbose", 236 | default=1, 237 | type=int, 238 | help="Verbosity level. Higher is more logging.", 239 | ) 240 | return parser 241 | 242 | 243 | def main(): 244 | """Run MCD calculation in parallel.""" 245 | args = get_parser().parse_args() 246 | 247 | # logging info 248 | # if args.verbose > 1: 249 | # logging.basicConfig( 250 | # level=logging.DEBUG, 251 | # format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 252 | # ) 253 | # elif args.verbose > 0: 254 | # logging.basicConfig( 255 | # level=logging.INFO, 256 | # format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 257 | # ) 258 | # else: 259 | # logging.basicConfig( 260 | # level=logging.WARN, 261 | # format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 262 | # ) 263 | # logging.warning("Skip DEBUG/INFO messages") 264 | 265 | # find files 266 | if os.path.isdir(args.gen_wavdir_or_wavscp): 267 | gen_files = sorted(find_files(args.gen_wavdir_or_wavscp)) 268 | else: 269 | with open(args.gen_wavdir_or_wavscp) as f: 270 | gen_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 271 | if gen_files[0].endswith("|"): 272 | raise ValueError("Not supported wav.scp format.") 273 | if os.path.isdir(args.gt_wavdir_or_wavscp): 274 | gt_files = sorted(find_files(args.gt_wavdir_or_wavscp)) 275 | else: 276 | with open(args.gt_wavdir_or_wavscp) as f: 277 | gt_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 278 | if gt_files[0].endswith("|"): 279 | raise ValueError("Not supported wav.scp format.") 280 | 281 | # Get and divide list 282 | if len(gen_files) == 0: 283 | raise FileNotFoundError("Not found any generated audio files.") 284 | if len(gen_files) > len(gt_files): 285 | raise ValueError( 286 | "#groundtruth files are less than #generated files " 287 | f"(#gen={len(gen_files)} vs. #gt={len(gt_files)}). " 288 | "Please check the groundtruth directory." 289 | ) 290 | logging.info("The number of utterances = %d" % len(gen_files)) 291 | file_lists = np.array_split(gen_files, args.nj) 292 | file_lists = [f_list.tolist() for f_list in file_lists] 293 | 294 | # multi processing 295 | with mp.Manager() as manager: 296 | mcd_dict = manager.dict() 297 | processes = [] 298 | for f in file_lists: 299 | p = mp.Process(target=calculate, args=(f, gt_files, args, mcd_dict)) 300 | p.start() 301 | processes.append(p) 302 | 303 | # wait for all process 304 | for p in processes: 305 | p.join() 306 | 307 | # convert to standard list 308 | mcd_dict = dict(mcd_dict) 309 | 310 | # calculate statistics 311 | mean_mcd = np.mean(np.array([v for v in mcd_dict.values()])) 312 | std_mcd = np.std(np.array([v for v in mcd_dict.values()])) 313 | logging.info(f"Average: {mean_mcd:.4f} ± {std_mcd:.4f}") 314 | 315 | # write results 316 | if args.outdir is None: 317 | if os.path.isdir(args.gen_wavdir_or_wavscp): 318 | args.outdir = args.gen_wavdir_or_wavscp 319 | else: 320 | args.outdir = os.path.dirname(args.gen_wavdir_or_wavscp) 321 | os.makedirs(args.outdir, exist_ok=True) 322 | with open(f"{args.outdir}/utt2mcd", "w") as f: 323 | for utt_id in sorted(mcd_dict.keys()): 324 | mcd = mcd_dict[utt_id] 325 | f.write(f"{utt_id} {mcd:.4f}\n") 326 | with open(f"{args.outdir}/mcd_avg_result.txt", "w") as f: 327 | f.write(f"#utterances: {len(gen_files)}\n") 328 | f.write(f"Average: {mean_mcd:.4f} ± {std_mcd:.4f}") 329 | 330 | logging.info("Successfully finished MCD evaluation.") 331 | 332 | 333 | if __name__ == "__main__": 334 | main() 335 | -------------------------------------------------------------------------------- /evaluate/evaluate_semitone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2021 Wen-Chin Huang and Tomoki Hayashi 4 | # Copyright 2022 Shuai Guo 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Evaluate semitone ACC between generated and groundtruth audios based on World.""" 8 | 9 | import argparse 10 | import fnmatch 11 | import logging 12 | import multiprocessing as mp 13 | import os 14 | from math import log2, pow 15 | from typing import Dict, List, Tuple 16 | 17 | import librosa 18 | import numpy as np 19 | import pysptk 20 | import pyworld as pw 21 | import soundfile as sf 22 | from fastdtw import fastdtw 23 | from scipy import spatial 24 | 25 | 26 | def _Hz2Semitone(freq): 27 | """_Hz2Semitone.""" 28 | A4 = 440 29 | C0 = A4 * pow(2, -4.75) 30 | name = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] 31 | 32 | if freq == 0: 33 | return "Sil" # silence 34 | else: 35 | h = round(12 * log2(freq / C0)) 36 | octave = h // 12 37 | n = h % 12 38 | return name[n] + "_" + str(octave) 39 | 40 | 41 | def find_files( 42 | root_dir: str, query: List[str] = ["*.flac", "*.wav"], include_root_dir: bool = True 43 | ) -> List[str]: 44 | """Find files recursively. 45 | 46 | Args: 47 | root_dir (str): Root root_dir to find. 48 | query (List[str]): Query to find. 49 | include_root_dir (bool): If False, root_dir name is not included. 50 | 51 | Returns: 52 | List[str]: List of found filenames. 53 | 54 | """ 55 | files = [] 56 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 57 | for q in query: 58 | for filename in fnmatch.filter(filenames, q): 59 | files.append(os.path.join(root, filename)) 60 | if not include_root_dir: 61 | files = [file_.replace(root_dir + "/", "") for file_ in files] 62 | 63 | return files 64 | 65 | 66 | def world_extract( 67 | x: np.ndarray, 68 | fs: int, 69 | f0min: int = 40, 70 | f0max: int = 800, 71 | n_fft: int = 512, 72 | n_shift: int = 256, 73 | mcep_dim: int = 25, 74 | mcep_alpha: float = 0.41, 75 | ) -> np.ndarray: 76 | """Extract World-based acoustic features. 77 | 78 | Args: 79 | x (ndarray): 1D waveform array. 80 | fs (int): Minimum f0 value (default=40). 81 | f0 (int): Maximum f0 value (default=800). 82 | n_shift (int): Shift length in point (default=256). 83 | n_fft (int): FFT length in point (default=512). 84 | n_shift (int): Shift length in point (default=256). 85 | mcep_dim (int): Dimension of mel-cepstrum (default=25). 86 | mcep_alpha (float): All pass filter coefficient (default=0.41). 87 | 88 | Returns: 89 | ndarray: Mel-cepstrum with the size (N, n_fft). 90 | ndarray: F0 sequence (N,). 91 | 92 | """ 93 | # extract features 94 | x = x.astype(np.float64) 95 | f0, time_axis = pw.harvest( 96 | x, 97 | fs, 98 | f0_floor=f0min, 99 | f0_ceil=f0max, 100 | frame_period=n_shift / fs * 1000, 101 | ) 102 | sp = pw.cheaptrick(x, f0, time_axis, fs, fft_size=n_fft) 103 | if mcep_dim is None or mcep_alpha is None: 104 | mcep_dim, mcep_alpha = _get_best_mcep_params(fs) 105 | mcep = pysptk.sp2mc(sp, mcep_dim, mcep_alpha) 106 | 107 | return mcep, f0 108 | 109 | 110 | def _get_basename(path: str) -> str: 111 | return os.path.splitext(os.path.split(path)[-1])[0] 112 | 113 | 114 | def _get_best_mcep_params(fs: int) -> Tuple[int, float]: 115 | if fs == 16000: 116 | return 23, 0.42 117 | elif fs == 22050: 118 | return 34, 0.45 119 | elif fs == 24000: 120 | return 34, 0.46 121 | elif fs == 44100: 122 | return 39, 0.53 123 | elif fs == 48000: 124 | return 39, 0.55 125 | else: 126 | raise ValueError(f"Not found the setting for {fs}.") 127 | 128 | 129 | def calculate( 130 | file_list: List[str], 131 | gt_file_list: List[str], 132 | args: argparse.Namespace, 133 | semitone_acc_dict: Dict[str, float], 134 | ): 135 | """Calculate semitone ACC.""" 136 | for i, gen_path in enumerate(file_list): 137 | corresponding_list = list( 138 | filter( 139 | lambda gt_path: _get_basename(gt_path)[:-7] in gen_path, gt_file_list 140 | ) 141 | ) 142 | assert len(corresponding_list) == 1 143 | gt_path = corresponding_list[0] 144 | gt_basename = _get_basename(gt_path) 145 | 146 | # load wav file as int16 147 | gen_x, gen_fs = sf.read(gen_path, dtype="int16") 148 | gt_x, gt_fs = sf.read(gt_path, dtype="int16") 149 | 150 | fs = gen_fs 151 | if gen_fs != gt_fs: 152 | gt_x = librosa.resample(gt_x.astype(np.float), gt_fs, gen_fs) 153 | 154 | # extract ground truth and converted features 155 | gen_mcep, gen_f0 = world_extract( 156 | x=gen_x, 157 | fs=fs, 158 | f0min=args.f0min, 159 | f0max=args.f0max, 160 | n_fft=args.n_fft, 161 | n_shift=args.n_shift, 162 | mcep_dim=args.mcep_dim, 163 | mcep_alpha=args.mcep_alpha, 164 | ) 165 | gt_mcep, gt_f0 = world_extract( 166 | x=gt_x, 167 | fs=fs, 168 | f0min=args.f0min, 169 | f0max=args.f0max, 170 | n_fft=args.n_fft, 171 | n_shift=args.n_shift, 172 | mcep_dim=args.mcep_dim, 173 | mcep_alpha=args.mcep_alpha, 174 | ) 175 | 176 | # DTW 177 | _, path = fastdtw(gen_mcep, gt_mcep, dist=spatial.distance.euclidean) 178 | twf = np.array(path).T 179 | gen_f0_dtw = gen_f0[twf[0]] 180 | gt_f0_dtw = gt_f0[twf[1]] 181 | 182 | # Semitone ACC 183 | semitone_GT = np.array([_Hz2Semitone(_f0) for _f0 in gt_f0_dtw]) 184 | semitone_predict = np.array([_Hz2Semitone(_f0) for _f0 in gen_f0_dtw]) 185 | semitone_ACC = float((semitone_GT == semitone_predict).sum()) / len(semitone_GT) 186 | semitone_acc_dict[gt_basename] = semitone_ACC 187 | 188 | 189 | def get_parser() -> argparse.Namespace: 190 | """Get argument parser.""" 191 | parser = argparse.ArgumentParser(description="Evaluate Mel-cepstrum distortion.") 192 | parser.add_argument( 193 | "gen_wavdir_or_wavscp", 194 | type=str, 195 | help="Path of directory or wav.scp for generated waveforms.", 196 | ) 197 | parser.add_argument( 198 | "gt_wavdir_or_wavscp", 199 | type=str, 200 | help="Path of directory or wav.scp for ground truth waveforms.", 201 | ) 202 | parser.add_argument( 203 | "--outdir", 204 | type=str, 205 | help="Path of directory to write the results.", 206 | ) 207 | 208 | # analysis related 209 | parser.add_argument( 210 | "--mcep_dim", 211 | default=None, 212 | type=int, 213 | help=( 214 | "Dimension of mel cepstrum coefficients. " 215 | "If None, automatically set to the best dimension for the sampling." 216 | ), 217 | ) 218 | parser.add_argument( 219 | "--mcep_alpha", 220 | default=None, 221 | type=float, 222 | help=( 223 | "All pass constant for mel-cepstrum analysis. " 224 | "If None, automatically set to the best dimension for the sampling." 225 | ), 226 | ) 227 | parser.add_argument( 228 | "--n_fft", 229 | default=1024, 230 | type=int, 231 | help="The number of FFT points.", 232 | ) 233 | parser.add_argument( 234 | "--n_shift", 235 | default=256, 236 | type=int, 237 | help="The number of shift points.", 238 | ) 239 | parser.add_argument( 240 | "--f0min", 241 | default=40, 242 | type=int, 243 | help="Minimum f0 value.", 244 | ) 245 | parser.add_argument( 246 | "--f0max", 247 | default=800, 248 | type=int, 249 | help="Maximum f0 value.", 250 | ) 251 | parser.add_argument( 252 | "--nj", 253 | default=16, 254 | type=int, 255 | help="Number of parallel jobs.", 256 | ) 257 | parser.add_argument( 258 | "--verbose", 259 | default=1, 260 | type=int, 261 | help="Verbosity level. Higher is more logging.", 262 | ) 263 | return parser 264 | 265 | 266 | def main(): 267 | """Run semitone ACC calculation in parallel.""" 268 | args = get_parser().parse_args() 269 | 270 | # logging info 271 | if args.verbose > 1: 272 | logging.basicConfig( 273 | level=logging.DEBUG, 274 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 275 | ) 276 | elif args.verbose > 0: 277 | logging.basicConfig( 278 | level=logging.INFO, 279 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 280 | ) 281 | else: 282 | logging.basicConfig( 283 | level=logging.WARN, 284 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 285 | ) 286 | logging.warning("Skip DEBUG/INFO messages") 287 | 288 | # find files 289 | if os.path.isdir(args.gen_wavdir_or_wavscp): 290 | gen_files = sorted(find_files(args.gen_wavdir_or_wavscp)) 291 | else: 292 | with open(args.gen_wavdir_or_wavscp) as f: 293 | gen_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 294 | if gen_files[0].endswith("|"): 295 | raise ValueError("Not supported wav.scp format.") 296 | if os.path.isdir(args.gt_wavdir_or_wavscp): 297 | gt_files = sorted(find_files(args.gt_wavdir_or_wavscp)) 298 | else: 299 | with open(args.gt_wavdir_or_wavscp) as f: 300 | gt_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 301 | if gt_files[0].endswith("|"): 302 | raise ValueError("Not supported wav.scp format.") 303 | 304 | # Get and divide list 305 | if len(gen_files) == 0: 306 | raise FileNotFoundError("Not found any generated audio files.") 307 | if len(gen_files) > len(gt_files): 308 | raise ValueError( 309 | "#groundtruth files are less than #generated files " 310 | f"(#gen={len(gen_files)} vs. #gt={len(gt_files)}). " 311 | "Please check the groundtruth directory." 312 | ) 313 | logging.info("The number of utterances = %d" % len(gen_files)) 314 | file_lists = np.array_split(gen_files, args.nj) 315 | file_lists = [f_list.tolist() for f_list in file_lists] 316 | 317 | # multi processing 318 | with mp.Manager() as manager: 319 | semitone_acc_dict = manager.dict() 320 | processes = [] 321 | for f in file_lists: 322 | p = mp.Process( 323 | target=calculate, args=(f, gt_files, args, semitone_acc_dict) 324 | ) 325 | p.start() 326 | processes.append(p) 327 | 328 | # wait for all process 329 | for p in processes: 330 | p.join() 331 | 332 | # convert to standard list 333 | semitone_acc_dict = dict(semitone_acc_dict) 334 | 335 | # calculate statistics 336 | mean_semitone_acc = np.mean(np.array([v for v in semitone_acc_dict.values()])) 337 | logging.info(f"Average - Semitone_ACC: {mean_semitone_acc*100:.2f}%") 338 | 339 | # write results 340 | if args.outdir is None: 341 | if os.path.isdir(args.gen_wavdir_or_wavscp): 342 | args.outdir = args.gen_wavdir_or_wavscp 343 | else: 344 | args.outdir = os.path.dirname(args.gen_wavdir_or_wavscp) 345 | os.makedirs(args.outdir, exist_ok=True) 346 | with open(f"{args.outdir}/utt2semitone_acc", "w") as f: 347 | for utt_id in sorted(semitone_acc_dict.keys()): 348 | semitone_ACC = semitone_acc_dict[utt_id] 349 | f.write(f"{utt_id} {semitone_ACC*100:.2f}%\n") 350 | with open(f"{args.outdir}/semitone_acc_avg_result.txt", "w") as f: 351 | f.write(f"#utterances: {len(gen_files)}\n") 352 | f.write(f"Average: {mean_semitone_acc*100:.2f}%") 353 | 354 | logging.info("Successfully finished semitone ACC evaluation.") 355 | 356 | 357 | if __name__ == "__main__": 358 | main() 359 | -------------------------------------------------------------------------------- /evaluate/evaluate_vuv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2021 Wen-Chin Huang and Tomoki Hayashi 4 | # Copyright 2022 Shuai Guo 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Evaluate VUV error between generated and groundtruth audios based on World.""" 8 | 9 | import argparse 10 | import fnmatch 11 | import logging 12 | import multiprocessing as mp 13 | import os 14 | from typing import Dict, List, Tuple 15 | 16 | import librosa 17 | import numpy as np 18 | import pysptk 19 | import pyworld as pw 20 | import soundfile as sf 21 | from fastdtw import fastdtw 22 | from scipy import spatial 23 | 24 | 25 | def _Hz2Flag(freq): 26 | if freq == 0: 27 | return False 28 | else: 29 | return True 30 | 31 | 32 | def find_files( 33 | root_dir: str, query: List[str] = ["*.flac", "*.wav"], include_root_dir: bool = True 34 | ) -> List[str]: 35 | """Find files recursively. 36 | 37 | Args: 38 | root_dir (str): Root root_dir to find. 39 | query (List[str]): Query to find. 40 | include_root_dir (bool): If False, root_dir name is not included. 41 | 42 | Returns: 43 | List[str]: List of found filenames. 44 | 45 | """ 46 | files = [] 47 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 48 | for q in query: 49 | for filename in fnmatch.filter(filenames, q): 50 | files.append(os.path.join(root, filename)) 51 | if not include_root_dir: 52 | files = [file_.replace(root_dir + "/", "") for file_ in files] 53 | 54 | return files 55 | 56 | 57 | def world_extract( 58 | x: np.ndarray, 59 | fs: int, 60 | f0min: int = 40, 61 | f0max: int = 800, 62 | n_fft: int = 512, 63 | n_shift: int = 256, 64 | mcep_dim: int = 25, 65 | mcep_alpha: float = 0.41, 66 | ) -> np.ndarray: 67 | """Extract World-based acoustic features. 68 | 69 | Args: 70 | x (ndarray): 1D waveform array. 71 | fs (int): Minimum f0 value (default=40). 72 | f0 (int): Maximum f0 value (default=800). 73 | n_shift (int): Shift length in point (default=256). 74 | n_fft (int): FFT length in point (default=512). 75 | n_shift (int): Shift length in point (default=256). 76 | mcep_dim (int): Dimension of mel-cepstrum (default=25). 77 | mcep_alpha (float): All pass filter coefficient (default=0.41). 78 | 79 | Returns: 80 | ndarray: Mel-cepstrum with the size (N, n_fft). 81 | ndarray: F0 sequence (N,). 82 | 83 | """ 84 | # extract features 85 | x = x.astype(np.float64) 86 | f0, time_axis = pw.harvest( 87 | x, 88 | fs, 89 | f0_floor=f0min, 90 | f0_ceil=f0max, 91 | frame_period=n_shift / fs * 1000, 92 | ) 93 | sp = pw.cheaptrick(x, f0, time_axis, fs, fft_size=n_fft) 94 | if mcep_dim is None or mcep_alpha is None: 95 | mcep_dim, mcep_alpha = _get_best_mcep_params(fs) 96 | mcep = pysptk.sp2mc(sp, mcep_dim, mcep_alpha) 97 | 98 | return mcep, f0 99 | 100 | 101 | def _get_basename(path: str) -> str: 102 | return os.path.splitext(os.path.split(path)[-1])[0] 103 | 104 | 105 | def _get_best_mcep_params(fs: int) -> Tuple[int, float]: 106 | if fs == 16000: 107 | return 23, 0.42 108 | elif fs == 22050: 109 | return 34, 0.45 110 | elif fs == 24000: 111 | return 34, 0.46 112 | elif fs == 44100: 113 | return 39, 0.53 114 | elif fs == 48000: 115 | return 39, 0.55 116 | else: 117 | raise ValueError(f"Not found the setting for {fs}.") 118 | 119 | 120 | def calculate( 121 | file_list: List[str], 122 | gt_file_list: List[str], 123 | args: argparse.Namespace, 124 | vuv_err_dict: Dict[str, float], 125 | ): 126 | """Calculate VUV error.""" 127 | for i, gen_path in enumerate(file_list): 128 | corresponding_list = list( 129 | filter( 130 | lambda gt_path: _get_basename(gt_path)[:-7] in gen_path, gt_file_list 131 | ) 132 | ) 133 | assert len(corresponding_list) == 1 134 | gt_path = corresponding_list[0] 135 | gt_basename = _get_basename(gt_path) 136 | 137 | # load wav file as int16 138 | gen_x, gen_fs = sf.read(gen_path, dtype="int16") 139 | gt_x, gt_fs = sf.read(gt_path, dtype="int16") 140 | 141 | fs = gen_fs 142 | if gen_fs != gt_fs: 143 | gt_x = librosa.resample(gt_x.astype(np.float), gt_fs, gen_fs) 144 | 145 | # extract ground truth and converted features 146 | gen_mcep, gen_f0 = world_extract( 147 | x=gen_x, 148 | fs=fs, 149 | f0min=args.f0min, 150 | f0max=args.f0max, 151 | n_fft=args.n_fft, 152 | n_shift=args.n_shift, 153 | mcep_dim=args.mcep_dim, 154 | mcep_alpha=args.mcep_alpha, 155 | ) 156 | gt_mcep, gt_f0 = world_extract( 157 | x=gt_x, 158 | fs=fs, 159 | f0min=args.f0min, 160 | f0max=args.f0max, 161 | n_fft=args.n_fft, 162 | n_shift=args.n_shift, 163 | mcep_dim=args.mcep_dim, 164 | mcep_alpha=args.mcep_alpha, 165 | ) 166 | 167 | # DTW 168 | _, path = fastdtw(gen_mcep, gt_mcep, dist=spatial.distance.euclidean) 169 | twf = np.array(path).T 170 | gen_f0_dtw = gen_f0[twf[0]] 171 | gt_f0_dtw = gt_f0[twf[1]] 172 | 173 | # VUV ERR 174 | vuv_GT = np.array([_Hz2Flag(_f0) for _f0 in gt_f0_dtw]) 175 | vuv_predict = np.array([_Hz2Flag(_f0) for _f0 in gen_f0_dtw]) 176 | vuv_ERR = float((vuv_GT != vuv_predict).sum()) / len(vuv_GT) 177 | vuv_err_dict[gt_basename] = vuv_ERR 178 | 179 | 180 | def get_parser() -> argparse.Namespace: 181 | """Get argument parser.""" 182 | parser = argparse.ArgumentParser(description="Evaluate Mel-cepstrum distortion.") 183 | parser.add_argument( 184 | "gen_wavdir_or_wavscp", 185 | type=str, 186 | help="Path of directory or wav.scp for generated waveforms.", 187 | ) 188 | parser.add_argument( 189 | "gt_wavdir_or_wavscp", 190 | type=str, 191 | help="Path of directory or wav.scp for ground truth waveforms.", 192 | ) 193 | parser.add_argument( 194 | "--outdir", 195 | type=str, 196 | help="Path of directory to write the results.", 197 | ) 198 | 199 | # analysis related 200 | parser.add_argument( 201 | "--mcep_dim", 202 | default=None, 203 | type=int, 204 | help=( 205 | "Dimension of mel cepstrum coefficients. " 206 | "If None, automatically set to the best dimension for the sampling." 207 | ), 208 | ) 209 | parser.add_argument( 210 | "--mcep_alpha", 211 | default=None, 212 | type=float, 213 | help=( 214 | "All pass constant for mel-cepstrum analysis. " 215 | "If None, automatically set to the best dimension for the sampling." 216 | ), 217 | ) 218 | parser.add_argument( 219 | "--n_fft", 220 | default=1024, 221 | type=int, 222 | help="The number of FFT points.", 223 | ) 224 | parser.add_argument( 225 | "--n_shift", 226 | default=256, 227 | type=int, 228 | help="The number of shift points.", 229 | ) 230 | parser.add_argument( 231 | "--f0min", 232 | default=40, 233 | type=int, 234 | help="Minimum f0 value.", 235 | ) 236 | parser.add_argument( 237 | "--f0max", 238 | default=800, 239 | type=int, 240 | help="Maximum f0 value.", 241 | ) 242 | parser.add_argument( 243 | "--nj", 244 | default=16, 245 | type=int, 246 | help="Number of parallel jobs.", 247 | ) 248 | parser.add_argument( 249 | "--verbose", 250 | default=1, 251 | type=int, 252 | help="Verbosity level. Higher is more logging.", 253 | ) 254 | return parser 255 | 256 | 257 | def main(): 258 | """Run VUV error calculation in parallel.""" 259 | args = get_parser().parse_args() 260 | 261 | # logging info 262 | if args.verbose > 1: 263 | logging.basicConfig( 264 | level=logging.DEBUG, 265 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 266 | ) 267 | elif args.verbose > 0: 268 | logging.basicConfig( 269 | level=logging.INFO, 270 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 271 | ) 272 | else: 273 | logging.basicConfig( 274 | level=logging.WARN, 275 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 276 | ) 277 | logging.warning("Skip DEBUG/INFO messages") 278 | 279 | # find files 280 | if os.path.isdir(args.gen_wavdir_or_wavscp): 281 | gen_files = sorted(find_files(args.gen_wavdir_or_wavscp)) 282 | else: 283 | with open(args.gen_wavdir_or_wavscp) as f: 284 | gen_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 285 | if gen_files[0].endswith("|"): 286 | raise ValueError("Not supported wav.scp format.") 287 | if os.path.isdir(args.gt_wavdir_or_wavscp): 288 | gt_files = sorted(find_files(args.gt_wavdir_or_wavscp)) 289 | else: 290 | with open(args.gt_wavdir_or_wavscp) as f: 291 | gt_files = [line.strip().split(None, 1)[1] for line in f.readlines()] 292 | if gt_files[0].endswith("|"): 293 | raise ValueError("Not supported wav.scp format.") 294 | 295 | # Get and divide list 296 | if len(gen_files) == 0: 297 | raise FileNotFoundError("Not found any generated audio files.") 298 | if len(gen_files) > len(gt_files): 299 | raise ValueError( 300 | "#groundtruth files are less than #generated files " 301 | f"(#gen={len(gen_files)} vs. #gt={len(gt_files)}). " 302 | "Please check the groundtruth directory." 303 | ) 304 | logging.info("The number of utterances = %d" % len(gen_files)) 305 | file_lists = np.array_split(gen_files, args.nj) 306 | file_lists = [f_list.tolist() for f_list in file_lists] 307 | 308 | # multi processing 309 | with mp.Manager() as manager: 310 | vuv_err_dict = manager.dict() 311 | processes = [] 312 | for f in file_lists: 313 | p = mp.Process(target=calculate, args=(f, gt_files, args, vuv_err_dict)) 314 | p.start() 315 | processes.append(p) 316 | 317 | # wait for all process 318 | for p in processes: 319 | p.join() 320 | 321 | # convert to standard list 322 | vuv_err_dict = dict(vuv_err_dict) 323 | 324 | # calculate statistics 325 | mean_vuv_err = np.mean(np.array([v for v in vuv_err_dict.values()])) 326 | logging.info(f"Average - VUV_ERROR: {mean_vuv_err*100:.2f}%") 327 | 328 | # write results 329 | if args.outdir is None: 330 | if os.path.isdir(args.gen_wavdir_or_wavscp): 331 | args.outdir = args.gen_wavdir_or_wavscp 332 | else: 333 | args.outdir = os.path.dirname(args.gen_wavdir_or_wavscp) 334 | os.makedirs(args.outdir, exist_ok=True) 335 | with open(f"{args.outdir}/utt2vuv_error", "w") as f: 336 | for utt_id in sorted(vuv_err_dict.keys()): 337 | vuv_ERR = vuv_err_dict[utt_id] 338 | f.write(f"{utt_id} {vuv_ERR*100:.2f}%\n") 339 | with open(f"{args.outdir}/vuv_error_avg_result.txt", "w") as f: 340 | f.write(f"#utterances: {len(gen_files)}\n") 341 | f.write(f"Average: {mean_vuv_err*100:.2f}%") 342 | 343 | logging.info("Successfully finished VUV error evaluation.") 344 | 345 | 346 | if __name__ == "__main__": 347 | main() 348 | -------------------------------------------------------------------------------- /evaluate_score.sh: -------------------------------------------------------------------------------- 1 | echo "Generating" 2 | python vsinging_infer.py 3 | 4 | echo "Scoring" 5 | 6 | 7 | _gt_wavscp="singing_gt" 8 | _dir="evaluate" 9 | _gen_wavdir="singing_out" 10 | 11 | if [ ! -d "singing_gt" ] ; then 12 | echo "copy gt" 13 | mkdir -p "singing_gt" 14 | python normalize_wav.py 15 | fi 16 | 17 | # Objective Evaluation - MCD 18 | echo "Begin Scoring for MCD metrics on ${dset}, results are written under ${_dir}/MCD_res" 19 | 20 | mkdir -p "${_dir}/MCD_res" 21 | python evaluate/evaluate_mcd.py \ 22 | ${_gen_wavdir} \ 23 | ${_gt_wavscp} \ 24 | --outdir "${_gen_wavdir}/MCD_res" 25 | 26 | # Objective Evaluation - log-F0 RMSE 27 | echo "Begin Scoring for F0 related metrics on ${dset}, results are written under ${_dir}/F0_res" 28 | 29 | mkdir -p "${_dir}/F0_res" 30 | python evaluate/evaluate_f0.py \ 31 | ${_gen_wavdir} \ 32 | ${_gt_wavscp} \ 33 | --outdir "${_gen_wavdir}/F0_res" 34 | 35 | # Objective Evaluation - semitone ACC 36 | echo "Begin Scoring for SEMITONE related metrics on ${dset}, results are written under ${_dir}/SEMITONE_res" 37 | 38 | mkdir -p "${_dir}/SEMITONE_res" 39 | python evaluate/evaluate_semitone.py \ 40 | ${_gen_wavdir} \ 41 | ${_gt_wavscp} \ 42 | --outdir "${_gen_wavdir}/SEMITONE_res" 43 | 44 | # Objective Evaluation - VUV error 45 | echo "Begin Scoring for VUV related metrics on ${dset}, results are written under ${_dir}/VUV_res" 46 | 47 | mkdir -p "${_dir}/VUV_res" 48 | python evaluate/evaluate_vuv.py \ 49 | ${_gen_wavdir} \ 50 | ${_gt_wavscp} \ 51 | --outdir "${_gen_wavdir}/VUV_res" 52 | 53 | zip singing_out.zip singing_out/*.wav -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import commons 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1 - dr) ** 2) 26 | g_loss = torch.mean(dg**2) 27 | loss += r_loss + g_loss 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1 - dg) ** 2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.0: 53 | print("min value is ", torch.min(y)) 54 | if torch.max(y) > 1.0: 55 | print("max value is ", torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + "_" + str(y.device) 59 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 62 | dtype=y.dtype, device=y.device 63 | ) 64 | 65 | y = torch.nn.functional.pad( 66 | y.unsqueeze(1), 67 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 68 | mode="reflect", 69 | ) 70 | y = y.squeeze(1) 71 | 72 | spec = torch.stft( 73 | y, 74 | n_fft, 75 | hop_length=hop_size, 76 | win_length=win_size, 77 | window=hann_window[wnsize_dtype_device], 78 | center=center, 79 | pad_mode="reflect", 80 | normalized=False, 81 | onesided=True, 82 | return_complex=True, 83 | ) 84 | spec = torch.view_as_real(spec) 85 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 86 | return spec 87 | 88 | 89 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 90 | global mel_basis 91 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 92 | fmax_dtype_device = str(fmax) + "_" + dtype_device 93 | if fmax_dtype_device not in mel_basis: 94 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 95 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 96 | dtype=spec.dtype, device=spec.device 97 | ) 98 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 99 | spec = spectral_normalize_torch(spec) 100 | return spec 101 | 102 | 103 | def mel_spectrogram_torch( 104 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 105 | ): 106 | if torch.min(y) < -1.0: 107 | print("min value is ", torch.min(y)) 108 | if torch.max(y) > 1.0: 109 | print("max value is ", torch.max(y)) 110 | 111 | global mel_basis, hann_window 112 | dtype_device = str(y.dtype) + "_" + str(y.device) 113 | fmax_dtype_device = str(fmax) + "_" + dtype_device 114 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 115 | if fmax_dtype_device not in mel_basis: 116 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 117 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 118 | dtype=y.dtype, device=y.device 119 | ) 120 | if wnsize_dtype_device not in hann_window: 121 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 122 | dtype=y.dtype, device=y.device 123 | ) 124 | 125 | y = torch.nn.functional.pad( 126 | y.unsqueeze(1), 127 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 128 | mode="reflect", 129 | ) 130 | y = y.squeeze(1) 131 | 132 | spec = torch.stft( 133 | y, 134 | n_fft, 135 | hop_length=hop_size, 136 | win_length=win_size, 137 | window=hann_window[wnsize_dtype_device], 138 | center=center, 139 | pad_mode="reflect", 140 | normalized=False, 141 | onesided=True, 142 | return_complex=True, 143 | ) 144 | spec = torch.view_as_real(spec) 145 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 146 | 147 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 148 | spec = spectral_normalize_torch(spec) 149 | 150 | return spec 151 | -------------------------------------------------------------------------------- /normalize_wav.py: -------------------------------------------------------------------------------- 1 | from prepare.align_wav_spec import Align 2 | import os 3 | from tqdm import tqdm 4 | 5 | align = Align(32768, 24000, 1024, 256, 1024) 6 | output_path = "singing_gt" 7 | input_path = "/home/yyu479/VISinger_data/wav_dump_24k" 8 | 9 | files = os.listdir(path=input_path) 10 | for i, wav_file in enumerate(tqdm(files)): 11 | suffix = os.path.splitext(os.path.split(wav_file)[-1])[1] 12 | if not suffix == ".wav": 13 | continue 14 | basename = os.path.splitext(os.path.split(wav_file)[-1])[0][:-7] 15 | align.normalize_wav( 16 | os.path.join(input_path, wav_file), os.path.join(output_path, wav_file) 17 | ) 18 | -------------------------------------------------------------------------------- /plot_f0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import librosa 5 | import librosa.display 6 | 7 | from prepare.data_vits_phn import FeatureInput, SingInput 8 | 9 | # setting 10 | hop_length = 256 11 | sample_rate = 24000 12 | wav_name = "2001000001" 13 | input_path = "singing_gt/2001000001_bits16.wav" 14 | 15 | # get mel 16 | y, sr = librosa.load(input_path, sr=sample_rate) 17 | librosa.feature.melspectrogram(y=y, sr=sr) 18 | D = librosa.stft(y, hop_length=hop_length) # STFT of y 19 | S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) 20 | 21 | # get f0 22 | featureInput = FeatureInput("singing_gt/", sr, hop_length) 23 | featur_pit = featureInput.compute_f0("2001000001_bits16.wav") 24 | 25 | fo = open("../VISinger_data/transcriptions.txt", "r+") 26 | # load text info 27 | 28 | while True: 29 | try: 30 | message = fo.readline().strip() 31 | except Exception as e: 32 | print("nothing of except:", e) 33 | if message == None: 34 | break 35 | if message == "": 36 | break 37 | if wav_name in message: 38 | break 39 | print(message) 40 | 41 | infos = message.split("|") 42 | file = infos[0] 43 | hanz = infos[1] 44 | phon = infos[2].split(" ") 45 | note = infos[3].split(" ") 46 | note_dur = infos[4].split(" ") 47 | phon_dur = infos[5].split(" ") 48 | phon_slur = infos[6].split(" ") 49 | 50 | 51 | singInput = SingInput(sample_rate, hop_length) 52 | 53 | ( 54 | file, 55 | labels_ids, 56 | labels_dur, 57 | scores_ids, 58 | scores_dur, 59 | labels_slr, 60 | labels_uvs, 61 | ) = singInput.parseInput(message) 62 | labels_uvs = np.repeat(labels_uvs, labels_dur, axis=0) 63 | featur_pit = featur_pit[: len(labels_uvs)] 64 | featur_pit_uv = featur_pit * labels_uvs 65 | 66 | uv = featur_pit == 0 67 | featur_pit_intp = np.copy(featur_pit) 68 | featur_pit_intp[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], featur_pit[~uv]) 69 | # plot 70 | # plt.figure() 71 | fig = plt.figure(figsize=(15, 6)) 72 | 73 | librosa.display.specshow( 74 | S_db, y_axis="log", sr=sr, hop_length=hop_length, x_axis="frames" 75 | ) 76 | 77 | (F0_ori,) = plt.plot(featur_pit.T, "r", label="F0_ori", alpha=0.9) 78 | (F0_uv,) = plt.plot(featur_pit_uv.T, "y", label="F0_uv", alpha=0.9) 79 | (F0_intp,) = plt.plot(featur_pit_intp.T, "b", label="F0_intp", alpha=0.9) 80 | plt.legend([F0_ori, F0_uv, F0_intp], ["F0_ori", "F0_uv", "F0_intp"], loc="upper right") 81 | plt.colorbar(format="%+2.0f dB") 82 | plt.savefig("f0.png") 83 | -------------------------------------------------------------------------------- /prepare/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/prepare/__init__.py -------------------------------------------------------------------------------- /prepare/align_wav_spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | 5 | from mel_processing import spectrogram_torch 6 | from utils import load_wav_to_torch 7 | import scipy.io.wavfile as sciwav 8 | import os 9 | 10 | 11 | class Align: 12 | def __init__( 13 | self, max_wav_value, sampling_rate, filter_length, hop_length, win_length 14 | ): 15 | self.max_wav_value = max_wav_value 16 | self.sampling_rate = sampling_rate 17 | self.filter_length = filter_length 18 | self.hop_length = hop_length 19 | self.win_length = win_length 20 | 21 | def align_wav_spec(self, filename, phone_dur): 22 | phone_dur = np.int32(phone_dur) 23 | phone_dur = torch.Tensor(phone_dur).to(torch.int32) 24 | audio, sampling_rate = load_wav_to_torch(filename) 25 | if sampling_rate != self.sampling_rate: 26 | raise ValueError( 27 | "{} SR doesn't match target {} SR".format( 28 | sampling_rate, self.sampling_rate 29 | ) 30 | ) 31 | audio_norm = audio / self.max_wav_value 32 | audio_norm = audio_norm.unsqueeze(0) 33 | spec_filename = filename.replace(".wav", ".spec.pt") 34 | if os.path.exists(spec_filename): 35 | spec = torch.load(spec_filename) 36 | else: 37 | spec = spectrogram_torch( 38 | audio_norm, 39 | self.filter_length, 40 | self.sampling_rate, 41 | self.hop_length, 42 | self.win_length, 43 | center=False, 44 | ) 45 | # align mel and wave 46 | phone_dur_sum = torch.sum(phone_dur).item() 47 | spec_length = spec.shape[2] 48 | 49 | if spec_length > phone_dur_sum: 50 | spec = spec[:, :, :phone_dur_sum] 51 | elif spec_length < phone_dur_sum: 52 | pad_length = phone_dur_sum - spec_length 53 | spec = torch.nn.functional.pad( 54 | input=spec, pad=(0, pad_length, 0, 0), mode="constant", value=0 55 | ) 56 | assert spec.shape[2] == phone_dur_sum 57 | 58 | # align wav 59 | fixed_wav_len = phone_dur_sum * self.hop_length 60 | if audio_norm.shape[1] > fixed_wav_len: 61 | audio_norm = audio_norm[:, :fixed_wav_len] 62 | elif audio_norm.shape[1] < fixed_wav_len: 63 | pad_length = fixed_wav_len - audio_norm.shape[1] 64 | audio_norm = torch.nn.functional.pad( 65 | input=audio_norm, 66 | pad=(0, pad_length, 0, 0), 67 | mode="constant", 68 | value=0, 69 | ) 70 | assert audio_norm.shape[1] == fixed_wav_len 71 | 72 | # rewrite aligned wav 73 | audio = ( 74 | (audio_norm * self.max_wav_value) 75 | .transpose(0, 1) 76 | .numpy() 77 | .astype(np.int16) 78 | ) 79 | 80 | sciwav.write( 81 | filename, 82 | self.sampling_rate, 83 | audio, 84 | ) 85 | # save spec 86 | spec = torch.squeeze(spec, 0) 87 | torch.save(spec, spec_filename) 88 | return spec.shape[1] 89 | 90 | def normalize_wav(self, input_path, output_path): 91 | audio, sampling_rate = load_wav_to_torch(input_path) 92 | audio_norm = audio.numpy() / self.max_wav_value 93 | audio_norm *= 32767 / max(0.01, np.max(np.abs(audio_norm))) * 0.6 94 | sciwav.write( 95 | output_path, 96 | sampling_rate, 97 | audio_norm.astype(np.int16), 98 | ) 99 | -------------------------------------------------------------------------------- /prepare/data_vits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import librosa 5 | import pyworld 6 | 7 | from prepare.phone_map import label_to_ids 8 | from prepare.phone_uv import uv_map 9 | 10 | 11 | def load_midi_map(): 12 | notemap = {} 13 | notemap["rest"] = 0 14 | fo = open("./prepare/midi-note.scp", "r+") 15 | while True: 16 | try: 17 | message = fo.readline().strip() 18 | except Exception as e: 19 | print("nothing of except:", e) 20 | break 21 | if message == None: 22 | break 23 | if message == "": 24 | break 25 | infos = message.split() 26 | notemap[infos[1]] = int(infos[0]) 27 | fo.close() 28 | return notemap 29 | 30 | 31 | class SingInput(object): 32 | def __init__(self, samplerate=16000, hop_size=128): 33 | self.fs = samplerate 34 | self.hop = hop_size 35 | self.notemaper = load_midi_map() 36 | 37 | def phone_to_uv(self, phones): 38 | uv = [] 39 | for phone in phones: 40 | uv.append(uv_map[phone.lower()]) 41 | return uv 42 | 43 | def notes_to_id(self, notes): 44 | note_ids = [] 45 | for note in notes: 46 | note_ids.append(self.notemaper[note]) 47 | return note_ids 48 | 49 | def frame_duration(self, durations): 50 | ph_durs = [float(x) for x in durations] 51 | sentence_length = 0 52 | for ph_dur in ph_durs: 53 | sentence_length = sentence_length + ph_dur 54 | sentence_length = int(sentence_length * self.fs / self.hop + 0.5) 55 | 56 | sample_frame = [] 57 | startTime = 0 58 | for i_ph in range(len(ph_durs)): 59 | start_frame = int(startTime * self.fs / self.hop + 0.5) 60 | end_frame = int((startTime + ph_durs[i_ph]) * self.fs / self.hop + 0.5) 61 | count_frame = end_frame - start_frame 62 | sample_frame.append(count_frame) 63 | startTime = startTime + ph_durs[i_ph] 64 | all_frame = np.sum(sample_frame) 65 | assert all_frame == sentence_length 66 | # match mel length 67 | sample_frame[-1] = sample_frame[-1] - 1 68 | return sample_frame 69 | 70 | def score_duration(self, durations): 71 | ph_durs = [float(x) for x in durations] 72 | sample_frame = [] 73 | for i_ph in range(len(ph_durs)): 74 | count_frame = int(ph_durs[i_ph] * self.fs / self.hop + 0.5) 75 | if count_frame >= 256: 76 | print("count_frame", count_frame) 77 | count_frame = 255 78 | sample_frame.append(count_frame) 79 | return sample_frame 80 | 81 | def parseInput(self, singinfo: str): 82 | infos = singinfo.split("|") 83 | file = infos[0] 84 | # hanz = infos[1] 85 | phon = infos[2].split(" ") 86 | note = infos[3].split(" ") 87 | note_dur = infos[4].split(" ") 88 | phon_dur = infos[5].split(" ") 89 | phon_slr = infos[6].split(" ") 90 | 91 | labels_ids = label_to_ids(phon) 92 | labels_uvs = self.phone_to_uv(phon) 93 | labels_frames = self.frame_duration(phon_dur) 94 | scores_ids = self.notes_to_id(note) 95 | scores_dur = self.score_duration(note_dur) 96 | labels_slr = [int(x) for x in phon_slr] 97 | return ( 98 | file, 99 | labels_ids, 100 | labels_frames, 101 | scores_ids, 102 | scores_dur, 103 | labels_slr, 104 | labels_uvs, 105 | ) 106 | 107 | def parseSong(self, singinfo: str): 108 | infos = singinfo.split("|") 109 | item_indx = infos[0] 110 | item_time = infos[1] 111 | # hanz = infos[2] 112 | phon = infos[3].split(" ") 113 | note_ids = infos[4].split(" ") 114 | note_dur = infos[5].split(" ") 115 | phon_dur = infos[6].split(" ") 116 | phon_slr = infos[7].split(" ") 117 | 118 | labels_ids = label_to_ids(phon) 119 | labels_uvs = self.phone_to_uv(phon) 120 | labels_frames = self.frame_duration(phon_dur) 121 | scores_ids = [int(x) if x != "rest" else 0 for x in note_ids] 122 | scores_dur = self.score_duration(note_dur) 123 | labels_slr = [int(x) for x in phon_slr] 124 | return ( 125 | item_indx, 126 | item_time, 127 | labels_ids, 128 | labels_frames, 129 | scores_ids, 130 | scores_dur, 131 | labels_slr, 132 | labels_uvs, 133 | ) 134 | 135 | def expandInput(self, labels_ids, labels_frames): 136 | assert len(labels_ids) == len(labels_frames) 137 | frame_num = np.sum(labels_frames) 138 | frame_labels = np.zeros(frame_num, dtype=np.int) 139 | start = 0 140 | for index, num in enumerate(labels_frames): 141 | frame_labels[start : start + num] = labels_ids[index] 142 | start += num 143 | return frame_labels 144 | 145 | def scorePitch(self, scores_id): 146 | score_pitch = np.zeros(len(scores_id), dtype=np.float) 147 | for index, score_id in enumerate(scores_id): 148 | if score_id == 0: 149 | score_pitch[index] = 0 150 | else: 151 | pitch = librosa.midi_to_hz(score_id) 152 | score_pitch[index] = round(pitch, 1) 153 | return score_pitch 154 | 155 | def smoothPitch(self, pitch): 156 | # 使用卷积对数据平滑 157 | kernel = np.hanning(5) # 随机生成一个卷积核(对称的) 158 | kernel /= kernel.sum() 159 | smooth_pitch = np.convolve(pitch, kernel, "same") 160 | return smooth_pitch 161 | 162 | 163 | class FeatureInput(object): 164 | def __init__(self, path, samplerate=16000, hop_size=128): 165 | self.fs = samplerate 166 | self.hop = hop_size 167 | self.path = path 168 | 169 | self.f0_bin = 256 170 | self.f0_max = 1100.0 171 | self.f0_min = 50.0 172 | self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) 173 | self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) 174 | 175 | def compute_f0(self, filename): 176 | x, sr = librosa.load(self.path + filename, self.fs) 177 | assert sr == self.fs 178 | f0, t = pyworld.dio( 179 | x.astype(np.double), 180 | fs=sr, 181 | f0_ceil=800, 182 | frame_period=1000 * self.hop / sr, 183 | ) 184 | f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs) 185 | for index, pitch in enumerate(f0): 186 | f0[index] = round(pitch, 1) 187 | return f0 188 | 189 | def coarse_f0(self, f0): 190 | f0_mel = 1127 * np.log(1 + f0 / 700) 191 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * ( 192 | self.f0_bin - 2 193 | ) / (self.f0_mel_max - self.f0_mel_min) + 1 194 | 195 | # use 0 or 1 196 | f0_mel[f0_mel <= 1] = 1 197 | f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1 198 | f0_coarse = np.rint(f0_mel).astype(np.int) 199 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( 200 | f0_coarse.max(), 201 | f0_coarse.min(), 202 | ) 203 | return f0_coarse 204 | 205 | def diff_f0(self, scores_pit, featur_pit, labels_frames): 206 | length_pit = min(len(scores_pit), len(featur_pit)) 207 | offset_pit = np.zeros(length_pit, dtype=np.int) 208 | for idx in range(length_pit): 209 | s_pit = scores_pit[idx] 210 | f_pit = featur_pit[idx] 211 | if s_pit == 0 or f_pit == 0: 212 | offset_pit[idx] = 0 213 | else: 214 | tmp = int(f_pit - s_pit) 215 | tmp = +128 if tmp > +128 else tmp 216 | tmp = -127 if tmp < -127 else tmp 217 | tmp = 256 + tmp if tmp < 0 else tmp 218 | offset_pit[idx] = tmp 219 | offset_pit[offset_pit > 255] = 255 220 | offset_pit[offset_pit < 0] = 0 221 | # start = 0 222 | # for num in labels_frames: 223 | # print("---------------------------------------------") 224 | # print(scores_pit[start:start+num]) 225 | # print(featur_pit[start:start+num]) 226 | # print(offset_pit[start:start+num]) 227 | # start += num 228 | return offset_pit 229 | 230 | 231 | if __name__ == "__main__": 232 | logging.basicConfig(level=logging.INFO) # ERROR & INFO 233 | 234 | notemaper = load_midi_map() 235 | logging.info(notemaper) 236 | 237 | singInput = SingInput(16000, 256) 238 | featureInput = FeatureInput("../VISinger_data/wav_dump_16k/", 16000, 256) 239 | 240 | if not os.path.exists("../VISinger_data/label_vits"): 241 | os.mkdir("../VISinger_data/label_vits") 242 | 243 | fo = open("../VISinger_data/transcriptions.txt", "r+") 244 | vits_file = open("./filelists/vits_file.txt", "w", encoding="utf-8") 245 | i = 0 246 | all_txt = [] # 统计非重复的句子个数 247 | while True: 248 | try: 249 | message = fo.readline().strip() 250 | except Exception as e: 251 | print("nothing of except:", e) 252 | break 253 | if message == None: 254 | break 255 | if message == "": 256 | break 257 | i = i + 1 258 | # if i > 5: 259 | # exit() 260 | infos = message.split("|") 261 | file = infos[0] 262 | hanz = infos[1] 263 | all_txt.append(hanz) 264 | phon = infos[2].split(" ") 265 | note = infos[3].split(" ") 266 | note_dur = infos[4].split(" ") 267 | phon_dur = infos[5].split(" ") 268 | phon_slur = infos[6].split(" ") 269 | 270 | logging.info("----------------------------") 271 | logging.info(file) 272 | logging.info(hanz) 273 | logging.info(phon) 274 | # logging.info(note_dur) 275 | # logging.info(phon_dur) 276 | # logging.info(phon_slur) 277 | 278 | ( 279 | file, 280 | labels_ids, 281 | labels_frames, 282 | scores_ids, 283 | scores_dur, 284 | labels_slr, 285 | labels_uvs, 286 | ) = singInput.parseInput(message) 287 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 288 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 289 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 290 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 291 | scores_pit = singInput.scorePitch(scores_ids) 292 | featur_pit = featureInput.compute_f0(f"{file}_bits16.wav") 293 | featur_pit = featur_pit[: len(labels_ids)] 294 | featur_pit = featur_pit * labels_uvs 295 | coarse_pit = featureInput.coarse_f0(featur_pit) 296 | 297 | # offset_pit = featureInput.diff_f0(scores_pit, featur_pit, labels_frames) 298 | assert len(labels_ids) == len(coarse_pit) 299 | 300 | logging.info(labels_ids) 301 | logging.info(scores_ids) 302 | logging.info(coarse_pit) 303 | logging.info(labels_slr) 304 | 305 | np.save( 306 | f"../VISinger_data/label_vits/{file}_label.npy", 307 | labels_ids, 308 | allow_pickle=False, 309 | ) 310 | np.save( 311 | f"../VISinger_data/label_vits/{file}_score.npy", 312 | scores_ids, 313 | allow_pickle=False, 314 | ) 315 | np.save( 316 | f"../VISinger_data/label_vits/{file}_pitch.npy", 317 | coarse_pit, 318 | allow_pickle=False, 319 | ) 320 | np.save( 321 | f"../VISinger_data/label_vits/{file}_slurs.npy", 322 | labels_slr, 323 | allow_pickle=False, 324 | ) 325 | 326 | # wave path|label path|label frame|score path|score duration;上面是一个.(当前目录),下面是两个..(从子目录调用) 327 | path_wave = f"../VISinger_data/wav_dump_16k/{file}_bits16.wav" 328 | path_label = f"../VISinger_data/label_vits/{file}_label.npy" 329 | path_score = f"../VISinger_data/label_vits/{file}_score.npy" 330 | path_pitch = f"../VISinger_data/label_vits/{file}_pitch.npy" 331 | path_slurs = f"../VISinger_data/label_vits/{file}_slurs.npy" 332 | print( 333 | f"{path_wave}|{path_label}|{path_score}|{path_pitch}|{path_slurs}", 334 | file=vits_file, 335 | ) 336 | 337 | fo.close() 338 | vits_file.close() 339 | print(len(set(all_txt))) # 统计非重复的句子个数 340 | -------------------------------------------------------------------------------- /prepare/data_vits_phn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import librosa 5 | import pyworld 6 | 7 | from prepare.phone_map import label_to_ids 8 | from prepare.phone_uv import uv_map 9 | from prepare.dur_to_frame import dur_to_frame 10 | from prepare.align_wav_spec import Align 11 | 12 | 13 | def load_midi_map(): 14 | notemap = {} 15 | notemap["rest"] = 0 16 | fo = open("./prepare/midi-note.scp", "r+") 17 | while True: 18 | try: 19 | message = fo.readline().strip() 20 | except Exception as e: 21 | print("nothing of except:", e) 22 | break 23 | if message == None: 24 | break 25 | if message == "": 26 | break 27 | infos = message.split() 28 | notemap[infos[1]] = int(infos[0]) 29 | fo.close() 30 | return notemap 31 | 32 | 33 | class SingInput(object): 34 | def __init__(self, sample_rate=24000, hop_size=256): 35 | self.fs = sample_rate 36 | self.hop = hop_size 37 | self.notemaper = load_midi_map() 38 | self.align = Align(32768, sample_rate, 1024, hop_size, 1024) 39 | 40 | def phone_to_uv(self, phones): 41 | uv = [] 42 | for phone in phones: 43 | uv.append(uv_map[phone.lower()]) 44 | return uv 45 | 46 | def notes_to_id(self, notes): 47 | note_ids = [] 48 | for note in notes: 49 | note_ids.append(self.notemaper[note]) 50 | return note_ids 51 | 52 | def frame_duration(self, durations): 53 | ph_durs = [float(x) for x in durations] 54 | sentence_length = 0 55 | for ph_dur in ph_durs: 56 | sentence_length = sentence_length + ph_dur 57 | sentence_length = int(sentence_length * self.fs / self.hop + 0.5) 58 | 59 | sample_frame = [] 60 | startTime = 0 61 | for i_ph in range(len(ph_durs)): 62 | start_frame = int(startTime * self.fs / self.hop + 0.5) 63 | end_frame = int((startTime + ph_durs[i_ph]) * self.fs / self.hop + 0.5) 64 | count_frame = end_frame - start_frame 65 | sample_frame.append(count_frame) 66 | startTime = startTime + ph_durs[i_ph] 67 | all_frame = np.sum(sample_frame) 68 | assert all_frame == sentence_length 69 | # match mel length 70 | sample_frame[-1] = sample_frame[-1] - 1 71 | return sample_frame 72 | 73 | def score_duration(self, durations): 74 | ph_durs = [float(x) for x in durations] 75 | sample_frame = [] 76 | for i_ph in range(len(ph_durs)): 77 | count_frame = int(ph_durs[i_ph] * self.fs / self.hop + 0.5) 78 | if count_frame >= 256: 79 | print("count_frame", count_frame) 80 | count_frame = 255 81 | sample_frame.append(count_frame) 82 | return sample_frame 83 | 84 | def parseInput(self, singinfo: str): 85 | infos = singinfo.split("|") 86 | file = infos[0] 87 | # hanz = infos[1] 88 | phon = infos[2].split(" ") 89 | note = infos[3].split(" ") 90 | note_dur = infos[4].split(" ") 91 | phon_dur = infos[5].split(" ") 92 | phon_slr = infos[6].split(" ") 93 | 94 | labels_ids = label_to_ids(phon) 95 | labels_uvs = self.phone_to_uv(phon) 96 | note_ids = self.notes_to_id(note) 97 | # convert into float 98 | note_dur = [eval(i) for i in note_dur] 99 | phon_dur = [eval(i) for i in phon_dur] 100 | 101 | note_dur = dur_to_frame(note_dur, self.fs, self.hop) 102 | phon_dur = dur_to_frame(phon_dur, self.fs, self.hop) 103 | labels_slr = [int(x) for x in phon_slr] 104 | 105 | # print("labels_ids", labels_ids) 106 | # print("note_dur", note_dur) 107 | # print("phon_dur", phon_dur) 108 | # print("labels_slr", labels_slr) 109 | return ( 110 | file, 111 | labels_ids, 112 | phon_dur, 113 | note_ids, 114 | note_dur, 115 | labels_slr, 116 | labels_uvs, 117 | ) 118 | 119 | def parseSong(self, singinfo: str): 120 | infos = singinfo.split("|") 121 | item_indx = infos[0] 122 | item_time = infos[1] 123 | # hanz = infos[2] 124 | phon = infos[3].split(" ") 125 | note_ids = infos[4].split(" ") 126 | note_dur = infos[5].split(" ") 127 | phon_dur = infos[6].split(" ") 128 | phon_slr = infos[7].split(" ") 129 | 130 | labels_ids = label_to_ids(phon) 131 | labels_uvs = self.phone_to_uv(phon) 132 | labels_frames = self.frame_duration(phon_dur) 133 | scores_ids = [int(x) if x != "rest" else 0 for x in note_ids] 134 | scores_dur = self.score_duration(note_dur) 135 | labels_slr = [int(x) for x in phon_slr] 136 | return ( 137 | item_indx, 138 | item_time, 139 | labels_ids, 140 | labels_frames, 141 | scores_ids, 142 | scores_dur, 143 | labels_slr, 144 | labels_uvs, 145 | ) 146 | 147 | def expandInput(self, labels_ids, labels_frames): 148 | assert len(labels_ids) == len(labels_frames) 149 | frame_num = np.sum(labels_frames) 150 | frame_labels = np.zeros(frame_num, dtype=np.int) 151 | start = 0 152 | for index, num in enumerate(labels_frames): 153 | frame_labels[start : start + num] = labels_ids[index] 154 | start += num 155 | return frame_labels 156 | 157 | def scorePitch(self, scores_id): 158 | score_pitch = np.zeros(len(scores_id), dtype=np.float) 159 | for index, score_id in enumerate(scores_id): 160 | if score_id == 0: 161 | score_pitch[index] = 0 162 | else: 163 | pitch = librosa.midi_to_hz(score_id) 164 | score_pitch[index] = round(pitch, 1) 165 | return score_pitch 166 | 167 | def smoothPitch(self, pitch): 168 | # 使用卷积对数据平滑 169 | kernel = np.hanning(5) # 随机生成一个卷积核(对称的) 170 | kernel /= kernel.sum() 171 | smooth_pitch = np.convolve(pitch, kernel, "same") 172 | return smooth_pitch 173 | 174 | def align_process(self, file, phn_dur): 175 | return self.align.align_wav_spec(file, phn_dur) 176 | 177 | 178 | class FeatureInput(object): 179 | def __init__(self, path, samplerate=24000, hop_size=256): 180 | self.fs = samplerate 181 | self.hop = hop_size 182 | self.path = path 183 | 184 | self.f0_bin = 256 185 | self.f0_max = 1100.0 186 | self.f0_min = 50.0 187 | self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) 188 | self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) 189 | 190 | def compute_f0(self, filename): 191 | x, sr = librosa.load(self.path + filename, self.fs) 192 | assert sr == self.fs 193 | f0, t = pyworld.dio( 194 | x.astype(np.double), 195 | fs=sr, 196 | f0_ceil=800, 197 | frame_period=1000 * self.hop / sr, 198 | ) 199 | f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs) 200 | for index, pitch in enumerate(f0): 201 | f0[index] = round(pitch, 1) 202 | return f0 203 | 204 | def coarse_f0(self, f0): 205 | f0_mel = 1127 * np.log(1 + f0 / 700) 206 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * ( 207 | self.f0_bin - 2 208 | ) / (self.f0_mel_max - self.f0_mel_min) + 1 209 | 210 | # use 0 or 1 211 | f0_mel[f0_mel <= 1] = 1 212 | f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1 213 | f0_coarse = np.rint(f0_mel).astype(np.int) 214 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( 215 | f0_coarse.max(), 216 | f0_coarse.min(), 217 | ) 218 | return f0_coarse 219 | 220 | def diff_f0(self, scores_pit, featur_pit, labels_frames): 221 | length_pit = min(len(scores_pit), len(featur_pit)) 222 | offset_pit = np.zeros(length_pit, dtype=np.int) 223 | for idx in range(length_pit): 224 | s_pit = scores_pit[idx] 225 | f_pit = featur_pit[idx] 226 | if s_pit == 0 or f_pit == 0: 227 | offset_pit[idx] = 0 228 | else: 229 | tmp = int(f_pit - s_pit) 230 | tmp = +128 if tmp > +128 else tmp 231 | tmp = -127 if tmp < -127 else tmp 232 | tmp = 256 + tmp if tmp < 0 else tmp 233 | offset_pit[idx] = tmp 234 | offset_pit[offset_pit > 255] = 255 235 | offset_pit[offset_pit < 0] = 0 236 | # start = 0 237 | # for num in labels_frames: 238 | # print("---------------------------------------------") 239 | # print(scores_pit[start:start+num]) 240 | # print(featur_pit[start:start+num]) 241 | # print(offset_pit[start:start+num]) 242 | # start += num 243 | return offset_pit 244 | 245 | 246 | if __name__ == "__main__": 247 | output_path = "../VISinger_data/label_vits_phn/" 248 | wav_path = "../VISinger_data/wav_dump_24k/" 249 | logging.basicConfig(level=logging.INFO) # ERROR & INFO 250 | pitch_norm = True 251 | pitch_intp = True 252 | uv_process = False 253 | 254 | notemaper = load_midi_map() 255 | logging.info(notemaper) 256 | 257 | sample_rate = 24000 258 | hop_size = 256 259 | singInput = SingInput(sample_rate, hop_size) 260 | featureInput = FeatureInput(wav_path, sample_rate, hop_size) 261 | 262 | if not os.path.exists(output_path): 263 | os.mkdir(output_path) 264 | 265 | fo = open("../VISinger_data/transcriptions.txt", "r+") 266 | # vits_file = open("./filelists/vits_file_phn.txt", "w", encoding="utf-8") 267 | vits_file = open("./filelists/vits_file.txt", "w", encoding="utf-8") 268 | i = 0 269 | all_txt = [] # 统计非重复的句子个数 270 | while True: 271 | try: 272 | message = fo.readline().strip() 273 | except Exception as e: 274 | print("nothing of except:", e) 275 | break 276 | if message == None: 277 | break 278 | if message == "": 279 | break 280 | i = i + 1 281 | # if i > 5: 282 | # exit() 283 | infos = message.split("|") 284 | file = infos[0] 285 | hanz = infos[1] 286 | all_txt.append(hanz) 287 | phon = infos[2].split(" ") 288 | note = infos[3].split(" ") 289 | note_dur = infos[4].split(" ") 290 | phon_dur = infos[5].split(" ") 291 | phon_slur = infos[6].split(" ") 292 | 293 | logging.info("----------------------------") 294 | logging.info("file {}".format(file)) 295 | logging.info("lyrics {}".format(hanz)) 296 | logging.info("phn {}".format(phon)) 297 | # logging.info(note_dur) 298 | # logging.info(phon_dur) 299 | # logging.info(phon_slur) 300 | 301 | ( 302 | file, 303 | labels_ids, 304 | labels_dur, 305 | scores_ids, 306 | scores_dur, 307 | labels_slr, 308 | labels_uvs, 309 | ) = singInput.parseInput(message) 310 | # labels_ids = singInput.expandInput(labels_ids, labels_frames) 311 | # labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 312 | # labels_slr = singInput.expandInput(labels_slr, labels_frames) 313 | # scores_ids = singInput.expandInput(scores_ids, labels_frames) 314 | # scores_pit = singInput.scorePitch(scores_ids) 315 | featur_pit = featureInput.compute_f0(f"{file}_bits16.wav") 316 | wav_file = os.path.join(wav_path, file + "_bits16.wav") 317 | 318 | spec_len = singInput.align_process(wav_file, labels_dur) 319 | 320 | # extend uv 321 | labels_uvs = np.repeat(labels_uvs, labels_dur, axis=0) 322 | 323 | featur_pit = featur_pit[:spec_len] 324 | 325 | if featur_pit.shape[0] < spec_len: 326 | pad_length = spec_len - featur_pit.shape[0] 327 | featur_pit = np.pad(featur_pit, pad_width=(0, pad_length), mode="constant") 328 | assert featur_pit.shape[0] == spec_len 329 | if uv_process: 330 | featur_pit = featur_pit * labels_uvs 331 | coarse_pit = featureInput.coarse_f0(featur_pit) 332 | 333 | # log f0 334 | if not pitch_norm: 335 | nonzero_idxs = np.where(featur_pit != 0)[0] 336 | featur_pit[nonzero_idxs] = np.log(featur_pit[nonzero_idxs]) 337 | else: 338 | featur_pit = 2595.0 * np.log10(1.0 + featur_pit / 700.0) / 500 339 | 340 | if pitch_intp: 341 | uv = featur_pit == 0 342 | featur_pit_intp = np.copy(featur_pit) 343 | featur_pit_intp[uv] = np.interp( 344 | np.where(uv)[0], np.where(~uv)[0], featur_pit[~uv] 345 | ) 346 | 347 | # offset_pit = featureInput.diff_f0(scores_pit, featur_pit, labels_frames) 348 | # assert len(labels_ids) == len(coarse_pit) 349 | assert len(labels_ids) == len(labels_dur) 350 | assert len(labels_dur) == len(scores_ids) 351 | assert len(scores_ids) == len(scores_dur) 352 | assert len(scores_dur) == len(labels_slr) 353 | 354 | logging.info("labels_ids {}".format(labels_ids)) 355 | # logging.info("labels_dur {}".format(labels_dur)) 356 | # logging.info("scores_ids {}".format(scores_ids)) 357 | # logging.info("scores_dur {}".format(scores_dur)) 358 | # logging.info("labels_slr {}".format(labels_slr)) 359 | # logging.info("labels_uvs {}".format(labels_uvs)) 360 | # logging.info("featur_pit {}".format(featur_pit)) 361 | logging.info("featur_pit_intp {}".format(featur_pit_intp)) 362 | 363 | np.save( 364 | output_path + f"{file}_label.npy", 365 | labels_ids, 366 | allow_pickle=False, 367 | ) 368 | np.save( 369 | output_path + f"{file}_label_dur.npy", 370 | labels_dur, 371 | allow_pickle=False, 372 | ) 373 | np.save( 374 | output_path + f"{file}_score.npy", 375 | scores_ids, 376 | allow_pickle=False, 377 | ) 378 | np.save( 379 | output_path + f"{file}_score_dur.npy", 380 | scores_dur, 381 | allow_pickle=False, 382 | ) 383 | if not pitch_intp: 384 | np.save( 385 | output_path + f"{file}_pitch.npy", 386 | featur_pit, 387 | allow_pickle=False, 388 | ) 389 | else: 390 | np.save( 391 | output_path + f"{file}_pitch.npy", 392 | featur_pit_intp, 393 | allow_pickle=False, 394 | ) 395 | # np.save( 396 | # output_path + f"{file}_pitch.npy", 397 | # coarse_pit, 398 | # allow_pickle=False, 399 | # ) 400 | np.save( 401 | output_path + f"{file}_slurs.npy", 402 | labels_slr, 403 | allow_pickle=False, 404 | ) 405 | 406 | # wave path|label path|label frame|score path|score duration;上面是一个.(当前目录),下面是两个..(从子目录调用) 407 | path_wave = wav_path + f"{file}_bits16.wav" 408 | path_label = output_path + f"{file}_label.npy" 409 | path_label_dur = output_path + f"{file}_label_dur.npy" 410 | path_score = output_path + f"{file}_score.npy" 411 | path_score_dur = output_path + f"{file}_score_dur.npy" 412 | path_pitch = output_path + f"{file}_pitch.npy" 413 | path_slurs = output_path + f"{file}_slurs.npy" 414 | print( 415 | f"{path_wave}|{path_label}|{path_label_dur}|{path_score}|{path_score_dur}|{path_pitch}|{path_slurs}", 416 | file=vits_file, 417 | ) 418 | 419 | fo.close() 420 | vits_file.close() 421 | print(len(set(all_txt))) # 统计非重复的句子个数 422 | -------------------------------------------------------------------------------- /prepare/data_vits_phn_ofuton.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import librosa 5 | import pyworld 6 | 7 | from prepare.phone_map import label_to_ids 8 | from prepare.phone_uv import uv_map 9 | from prepare.dur_to_frame import dur_to_frame 10 | from prepare.align_wav_spec import Align 11 | 12 | 13 | def load_midi_map(): 14 | notemap = {} 15 | notemap["rest"] = 0 16 | fo = open("./prepare/midi-note.scp", "r+") 17 | while True: 18 | try: 19 | message = fo.readline().strip() 20 | except Exception as e: 21 | print("nothing of except:", e) 22 | break 23 | if message == None: 24 | break 25 | if message == "": 26 | break 27 | infos = message.split() 28 | notemap[infos[1]] = int(infos[0]) 29 | fo.close() 30 | return notemap 31 | 32 | 33 | class SingInput(object): 34 | def __init__(self, sample_rate=24000, hop_size=256): 35 | self.fs = sample_rate 36 | self.hop = hop_size 37 | self.notemaper = load_midi_map() 38 | self.align = Align(32768, sample_rate, 1024, hop_size, 1024) 39 | 40 | def phone_to_uv(self, phones): 41 | uv = [] 42 | for phone in phones: 43 | uv.append(uv_map[phone.lower()]) 44 | return uv 45 | 46 | def notes_to_id(self, notes): 47 | note_ids = [] 48 | for note in notes: 49 | note_ids.append(self.notemaper[note]) 50 | return note_ids 51 | 52 | def frame_duration(self, durations): 53 | ph_durs = [float(x) for x in durations] 54 | sentence_length = 0 55 | for ph_dur in ph_durs: 56 | sentence_length = sentence_length + ph_dur 57 | sentence_length = int(sentence_length * self.fs / self.hop + 0.5) 58 | 59 | sample_frame = [] 60 | startTime = 0 61 | for i_ph in range(len(ph_durs)): 62 | start_frame = int(startTime * self.fs / self.hop + 0.5) 63 | end_frame = int((startTime + ph_durs[i_ph]) * self.fs / self.hop + 0.5) 64 | count_frame = end_frame - start_frame 65 | sample_frame.append(count_frame) 66 | startTime = startTime + ph_durs[i_ph] 67 | all_frame = np.sum(sample_frame) 68 | assert all_frame == sentence_length 69 | # match mel length 70 | sample_frame[-1] = sample_frame[-1] - 1 71 | return sample_frame 72 | 73 | def score_duration(self, durations): 74 | ph_durs = [float(x) for x in durations] 75 | sample_frame = [] 76 | for i_ph in range(len(ph_durs)): 77 | count_frame = int(ph_durs[i_ph] * self.fs / self.hop + 0.5) 78 | if count_frame >= 256: 79 | print("count_frame", count_frame) 80 | count_frame = 255 81 | sample_frame.append(count_frame) 82 | return sample_frame 83 | 84 | def parseInput(self, singinfo: str): 85 | infos = singinfo.split("|") 86 | file = infos[0] 87 | # hanz = infos[1] 88 | phon = infos[2].split(" ") 89 | note = infos[3].split(" ") 90 | note_dur = infos[4].split(" ") 91 | phon_dur = infos[5].split(" ") 92 | phon_slr = infos[6].split(" ") 93 | 94 | labels_ids = label_to_ids(phon) 95 | # labels_uvs = self.phone_to_uv(phon) 96 | note_ids = self.notes_to_id(note) 97 | # convert into float 98 | note_dur = [eval(i) for i in note_dur] 99 | phon_dur = [eval(i) for i in phon_dur] 100 | 101 | note_dur = dur_to_frame(note_dur, self.fs, self.hop) 102 | phon_dur = dur_to_frame(phon_dur, self.fs, self.hop) 103 | labels_slr = [int(x) for x in phon_slr] 104 | 105 | # print("labels_ids", labels_ids) 106 | # print("note_dur", note_dur) 107 | # print("phon_dur", phon_dur) 108 | # print("labels_slr", labels_slr) 109 | return ( 110 | file, 111 | labels_ids, 112 | phon_dur, 113 | note_ids, 114 | note_dur, 115 | labels_slr, 116 | # labels_uvs, 117 | ) 118 | 119 | def parseSong(self, singinfo: str): 120 | infos = singinfo.split("|") 121 | item_indx = infos[0] 122 | item_time = infos[1] 123 | # hanz = infos[2] 124 | phon = infos[3].split(" ") 125 | note_ids = infos[4].split(" ") 126 | note_dur = infos[5].split(" ") 127 | phon_dur = infos[6].split(" ") 128 | phon_slr = infos[7].split(" ") 129 | 130 | labels_ids = label_to_ids(phon) 131 | # labels_uvs = self.phone_to_uv(phon) 132 | labels_frames = self.frame_duration(phon_dur) 133 | scores_ids = [int(x) if x != "rest" else 0 for x in note_ids] 134 | scores_dur = self.score_duration(note_dur) 135 | labels_slr = [int(x) for x in phon_slr] 136 | return ( 137 | item_indx, 138 | item_time, 139 | labels_ids, 140 | labels_frames, 141 | scores_ids, 142 | scores_dur, 143 | labels_slr, 144 | # labels_uvs, 145 | ) 146 | 147 | def expandInput(self, labels_ids, labels_frames): 148 | assert len(labels_ids) == len(labels_frames) 149 | frame_num = np.sum(labels_frames) 150 | frame_labels = np.zeros(frame_num, dtype=np.int) 151 | start = 0 152 | for index, num in enumerate(labels_frames): 153 | frame_labels[start : start + num] = labels_ids[index] 154 | start += num 155 | return frame_labels 156 | 157 | def scorePitch(self, scores_id): 158 | score_pitch = np.zeros(len(scores_id), dtype=np.float) 159 | for index, score_id in enumerate(scores_id): 160 | if score_id == 0: 161 | score_pitch[index] = 0 162 | else: 163 | pitch = librosa.midi_to_hz(score_id) 164 | score_pitch[index] = round(pitch, 1) 165 | return score_pitch 166 | 167 | def smoothPitch(self, pitch): 168 | # 使用卷积对数据平滑 169 | kernel = np.hanning(5) # 随机生成一个卷积核(对称的) 170 | kernel /= kernel.sum() 171 | smooth_pitch = np.convolve(pitch, kernel, "same") 172 | return smooth_pitch 173 | 174 | def align_process(self, file, phn_dur): 175 | return self.align.align_wav_spec(file, phn_dur) 176 | 177 | 178 | class FeatureInput(object): 179 | def __init__(self, path, samplerate=24000, hop_size=256): 180 | self.fs = samplerate 181 | self.hop = hop_size 182 | self.path = path 183 | 184 | self.f0_bin = 256 185 | self.f0_max = 1100.0 186 | self.f0_min = 50.0 187 | self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) 188 | self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) 189 | 190 | def compute_f0(self, filename): 191 | x, sr = librosa.load(self.path + filename, self.fs) 192 | assert sr == self.fs 193 | f0, t = pyworld.dio( 194 | x.astype(np.double), 195 | fs=sr, 196 | f0_ceil=800, 197 | frame_period=1000 * self.hop / sr, 198 | ) 199 | f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs) 200 | for index, pitch in enumerate(f0): 201 | f0[index] = round(pitch, 1) 202 | return f0 203 | 204 | def coarse_f0(self, f0): 205 | f0_mel = 1127 * np.log(1 + f0 / 700) 206 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * ( 207 | self.f0_bin - 2 208 | ) / (self.f0_mel_max - self.f0_mel_min) + 1 209 | 210 | # use 0 or 1 211 | f0_mel[f0_mel <= 1] = 1 212 | f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1 213 | f0_coarse = np.rint(f0_mel).astype(np.int) 214 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( 215 | f0_coarse.max(), 216 | f0_coarse.min(), 217 | ) 218 | return f0_coarse 219 | 220 | def diff_f0(self, scores_pit, featur_pit, labels_frames): 221 | length_pit = min(len(scores_pit), len(featur_pit)) 222 | offset_pit = np.zeros(length_pit, dtype=np.int) 223 | for idx in range(length_pit): 224 | s_pit = scores_pit[idx] 225 | f_pit = featur_pit[idx] 226 | if s_pit == 0 or f_pit == 0: 227 | offset_pit[idx] = 0 228 | else: 229 | tmp = int(f_pit - s_pit) 230 | tmp = +128 if tmp > +128 else tmp 231 | tmp = -127 if tmp < -127 else tmp 232 | tmp = 256 + tmp if tmp < 0 else tmp 233 | offset_pit[idx] = tmp 234 | offset_pit[offset_pit > 255] = 255 235 | offset_pit[offset_pit < 0] = 0 236 | # start = 0 237 | # for num in labels_frames: 238 | # print("---------------------------------------------") 239 | # print(scores_pit[start:start+num]) 240 | # print(featur_pit[start:start+num]) 241 | # print(offset_pit[start:start+num]) 242 | # start += num 243 | return offset_pit 244 | 245 | 246 | if __name__ == "__main__": 247 | output_path = "../VISinger_ofuton_data/label_vits_phn/" 248 | wav_path = "../VISinger_ofuton_data/wav_dump_24k/" 249 | logging.basicConfig(level=logging.INFO) # ERROR & INFO 250 | pitch_norm = True 251 | pitch_intp = True 252 | uv_process = False 253 | 254 | notemaper = load_midi_map() 255 | logging.info(notemaper) 256 | 257 | sample_rate = 24000 258 | hop_size = 256 259 | singInput = SingInput(sample_rate, hop_size) 260 | featureInput = FeatureInput(wav_path, sample_rate, hop_size) 261 | 262 | if not os.path.exists(output_path): 263 | os.mkdir(output_path) 264 | 265 | fo = open("../VISinger_ofuton_data/transcriptions.txt", "r+") 266 | # vits_file = open("./filelists/vits_file_phn.txt", "w", encoding="utf-8") 267 | vits_file = open("./filelists/vits_file.txt", "w", encoding="utf-8") 268 | i = 0 269 | all_txt = [] # 统计非重复的句子个数 270 | while True: 271 | try: 272 | message = fo.readline().strip() 273 | except Exception as e: 274 | print("nothing of except:", e) 275 | break 276 | if message == None: 277 | break 278 | if message == "": 279 | break 280 | i = i + 1 281 | # if i > 5: 282 | # exit() 283 | infos = message.split("|") 284 | file = infos[0] 285 | hanz = infos[1] 286 | all_txt.append(hanz) 287 | phon = infos[2].split(" ") 288 | note = infos[3].split(" ") 289 | note_dur = infos[4].split(" ") 290 | phon_dur = infos[5].split(" ") 291 | phon_slur = infos[6].split(" ") 292 | 293 | logging.info("----------------------------") 294 | logging.info("file {}".format(file)) 295 | logging.info("lyrics {}".format(hanz)) 296 | logging.info("phn {}".format(phon)) 297 | # logging.info(note_dur) 298 | # logging.info(phon_dur) 299 | # logging.info(phon_slur) 300 | 301 | ( 302 | file, 303 | labels_ids, 304 | labels_dur, 305 | scores_ids, 306 | scores_dur, 307 | labels_slr, 308 | # labels_uvs, 309 | ) = singInput.parseInput(message) 310 | # labels_ids = singInput.expandInput(labels_ids, labels_frames) 311 | # labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 312 | # labels_slr = singInput.expandInput(labels_slr, labels_frames) 313 | # scores_ids = singInput.expandInput(scores_ids, labels_frames) 314 | # scores_pit = singInput.scorePitch(scores_ids) 315 | featur_pit = featureInput.compute_f0(f"{file}.wav") 316 | wav_file = os.path.join(wav_path, file + ".wav") 317 | 318 | spec_len = singInput.align_process(wav_file, labels_dur) 319 | 320 | # extend uv 321 | # labels_uvs = np.repeat(labels_uvs, labels_dur, axis=0) 322 | 323 | featur_pit = featur_pit[:spec_len] 324 | 325 | if featur_pit.shape[0] < spec_len: 326 | pad_length = spec_len - featur_pit.shape[0] 327 | featur_pit = np.pad(featur_pit, pad_width=(0, pad_length), mode="constant") 328 | assert featur_pit.shape[0] == spec_len 329 | # if uv_process: 330 | # featur_pit = featur_pit * labels_uvs 331 | coarse_pit = featureInput.coarse_f0(featur_pit) 332 | 333 | # log f0 334 | if not pitch_norm: 335 | nonzero_idxs = np.where(featur_pit != 0)[0] 336 | featur_pit[nonzero_idxs] = np.log(featur_pit[nonzero_idxs]) 337 | else: 338 | featur_pit = 2595.0 * np.log10(1.0 + featur_pit / 700.0) / 500 339 | 340 | if pitch_intp: 341 | uv = featur_pit == 0 342 | featur_pit_intp = np.copy(featur_pit) 343 | featur_pit_intp[uv] = np.interp( 344 | np.where(uv)[0], np.where(~uv)[0], featur_pit[~uv] 345 | ) 346 | 347 | # offset_pit = featureInput.diff_f0(scores_pit, featur_pit, labels_frames) 348 | # assert len(labels_ids) == len(coarse_pit) 349 | assert len(labels_ids) == len(labels_dur) 350 | assert len(labels_dur) == len(scores_ids) 351 | assert len(scores_ids) == len(scores_dur) 352 | assert len(scores_dur) == len(labels_slr) 353 | 354 | logging.info("labels_ids {}".format(labels_ids)) 355 | # logging.info("labels_dur {}".format(labels_dur)) 356 | # logging.info("scores_ids {}".format(scores_ids)) 357 | # logging.info("scores_dur {}".format(scores_dur)) 358 | # logging.info("labels_slr {}".format(labels_slr)) 359 | # logging.info("labels_uvs {}".format(labels_uvs)) 360 | # logging.info("featur_pit {}".format(featur_pit)) 361 | logging.info("featur_pit_intp {}".format(featur_pit_intp)) 362 | 363 | np.save( 364 | output_path + f"{file}_label.npy", 365 | labels_ids, 366 | allow_pickle=False, 367 | ) 368 | np.save( 369 | output_path + f"{file}_label_dur.npy", 370 | labels_dur, 371 | allow_pickle=False, 372 | ) 373 | np.save( 374 | output_path + f"{file}_score.npy", 375 | scores_ids, 376 | allow_pickle=False, 377 | ) 378 | np.save( 379 | output_path + f"{file}_score_dur.npy", 380 | scores_dur, 381 | allow_pickle=False, 382 | ) 383 | if not pitch_intp: 384 | np.save( 385 | output_path + f"{file}_pitch.npy", 386 | featur_pit, 387 | allow_pickle=False, 388 | ) 389 | else: 390 | np.save( 391 | output_path + f"{file}_pitch.npy", 392 | featur_pit_intp, 393 | allow_pickle=False, 394 | ) 395 | # np.save( 396 | # output_path + f"{file}_pitch.npy", 397 | # coarse_pit, 398 | # allow_pickle=False, 399 | # ) 400 | np.save( 401 | output_path + f"{file}_slurs.npy", 402 | labels_slr, 403 | allow_pickle=False, 404 | ) 405 | 406 | # wave path|label path|label frame|score path|score duration;上面是一个.(当前目录),下面是两个..(从子目录调用) 407 | path_wave = wav_path + f"{file}.wav" 408 | path_label = output_path + f"{file}_label.npy" 409 | path_label_dur = output_path + f"{file}_label_dur.npy" 410 | path_score = output_path + f"{file}_score.npy" 411 | path_score_dur = output_path + f"{file}_score_dur.npy" 412 | path_pitch = output_path + f"{file}_pitch.npy" 413 | path_slurs = output_path + f"{file}_slurs.npy" 414 | print( 415 | f"{path_wave}|{path_label}|{path_label_dur}|{path_score}|{path_score_dur}|{path_pitch}|{path_slurs}", 416 | file=vits_file, 417 | ) 418 | 419 | fo.close() 420 | vits_file.close() 421 | print(len(set(all_txt))) # 统计非重复的句子个数 422 | -------------------------------------------------------------------------------- /prepare/dur_to_frame.py: -------------------------------------------------------------------------------- 1 | def dur_to_frame(ds, fs, hop_size): 2 | frames = [int(i * fs / hop_size + 0.5) for i in ds] 3 | return frames 4 | -------------------------------------------------------------------------------- /prepare/gen_ofuton_transcript.py: -------------------------------------------------------------------------------- 1 | import music21 as m21 2 | import os 3 | from typing import Iterable, List, Optional, Union 4 | 5 | 6 | def pyopenjtalk_g2p(text) -> List[str]: 7 | import pyopenjtalk 8 | 9 | # phones is a str object separated by space 10 | phones = pyopenjtalk.g2p(text, kana=False) 11 | phones = phones.split(" ") 12 | return phones 13 | 14 | 15 | def text2tokens_svs(syllable: str) -> List[str]: 16 | customed_dic = { 17 | "へ": ["h", "e"], 18 | "ヴぁ": ["v", "a"], 19 | "ヴぃ": ["v", "i"], 20 | "ヴぇ": ["v", "e"], 21 | "ヴぉ": ["v", "i"], 22 | "でぇ": ["dy", "e"], 23 | } 24 | tokens = pyopenjtalk_g2p(syllable) 25 | if syllable in customed_dic: 26 | tokens = customed_dic[syllable] 27 | return tokens 28 | 29 | 30 | def note_filter(note_name, note_map): 31 | note_name = note_name.replace("-", "") 32 | if "#" in note_name: 33 | note_name = note_name + "/" + note_map[note_name[0]] + "b" + note_name[2] 34 | return note_name 35 | 36 | 37 | # eval(valid), dev(test), train 38 | def process(base_path, file_path): 39 | 40 | note_map = { 41 | "A": "B", 42 | "B": "C", 43 | "C": "D", 44 | "D": "E", 45 | "E": "F", 46 | "F": "G", 47 | "G": "A", 48 | } 49 | 50 | label_path = file_path + "label" 51 | text_path = file_path + "text" 52 | data = [] 53 | 54 | for line in open(label_path, "r"): 55 | # add phn and phn_dur 56 | str_list = line.replace("\n", "").split(" ") 57 | name = str_list[0] 58 | phn_dur = [] 59 | phn = [] 60 | score = [] 61 | score_dur = [] 62 | 63 | for i in range(1, len(str_list)): 64 | try: 65 | phn_dur_ = str(round(float(str_list[i + 1]) - float(str_list[i]), 6)) 66 | # phn_dur_ = float(str_list[i + 1]) - float(str_list[i]) 67 | except: 68 | if str_list[i] != "" and str_list[i].isalpha(): 69 | phn.append(str_list[i]) 70 | phn_dict.add(str_list[i]) 71 | continue 72 | 73 | phn_dur.append(phn_dur_) 74 | 75 | # append text 76 | for line2 in open(text_path, "r"): 77 | str_list2 = line2.replace("\n", "").split(" ") 78 | if str_list2[0] != name: 79 | continue 80 | else: 81 | text_ = str_list2[1] 82 | break 83 | 84 | # add score and score_dur 85 | musicxmlscp = open(os.path.join(file_path, "xml.scp"), "r", encoding="utf-8") 86 | for xml_line in musicxmlscp: 87 | xmlline = xml_line.strip().split(" ") 88 | recording_id = xmlline[0] 89 | if recording_id != name: 90 | continue 91 | else: 92 | path = base_path + xmlline[1] 93 | parse_file = m21.converter.parse(path) 94 | part = parse_file.parts[0].flat 95 | m = parse_file.metronomeMarkBoundaries() 96 | tempo = m[0][2] 97 | for part in parse_file.parts: 98 | for note in part.recurse().notes: 99 | note_dur_ = note.quarterLength * 60 / tempo.number 100 | note_name_ = note_filter(note.nameWithOctave, note_map) 101 | note_text_ = note.lyric 102 | # print("note_text1", text_) 103 | # print("note_text_", note_text_) 104 | if not note_text_: 105 | continue 106 | note_phn_ = text2tokens_svs(note_text_) 107 | for i in range(len(note_phn_)): 108 | score.append(note_name_) 109 | score_dur.append(str(note_dur_)) 110 | # print("note_phn", note_phn_) 111 | break 112 | 113 | # print("tempo", tempo.number) 114 | 115 | # TODO: add slur. currently all 0 116 | slur = [] 117 | for i in range(len(phn)): 118 | slur.append("0") 119 | 120 | # add one line 121 | data.append( 122 | name 123 | + "|" 124 | + text_ 125 | + "|" 126 | + " ".join(phn) 127 | + "|" 128 | + " ".join(score) 129 | + "|" 130 | + " ".join(score_dur) 131 | + "|" 132 | + " ".join(phn_dur) 133 | + "|" 134 | + " ".join(slur) 135 | ) 136 | print(data) 137 | assert len(phn) == len(phn_dur) 138 | assert len(phn) == len(score) 139 | assert len(phn) == len(score_dur) 140 | assert len(phn) == len(slur) 141 | return data 142 | 143 | 144 | base_path = "/home/yyu479/espnet/egs2/ofuton_p_utagoe_db/svs1/" 145 | 146 | data = [] 147 | phn_dict = set() 148 | data_eval = process(base_path, base_path + "dump/raw/eval/") 149 | data_dev = process(base_path, base_path + "dump/raw/org/dev/") 150 | data_tr_no_dev = process(base_path, base_path + "dump/raw/org/tr_no_dev/") 151 | data = data_eval + data_dev + data_tr_no_dev 152 | 153 | with open("transcriptions.txt", "w") as f: 154 | for i in data: 155 | f.writelines(i) 156 | f.write("\n") 157 | 158 | phn_dict_sort = list(phn_dict) 159 | phn_dict_sort.sort() 160 | with open("dict.txt", "w") as f: 161 | for i in phn_dict_sort: 162 | f.writelines(i) 163 | f.write("\n") 164 | -------------------------------------------------------------------------------- /prepare/midi-HZ.scp: -------------------------------------------------------------------------------- 1 | 127 G9 12543.9 2 | 126 F#9/Gb9 11839.8 3 | 125 F9 11175.3 4 | 124 E9 10548.1 5 | 123 D#9/Eb9 9956.1 6 | 122 D9 9397.3 7 | 121 C#9/Db9 8869.8 8 | 120 C9 8372 9 | 119 B8 7902.1 10 | 118 A#8/Bb8 7458.6 11 | 117 A8 7040 12 | 116 G#8/Ab8 6644.9 13 | 115 G8 6271.9 14 | 114 F#8/Gb8 5919.9 15 | 113 F8 5587.7 16 | 112 E8 5274 17 | 111 D#8/Eb8 4978 18 | 110 D8 4698.6 19 | 109 C#8/Db8 4434.9 20 | 108 C8 4186 21 | 107 B7 3951.1 22 | 106 A#7/Bb7 3729.3 23 | 105 A7 3520 24 | 104 G#7/Ab7 3322.4 25 | 103 G7 3136 26 | 102 F#7/Gb7 2960 27 | 101 F7 2793.8 28 | 100 E7 2637 29 | 99 D#7/Eb7 2489 30 | 98 D7 2349.3 31 | 97 C#7/Db7 2217.5 32 | 96 C7 2093 33 | 95 B6 1975.5 34 | 94 A#6/Bb6 1864.7 35 | 93 A6 1760 36 | 92 G#6/Ab6 1661.2 37 | 91 G6 1568 38 | 90 F#6/Gb6 1480 39 | 89 F6 1396.9 40 | 88 E6 1318.5 41 | 87 D#6/Eb6 1244.5 42 | 86 D6 1174.7 43 | 85 C#6/Db6 1108.7 44 | 84 C6 1046.5 45 | 83 B5 987.8 46 | 82 A#5/Bb5 932.3 47 | 81 A5 880 48 | 80 G#5/Ab5 830.6 49 | 79 G5 784 50 | 78 F#5/Gb5 740 51 | 77 F5 698.5 52 | 76 E5 659.3 53 | 75 D#5/Eb5 622.3 54 | 74 D5 587.3 55 | 73 C#5/Db5 554.4 56 | 72 C5 523.3 57 | 71 B4 493.9 58 | 70 A#4/Bb4 466.2 59 | 69 A4 440 60 | 68 G#4/Ab4 415.3 61 | 67 G4 392 62 | 66 F#4/Gb4 370 63 | 65 F4 349.2 64 | 64 E4 329.6 65 | 63 D#4/Eb4 311.1 66 | 62 D4 293.7 67 | 61 C#4/Db4 277.2 68 | 60 C4 261.6 69 | 59 B3 246.9 70 | 58 A#3/Bb3 233.1 71 | 57 A3 220 72 | 56 G#3/Ab3 207.7 73 | 55 G3 196 74 | 54 F#3/Gb3 185 75 | 53 F3 174.6 76 | 52 E3 164.8 77 | 51 D#3/Eb3 155.6 78 | 50 D3 146.8 79 | 49 C#3/Db3 138.6 80 | 48 C3 130.8 81 | 47 B2 123.5 82 | 46 A#2/Bb2 116.5 83 | 45 A2 110 84 | 44 G#2/Ab2 103. 85 | 43 G2 98 86 | 42 F#2/Gb2 92.5 87 | 41 F2 87.3 88 | 40 E2 82.4 89 | 39 D#2/Eb2 77.8 90 | 38 D2 73.4 91 | 37 C#2/Db2 69.3 92 | 36 C2 65.4 93 | 35 B1 61.7 94 | 34 A#1/Bb1 58.3 95 | 33 A1 55 96 | 32 G#1/Ab1 51.9 97 | 31 G1 49 98 | 30 F#1/Gb1 46.2 99 | 29 F1 43.7 100 | 28 E1 41.2 101 | 27 D#1/Eb1 38.9 102 | 26 D1 36.7 103 | 25 C#1/Db1 34.6 104 | 24 C1 32.7 105 | 23 B0 30.9 106 | 22 A#0/Bb0 29.1 107 | 21 A0 27.5 108 | 0 rest 0 -------------------------------------------------------------------------------- /prepare/midi-note.scp: -------------------------------------------------------------------------------- 1 | 127 G9 2 | 126 F#9/Gb9 3 | 125 F9 4 | 124 E9 5 | 123 D#9/Eb9 6 | 122 D9 7 | 121 C#9/Db9 8 | 120 C9 9 | 119 B8 10 | 118 A#8/Bb8 11 | 117 A8 12 | 116 G#8/Ab8 13 | 115 G8 14 | 114 F#8/Gb8 15 | 113 F8 16 | 112 E8 17 | 111 D#8/Eb8 18 | 110 D8 19 | 109 C#8/Db8 20 | 108 C8 21 | 107 B7 22 | 106 A#7/Bb7 23 | 105 A7 24 | 104 G#7/Ab7 25 | 103 G7 26 | 102 F#7/Gb7 27 | 101 F7 28 | 100 E7 29 | 99 D#7/Eb7 30 | 98 D7 31 | 97 C#7/Db7 32 | 96 C7 33 | 95 B6 34 | 94 A#6/Bb6 35 | 93 A6 36 | 92 G#6/Ab6 37 | 91 G6 38 | 90 F#6/Gb6 39 | 89 F6 40 | 88 E6 41 | 87 D#6/Eb6 42 | 86 D6 43 | 85 C#6/Db6 44 | 84 C6 45 | 83 B5 46 | 82 A#5/Bb5 47 | 81 A5 48 | 80 G#5/Ab5 49 | 79 G5 50 | 78 F#5/Gb5 51 | 77 F5 52 | 76 E5 53 | 75 D#5/Eb5 54 | 74 D5 55 | 73 C#5/Db5 56 | 72 C5 57 | 71 B4 58 | 70 A#4/Bb4 59 | 69 A4 60 | 68 G#4/Ab4 61 | 67 G4 62 | 66 F#4/Gb4 63 | 65 F4 64 | 64 E4 65 | 63 D#4/Eb4 66 | 62 D4 67 | 61 C#4/Db4 68 | 60 C4 69 | 59 B3 70 | 58 A#3/Bb3 71 | 57 A3 72 | 56 G#3/Ab3 73 | 55 G3 74 | 54 F#3/Gb3 75 | 53 F3 76 | 52 E3 77 | 51 D#3/Eb3 78 | 50 D3 79 | 49 C#3/Db3 80 | 48 C3 81 | 47 B2 82 | 46 A#2/Bb2 83 | 45 A2 84 | 44 G#2/Ab2 85 | 43 G2 86 | 42 F#2/Gb2 87 | 41 F2 88 | 40 E2 89 | 39 D#2/Eb2 90 | 38 D2 91 | 37 C#2/Db2 92 | 36 C2 93 | 35 B1 94 | 34 A#1/Bb1 95 | 33 A1 96 | 32 G#1/Ab1 97 | 31 G1 98 | 30 F#1/Gb1 99 | 29 F1 100 | 28 E1 101 | 27 D#1/Eb1 102 | 26 D1 103 | 25 C#1/Db1 104 | 24 C1 105 | 23 B0 106 | 22 A#0/Bb0 107 | 21 A0 -------------------------------------------------------------------------------- /prepare/phone_map.py: -------------------------------------------------------------------------------- 1 | _pause = ["unk", "sos", "eos", "ap", "sp"] 2 | 3 | _initials = [ 4 | "b", 5 | "c", 6 | "ch", 7 | "d", 8 | "f", 9 | "g", 10 | "h", 11 | "j", 12 | "k", 13 | "l", 14 | "m", 15 | "n", 16 | "p", 17 | "q", 18 | "r", 19 | "s", 20 | "sh", 21 | "t", 22 | "w", 23 | "x", 24 | "y", 25 | "z", 26 | "zh", 27 | ] 28 | 29 | _finals = [ 30 | "a", 31 | "ai", 32 | "an", 33 | "ang", 34 | "ao", 35 | "e", 36 | "ei", 37 | "en", 38 | "eng", 39 | "er", 40 | "i", 41 | "ia", 42 | "ian", 43 | "iang", 44 | "iao", 45 | "ie", 46 | "in", 47 | "ing", 48 | "iong", 49 | "iu", 50 | "o", 51 | "ong", 52 | "ou", 53 | "u", 54 | "ua", 55 | "uai", 56 | "uan", 57 | "uang", 58 | "ui", 59 | "un", 60 | "uo", 61 | "v", 62 | "van", 63 | "ve", 64 | "vn", 65 | ] 66 | 67 | lang = "cn" 68 | if lang == "cn": 69 | symbols = _pause + _initials + _finals 70 | elif lang == "jp": 71 | symbols = [ 72 | "I", 73 | "N", 74 | "a", 75 | "b", 76 | "by", 77 | "ch", 78 | "cl", 79 | "d", 80 | "dy", 81 | "e", 82 | "f", 83 | "g", 84 | "gy", 85 | "h", 86 | "hy", 87 | "i", 88 | "j", 89 | "k", 90 | "ky", 91 | "m", 92 | "my", 93 | "n", 94 | "ny", 95 | "o", 96 | "p", 97 | "py", 98 | "r", 99 | "ry", 100 | "s", 101 | "sh", 102 | "t", 103 | "ts", 104 | "ty", 105 | "u", 106 | "v", 107 | "w", 108 | "y", 109 | "z", 110 | ] 111 | 112 | # Mappings from symbol to numeric ID and vice versa: 113 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 114 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 115 | 116 | 117 | def label_to_ids(phones): 118 | # use lower letter 119 | if lang == "cn": 120 | sequence = [_symbol_to_id[symbol.lower()] for symbol in phones] 121 | elif lang == "jp": 122 | sequence = [_symbol_to_id[symbol] for symbol in phones] 123 | return sequence 124 | 125 | 126 | def get_vocab_size(): 127 | return len(symbols) 128 | -------------------------------------------------------------------------------- /prepare/phone_uv.py: -------------------------------------------------------------------------------- 1 | # 普通话发音基础声母韵母 2 | # 普通话声母只有 4 个浊音:m、n、l、r,其余 17 个辅音声母都是清音 3 | # 汉语拼音的 y 和 w 只出现在零声母音节的开头,它们的作用主要是使音节界限清楚。 4 | # https://baijiahao.baidu.com/s?id=1655739561730224990&wfr=spider&for=pc 5 | 6 | uv_map = { 7 | "unk":0, 8 | "sos":0, 9 | "eos":0, 10 | "ap":0, 11 | "sp":0, 12 | "b":0, 13 | "c":0, 14 | "ch":0, 15 | "d":0, 16 | "f":0, 17 | "g":0, 18 | "h":0, 19 | "j":0, 20 | "k":0, 21 | "l":1, 22 | "m":1, 23 | "n":1, 24 | "p":0, 25 | "q":0, 26 | "r":1, 27 | "s":0, 28 | "sh":0, 29 | "t":0, 30 | "w":1, 31 | "x":0, 32 | "y":1, 33 | "z":0, 34 | "zh":0, 35 | "a":1, 36 | "ai":1, 37 | "an":1, 38 | "ang":1, 39 | "ao":1, 40 | "e":1, 41 | "ei":1, 42 | "en":1, 43 | "eng":1, 44 | "er":1, 45 | "i":1, 46 | "ia":1, 47 | "ian":1, 48 | "iang":1, 49 | "iao":1, 50 | "ie":1, 51 | "in":1, 52 | "ing":1, 53 | "iong":1, 54 | "iu":1, 55 | "o":1, 56 | "ong":1, 57 | "ou":1, 58 | "u":1, 59 | "ua":1, 60 | "uai":1, 61 | "uan":1, 62 | "uang":1, 63 | "ui":1, 64 | "un":1, 65 | "uo":1, 66 | "v":1, 67 | "van":1, 68 | "ve":1, 69 | "vn":1 70 | } -------------------------------------------------------------------------------- /prepare/preprocess.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | if __name__ == "__main__": 4 | 5 | alls = [] 6 | fo = open("./filelists/vits_file.txt", "r+") 7 | while True: 8 | try: 9 | message = fo.readline().strip() 10 | except Exception as e: 11 | print("nothing of except:", e) 12 | break 13 | if message == None: 14 | break 15 | if message == "": 16 | break 17 | alls.append(message) 18 | fo.close() 19 | 20 | valids = alls[:150] 21 | tests = alls[150:300] 22 | trains = alls[300:] 23 | 24 | random.shuffle(trains) 25 | 26 | fw = open("./filelists/singing_valid.txt", "w", encoding="utf-8") 27 | for strs in valids: 28 | print(strs, file=fw) 29 | fw.close() 30 | 31 | fw = open("./filelists/singing_test.txt", "w", encoding="utf-8") 32 | for strs in tests: 33 | print(strs, file=fw) 34 | 35 | fw = open("./filelists/singing_train.txt", "w", encoding="utf-8") 36 | for strs in trains: 37 | print(strs, file=fw) 38 | 39 | fw.close() 40 | -------------------------------------------------------------------------------- /prepare/preprocess_jp.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | if __name__ == "__main__": 4 | 5 | alls = [] 6 | fo = open("./filelists/vits_file.txt", "r+") 7 | while True: 8 | try: 9 | message = fo.readline().strip() 10 | except Exception as e: 11 | print("nothing of except:", e) 12 | break 13 | if message == None: 14 | break 15 | if message == "": 16 | break 17 | alls.append(message) 18 | fo.close() 19 | 20 | valids = alls[:70] 21 | tests = alls[70:134] 22 | trains = alls[134:] 23 | 24 | random.shuffle(trains) 25 | 26 | fw = open("./filelists/singing_valid.txt", "w", encoding="utf-8") 27 | for strs in valids: 28 | print(strs, file=fw) 29 | fw.close() 30 | 31 | fw = open("./filelists/singing_test.txt", "w", encoding="utf-8") 32 | for strs in tests: 33 | print(strs, file=fw) 34 | 35 | fw = open("./filelists/singing_train.txt", "w", encoding="utf-8") 36 | for strs in trains: 37 | print(strs, file=fw) 38 | 39 | fw.close() 40 | -------------------------------------------------------------------------------- /prepare/resample_wav.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def process_utterance( 6 | audio_dir, 7 | wav_dumpdir, 8 | segment, 9 | tgt_sr=24000, 10 | ): 11 | uid, lyrics, phns, notes, syb_dur, phn_dur, keep = segment.strip().split("|") 12 | cmd = "sox {}.wav -c 1 -t wavpcm -b 16 -r {} {}_bits16.wav".format( 13 | os.path.join(audio_dir, uid), 14 | tgt_sr, 15 | os.path.join(wav_dumpdir, uid), 16 | ) 17 | print("uid", uid) 18 | os.system(cmd) 19 | 20 | 21 | def process_subset(args, set_name): 22 | with open( 23 | os.path.join(args.src_data, "segments", set_name + ".txt"), 24 | "r", 25 | encoding="utf-8", 26 | ) as f: 27 | segments = f.read().strip().split("\n") 28 | for segment in segments: 29 | process_utterance( 30 | os.path.join(args.src_data, "segments", "wavs"), 31 | args.wav_dumpdir, 32 | segment, 33 | tgt_sr=args.sr, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser(description="Prepare Data for Opencpop Database") 39 | parser.add_argument("src_data", type=str, help="source data directory") 40 | parser.add_argument( 41 | "--wav_dumpdir", type=str, help="wav dump directoyr (rebit)", default="wav_dump" 42 | ) 43 | parser.add_argument("--sr", type=int, help="sampling rate (Hz)") 44 | args = parser.parse_args() 45 | 46 | for name in ["train", "test"]: 47 | process_subset(args, name) 48 | -------------------------------------------------------------------------------- /prepare/resample_wav.sh: -------------------------------------------------------------------------------- 1 | OPENCPOP=/home/yyu479/svs/data/Opencpop/ 2 | fs=24000 3 | output=/home/yyu479/VISinger_data/wav_dump_24k 4 | mkdir -p ${output} 5 | python resample_wav.py ${OPENCPOP} \ 6 | --wav_dumpdir ${output} \ 7 | --sr ${fs} -------------------------------------------------------------------------------- /resource/2005000151.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/2005000151.wav -------------------------------------------------------------------------------- /resource/2005000152.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/2005000152.wav -------------------------------------------------------------------------------- /resource/2006000186.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/2006000186.wav -------------------------------------------------------------------------------- /resource/2006000187.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/2006000187.wav -------------------------------------------------------------------------------- /resource/2008000268.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/2008000268.wav -------------------------------------------------------------------------------- /resource/vising_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/vising_loss.png -------------------------------------------------------------------------------- /resource/vising_mel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jerryuhoo/VISinger/ad8bc167c10275dd513ae466e73deae2f7045c99/resource/vising_mel.png -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | nohup python train.py -c configs/singing_base.json -m singing_base & -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict= {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 48 | iteration, checkpoint_path)) 49 | if hasattr(model, 'module'): 50 | state_dict = model.module.state_dict() 51 | else: 52 | state_dict = model.state_dict() 53 | torch.save({'model': state_dict, 54 | 'iteration': iteration, 55 | 'optimizer': optimizer.state_dict(), 56 | 'learning_rate': learning_rate}, checkpoint_path) 57 | 58 | 59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 60 | for k, v in scalars.items(): 61 | writer.add_scalar(k, v, global_step) 62 | for k, v in histograms.items(): 63 | writer.add_histogram(k, v, global_step) 64 | for k, v in images.items(): 65 | writer.add_image(k, v, global_step, dataformats='HWC') 66 | for k, v in audios.items(): 67 | writer.add_audio(k, v, global_step, audio_sampling_rate) 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | 78 | def plot_spectrogram_to_numpy(spectrogram): 79 | global MATPLOTLIB_FLAG 80 | if not MATPLOTLIB_FLAG: 81 | import matplotlib 82 | matplotlib.use("Agg") 83 | MATPLOTLIB_FLAG = True 84 | mpl_logger = logging.getLogger('matplotlib') 85 | mpl_logger.setLevel(logging.WARNING) 86 | import matplotlib.pylab as plt 87 | import numpy as np 88 | 89 | fig, ax = plt.subplots(figsize=(10,2)) 90 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 91 | interpolation='none') 92 | plt.colorbar(im, ax=ax) 93 | plt.xlabel("Frames") 94 | plt.ylabel("Channels") 95 | plt.tight_layout() 96 | 97 | fig.canvas.draw() 98 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 99 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 100 | plt.close() 101 | return data 102 | 103 | 104 | def plot_alignment_to_numpy(alignment, info=None): 105 | global MATPLOTLIB_FLAG 106 | if not MATPLOTLIB_FLAG: 107 | import matplotlib 108 | matplotlib.use("Agg") 109 | MATPLOTLIB_FLAG = True 110 | mpl_logger = logging.getLogger('matplotlib') 111 | mpl_logger.setLevel(logging.WARNING) 112 | import matplotlib.pylab as plt 113 | import numpy as np 114 | 115 | fig, ax = plt.subplots(figsize=(6, 4)) 116 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 117 | interpolation='none') 118 | fig.colorbar(im, ax=ax) 119 | xlabel = 'Decoder timestep' 120 | if info is not None: 121 | xlabel += '\n\n' + info 122 | plt.xlabel(xlabel) 123 | plt.ylabel('Encoder timestep') 124 | plt.tight_layout() 125 | 126 | fig.canvas.draw() 127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 129 | plt.close() 130 | return data 131 | 132 | 133 | def load_wav_to_torch(full_path): 134 | sampling_rate, data = read(full_path) 135 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 136 | 137 | 138 | def load_filepaths_and_text(filename, split="|"): 139 | with open(filename, encoding='utf-8') as f: 140 | filepaths_and_text = [line.strip().split(split) for line in f] 141 | return filepaths_and_text 142 | 143 | 144 | def get_hparams(init=True): 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 147 | help='JSON file for configuration') 148 | parser.add_argument('-m', '--model', type=str, required=True, 149 | help='Model name') 150 | 151 | args = parser.parse_args() 152 | model_dir = os.path.join("./logs", args.model) 153 | 154 | if not os.path.exists(model_dir): 155 | os.makedirs(model_dir) 156 | 157 | config_path = args.config 158 | config_save_path = os.path.join(model_dir, "config.json") 159 | if init: 160 | with open(config_path, "r") as f: 161 | data = f.read() 162 | with open(config_save_path, "w") as f: 163 | f.write(data) 164 | else: 165 | with open(config_save_path, "r") as f: 166 | data = f.read() 167 | config = json.loads(data) 168 | 169 | hparams = HParams(**config) 170 | hparams.model_dir = model_dir 171 | return hparams 172 | 173 | 174 | def get_hparams_from_dir(model_dir): 175 | config_save_path = os.path.join(model_dir, "config.json") 176 | with open(config_save_path, "r") as f: 177 | data = f.read() 178 | config = json.loads(data) 179 | 180 | hparams =HParams(**config) 181 | hparams.model_dir = model_dir 182 | return hparams 183 | 184 | 185 | def get_hparams_from_file(config_path): 186 | with open(config_path, "r") as f: 187 | data = f.read() 188 | config = json.loads(data) 189 | 190 | hparams =HParams(**config) 191 | return hparams 192 | 193 | 194 | def check_git_hash(model_dir): 195 | source_dir = os.path.dirname(os.path.realpath(__file__)) 196 | if not os.path.exists(os.path.join(source_dir, ".git")): 197 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 198 | source_dir 199 | )) 200 | return 201 | 202 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 203 | 204 | path = os.path.join(model_dir, "githash") 205 | if os.path.exists(path): 206 | saved_hash = open(path).read() 207 | if saved_hash != cur_hash: 208 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 209 | saved_hash[:8], cur_hash[:8])) 210 | else: 211 | open(path, "w").write(cur_hash) 212 | 213 | 214 | def get_logger(model_dir, filename="train.log"): 215 | global logger 216 | logger = logging.getLogger(os.path.basename(model_dir)) 217 | logger.setLevel(logging.DEBUG) 218 | 219 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 220 | if not os.path.exists(model_dir): 221 | os.makedirs(model_dir) 222 | h = logging.FileHandler(os.path.join(model_dir, filename)) 223 | h.setLevel(logging.DEBUG) 224 | h.setFormatter(formatter) 225 | logger.addHandler(h) 226 | return logger 227 | 228 | 229 | class HParams(): 230 | def __init__(self, **kwargs): 231 | for k, v in kwargs.items(): 232 | if type(v) == dict: 233 | v = HParams(**v) 234 | self[k] = v 235 | 236 | def keys(self): 237 | return self.__dict__.keys() 238 | 239 | def items(self): 240 | return self.__dict__.items() 241 | 242 | def values(self): 243 | return self.__dict__.values() 244 | 245 | def __len__(self): 246 | return len(self.__dict__) 247 | 248 | def __getitem__(self, key): 249 | return getattr(self, key) 250 | 251 | def __setitem__(self, key, value): 252 | return setattr(self, key, value) 253 | 254 | def __contains__(self, key): 255 | return key in self.__dict__ 256 | 257 | def __repr__(self): 258 | return self.__dict__.__repr__() 259 | -------------------------------------------------------------------------------- /vsinging_debug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | from scipy.io import wavfile 6 | from time import * 7 | 8 | import torch 9 | import utils 10 | from models import SynthesizerTrn 11 | 12 | 13 | def save_wav(wav, path, rate): 14 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 15 | wavfile.write(path, rate, wav.astype(np.int16)) 16 | 17 | 18 | # define model and load checkpoint 19 | hps = utils.get_hparams_from_file("./configs/singing_base.json") 20 | 21 | net_g = SynthesizerTrn( 22 | hps.data.filter_length // 2 + 1, 23 | hps.train.segment_size // hps.data.hop_length, 24 | **hps.model, 25 | ).cuda() 26 | 27 | _ = utils.load_checkpoint("./logs/singing_base/G_160000.pth", net_g, None) 28 | net_g.eval() 29 | # net_g.remove_weight_norm() 30 | 31 | # check directory existence 32 | if not os.path.exists("./singing_out"): 33 | os.makedirs("./singing_out") 34 | 35 | idxs = [ 36 | "2001000001", 37 | "2001000002", 38 | "2001000003", 39 | "2001000004", 40 | "2001000005", 41 | "2001000006", 42 | "2051001912", 43 | "2051001913", 44 | "2051001914", 45 | "2051001915", 46 | "2051001916", 47 | "2051001917", 48 | ] 49 | for idx in idxs: 50 | phone = np.load(f"../VISinger_data/label_vits/{idx}_label.npy") 51 | score = np.load(f"../VISinger_data/label_vits/{idx}_score.npy") 52 | pitch = np.load(f"../VISinger_data/label_vits/{idx}_pitch.npy") 53 | slurs = np.load(f"../VISinger_data/label_vits/{idx}_slurs.npy") 54 | phone = torch.LongTensor(phone) 55 | score = torch.LongTensor(score) 56 | pitch = torch.LongTensor(pitch) 57 | slurs = torch.LongTensor(slurs) 58 | 59 | phone_lengths = phone.size()[0] 60 | 61 | begin_time = time() 62 | with torch.no_grad(): 63 | phone = phone.cuda().unsqueeze(0) 64 | score = score.cuda().unsqueeze(0) 65 | pitch = pitch.cuda().unsqueeze(0) 66 | slurs = slurs.cuda().unsqueeze(0) 67 | phone_lengths = torch.LongTensor([phone_lengths]).cuda() 68 | audio = ( 69 | net_g.infer(phone, phone_lengths, score, pitch, slurs)[0][0, 0] 70 | .data.cpu() 71 | .float() 72 | .numpy() 73 | ) 74 | end_time = time() 75 | run_time = end_time - begin_time 76 | print("Syth Time (Seconds):", run_time) 77 | data_len = len(audio) / 16000 78 | print("Wave Time (Seconds):", data_len) 79 | print("Real time Rate (%):", run_time / data_len) 80 | save_wav(audio, f"./singing_out/singing_{idx}.wav", hps.data.sampling_rate) 81 | 82 | # can be deleted 83 | os.system("chmod 777 ./singing_out -R") 84 | -------------------------------------------------------------------------------- /vsinging_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from scipy.io import wavfile 6 | from time import * 7 | 8 | import torch 9 | import utils 10 | from models import Synthesizer 11 | from prepare.data_vits import SingInput 12 | from prepare.data_vits import FeatureInput 13 | from prepare.phone_map import get_vocab_size 14 | 15 | 16 | def save_wav(wav, path, rate): 17 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 18 | wavfile.write(path, rate, wav.astype(np.int16)) 19 | 20 | 21 | use_cuda = True 22 | 23 | # define model and load checkpoint 24 | hps = utils.get_hparams_from_file("./configs/singing_base.json") 25 | 26 | vocab_size = get_vocab_size() 27 | 28 | model_path = "./logs/singing_base/" 29 | saved_models = os.listdir(model_path) 30 | iter_nums = [] 31 | for i in range(len(saved_models)): 32 | if os.path.splitext(saved_models[i])[1] == ".pth" and "G" in saved_models[i]: 33 | iter_nums.append(int(os.path.splitext(saved_models[i])[0][2:])) 34 | iter_nums = sorted(iter_nums, reverse=True) 35 | 36 | print("start infering (G_" + str(iter_nums[0]) + ".pth)") 37 | 38 | net_g = Synthesizer( 39 | vocab_size, 40 | hps.data.filter_length // 2 + 1, 41 | hps.train.segment_size // hps.data.hop_length, 42 | **hps.model, 43 | ) # .cuda() 44 | 45 | if use_cuda: 46 | net_g = net_g.cuda() 47 | 48 | _ = utils.load_checkpoint( 49 | "./logs/singing_base/G_" + str(iter_nums[0]) + ".pth", net_g, None 50 | ) 51 | 52 | net_g.eval() 53 | # net_g.remove_weight_norm() 54 | 55 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 56 | featureInput = FeatureInput( 57 | "../VISinger_data/wav_dump_16k/", hps.data.sampling_rate, hps.data.hop_length 58 | ) 59 | 60 | # check directory existence 61 | if not os.path.exists("./singing_out"): 62 | os.makedirs("./singing_out") 63 | 64 | fo = open("./vsinging_infer.txt", "r+") 65 | while True: 66 | try: 67 | message = fo.readline().strip() 68 | except Exception as e: 69 | print("nothing of except:", e) 70 | break 71 | if message == None: 72 | break 73 | if message == "": 74 | break 75 | print(message) 76 | ( 77 | file, 78 | labels_ids, 79 | labels_frames, 80 | scores_ids, 81 | scores_dur, 82 | labels_slr, 83 | labels_uvs, 84 | ) = singInput.parseInput(message) 85 | 86 | phone = torch.LongTensor(labels_ids) 87 | score = torch.LongTensor(scores_ids) 88 | score_dur = torch.LongTensor(scores_dur) 89 | slurs = torch.LongTensor(labels_slr) 90 | 91 | phone_lengths = phone.size()[0] 92 | 93 | begin_time = time() 94 | with torch.no_grad(): 95 | if use_cuda: 96 | phone = phone.cuda().unsqueeze(0) 97 | score = score.cuda().unsqueeze(0) 98 | score_dur = score_dur.cuda().unsqueeze(0) 99 | slurs = slurs.cuda().unsqueeze(0) 100 | phone_lengths = torch.LongTensor([phone_lengths]).cuda() 101 | else: 102 | phone = phone.unsqueeze(0) 103 | score = score.unsqueeze(0) 104 | score_dur = score_dur.unsqueeze(0) 105 | slurs = slurs.unsqueeze(0) 106 | phone_lengths = torch.LongTensor([phone_lengths]) 107 | audio = ( 108 | net_g.infer(phone, phone_lengths, score, score_dur, slurs)[0][0, 0] 109 | .data.cpu() 110 | .float() 111 | .numpy() 112 | ) 113 | end_time = time() 114 | run_time = end_time - begin_time 115 | print("Syth Time (Seconds):", run_time) 116 | data_len = len(audio) / hps.data.sampling_rate 117 | print("Wave Time (Seconds):", data_len) 118 | print("Real time Rate (%):", run_time / data_len) 119 | save_wav(audio, f"./singing_out/{file}.wav", hps.data.sampling_rate) 120 | fo.close() 121 | # can be deleted 122 | os.system("chmod 777 ./singing_out -R") 123 | -------------------------------------------------------------------------------- /vsinging_infer_jp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from scipy.io import wavfile 6 | from time import * 7 | 8 | import torch 9 | import utils 10 | from models import SynthesizerTrn 11 | from prepare.data_vits_phn_ofuton import SingInput 12 | from prepare.data_vits_phn_ofuton import FeatureInput 13 | from prepare.phone_map import get_vocab_size 14 | 15 | 16 | def save_wav(wav, path, rate): 17 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 18 | wavfile.write(path, rate, wav.astype(np.int16)) 19 | 20 | 21 | use_cuda = True 22 | 23 | # define model and load checkpoint 24 | hps = utils.get_hparams_from_file("./configs/singing_base.json") 25 | 26 | vocab_size = get_vocab_size() 27 | 28 | net_g = SynthesizerTrn( 29 | vocab_size, 30 | hps.data.filter_length // 2 + 1, 31 | hps.train.segment_size // hps.data.hop_length, 32 | **hps.model, 33 | ) # .cuda() 34 | 35 | if use_cuda: 36 | net_g = net_g.cuda() 37 | 38 | _ = utils.load_checkpoint("./logs/singing_base/G_40000.pth", net_g, None) 39 | net_g.eval() 40 | # net_g.remove_weight_norm() 41 | 42 | singInput = SingInput(hps.data.sampling_rate, hps.data.hop_length) 43 | featureInput = FeatureInput( 44 | "../VISinger_data/wav_dump_16k/", hps.data.sampling_rate, hps.data.hop_length 45 | ) 46 | 47 | # check directory existence 48 | if not os.path.exists("./singing_out"): 49 | os.makedirs("./singing_out") 50 | 51 | fo = open("./vsinging_infer_jp.txt", "r+") 52 | while True: 53 | try: 54 | message = fo.readline().strip() 55 | except Exception as e: 56 | print("nothing of except:", e) 57 | break 58 | if message == None: 59 | break 60 | if message == "": 61 | break 62 | print(message) 63 | ( 64 | file, 65 | labels_ids, 66 | labels_frames, 67 | scores_ids, 68 | scores_dur, 69 | labels_slr, 70 | # labels_uvs, 71 | ) = singInput.parseInput(message) 72 | 73 | phone = torch.LongTensor(labels_ids) 74 | score = torch.LongTensor(scores_ids) 75 | score_dur = torch.LongTensor(scores_dur) 76 | slurs = torch.LongTensor(labels_slr) 77 | 78 | phone_lengths = phone.size()[0] 79 | 80 | begin_time = time() 81 | with torch.no_grad(): 82 | if use_cuda: 83 | phone = phone.cuda().unsqueeze(0) 84 | score = score.cuda().unsqueeze(0) 85 | score_dur = score_dur.cuda().unsqueeze(0) 86 | slurs = slurs.cuda().unsqueeze(0) 87 | phone_lengths = torch.LongTensor([phone_lengths]).cuda() 88 | else: 89 | phone = phone.unsqueeze(0) 90 | score = score.unsqueeze(0) 91 | score_dur = score_dur.unsqueeze(0) 92 | slurs = slurs.unsqueeze(0) 93 | phone_lengths = torch.LongTensor([phone_lengths]) 94 | audio = ( 95 | net_g.infer(phone, phone_lengths, score, score_dur, slurs)[0][0, 0] 96 | .data.cpu() 97 | .float() 98 | .numpy() 99 | ) 100 | end_time = time() 101 | run_time = end_time - begin_time 102 | print("Syth Time (Seconds):", run_time) 103 | data_len = len(audio) / hps.data.sampling_rate 104 | print("Wave Time (Seconds):", data_len) 105 | print("Real time Rate (%):", run_time / data_len) 106 | save_wav(audio, f"./singing_out/{file}.wav", hps.data.sampling_rate) 107 | fo.close() 108 | # can be deleted 109 | os.system("chmod 777 ./singing_out -R") 110 | -------------------------------------------------------------------------------- /vsinging_song.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from scipy.io import wavfile 5 | from time import * 6 | 7 | import torch 8 | import utils 9 | from models import Synthesizer 10 | from prepare.data_vits import SingInput 11 | from prepare.data_vits import FeatureInput 12 | 13 | 14 | def save_wav(wav, path, rate): 15 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) * 0.6 16 | wavfile.write(path, rate, wav.astype(np.int16)) 17 | 18 | 19 | # define model and load checkpoint 20 | hps = utils.get_hparams_from_file("./configs/singing_base.json") 21 | 22 | net_g = Synthesizer( 23 | hps.data.filter_length // 2 + 1, 24 | hps.train.segment_size // hps.data.hop_length, 25 | **hps.model, 26 | ) 27 | 28 | # _ = utils.load_checkpoint("./logs/singing_base/G_160000.pth", net_g, None) 29 | # net_g.remove_weight_norm() 30 | # torch.save(net_g, "visinger.pth") 31 | net_g = torch.load("visinger.pth", map_location="cpu") 32 | net_g.eval().cuda() 33 | # net_g.remove_weight_norm() 34 | 35 | singInput = SingInput(16000, 256) 36 | featureInput = FeatureInput("../VISinger_data/wav_dump_16k/", 16000, 256) 37 | 38 | # check directory existence 39 | if not os.path.exists("./singing_out"): 40 | os.makedirs("./singing_out") 41 | 42 | fo = open("./vsinging_song_midi.txt", "r+") 43 | song_rate = 16000 44 | song_time = fo.readline().strip().split("|")[1] 45 | song_length = int(song_rate * (float(song_time) + 30)) 46 | song_data = np.zeros(song_length, dtype="float32") 47 | while True: 48 | try: 49 | message = fo.readline().strip() 50 | except Exception as e: 51 | print("nothing of except:", e) 52 | break 53 | if message == None: 54 | break 55 | if message == "": 56 | break 57 | ( 58 | item_indx, 59 | item_time, 60 | labels_ids, 61 | labels_frames, 62 | scores_ids, 63 | scores_dur, 64 | labels_slr, 65 | labels_uvs, 66 | ) = singInput.parseSong(message) 67 | labels_ids = singInput.expandInput(labels_ids, labels_frames) 68 | labels_uvs = singInput.expandInput(labels_uvs, labels_frames) 69 | labels_slr = singInput.expandInput(labels_slr, labels_frames) 70 | scores_ids = singInput.expandInput(scores_ids, labels_frames) 71 | scores_pit = singInput.scorePitch(scores_ids) 72 | # elments by elments 73 | scores_pit = scores_pit * labels_uvs 74 | # scores_pit = singInput.smoothPitch(scores_pit) 75 | # scores_pit = scores_pit * labels_uvs 76 | phone = torch.LongTensor(labels_ids) 77 | score = torch.LongTensor(scores_ids) 78 | slurs = torch.LongTensor(labels_slr) 79 | pitch = featureInput.coarse_f0(scores_pit) 80 | pitch = torch.LongTensor(pitch) 81 | 82 | phone_lengths = phone.size()[0] 83 | 84 | begin_time = time() 85 | with torch.no_grad(): 86 | phone = phone.cuda().unsqueeze(0) 87 | score = score.cuda().unsqueeze(0) 88 | pitch = pitch.cuda().unsqueeze(0) 89 | slurs = slurs.cuda().unsqueeze(0) 90 | phone_lengths = torch.LongTensor([phone_lengths]).cuda() 91 | audio = ( 92 | net_g.infer(phone, phone_lengths, score, pitch, slurs)[0][0, 0] 93 | .data.cpu() 94 | .float() 95 | .numpy() 96 | ) 97 | end_time = time() 98 | run_time = end_time - begin_time 99 | print("Syth Time (Seconds):", run_time) 100 | data_len = len(audio) / 16000 101 | print("Wave Time (Seconds):", data_len) 102 | print("Real time Rate (%):", run_time / data_len) 103 | save_wav(audio, f"./singing_out/{item_indx}.wav", hps.data.sampling_rate) 104 | # wav 105 | item_start = int(song_rate * float(item_time)) 106 | item_end = item_start + len(audio) 107 | song_data[item_start:item_end] = audio 108 | # out of for 109 | song_data = np.array(song_data, dtype="float32") 110 | save_wav(song_data, f"./singing_out/_song.wav", hps.data.sampling_rate) 111 | fo.close() 112 | # can be deleted 113 | os.system("chmod 777 ./singing_out -R") 114 | -------------------------------------------------------------------------------- /vsinging_song_midi.txt: -------------------------------------------------------------------------------- 1 | song_time|116.88723672656248 2 | 0|0000.694| 化 外 山 间 岁 月 皆 看 老|h ua w ai sh an j ian s ui y ve j ie k an l ao|57 57 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.506 0.506|0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 0.241 0.096 0.506|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 | 1|0006.140| 洛 雪 无 声 天 地 掩 尘 嚣|l uo x ve w u sh eng t ian d i y an ch en x iao|57 57 64 64 62 62 60 60 59 59 60 60 62 62 64 64 69 69|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.590 0.590|0.096 0.249 0.088 0.249 0.088 0.249 0.088 0.305 0.032 0.305 0.032 0.249 0.088 0.273 0.064 0.249 0.088 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 | 2|0010.923| 他 看 尽 晨 曦 日 暮 AP 饮 罢 腰 间 酒 一 壶 AP 依 稀 当 年 孤 旅 踏 苍 霞 尽 处|t a k an j in ch en x i r i m u AP y in b a y ao j ian j iu y i h u AP y i x i d ang n ian g u l v t a c ang x ia j in ch u|60 60 62 62 64 64 62 62 67 67 64 64 62 62 rest 64 64 67 67 72 72 71 71 69 69 67 67 69 69 rest 67 67 64 64 62 62 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 1.180 1.180|0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.249 0.088 0.337 0.249 0.088 0.297 0.040 0.249 0.088 0.273 0.064 0.273 0.064 0.249 0.088 0.273 0.064 0.421 0.165 0.088 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 1.180|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 5 | 3|0021.678| 风 霜 冷 冽 他 眉 目 AP 时 光 雕 琢 他 风 骨 AP 浮 世 南 柯 一 梦 冷 暖 都 藏 住|f eng sh uang l eng l ie t a m ei m u AP sh i g uang d iao z uo t a f eng g u AP f u sh i n an k e y i m eng l eng n uan d ou c ang zh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.064 0.249 0.088 0.241 0.096 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.337 0.249 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.337 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.241 0.096 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 6 | 4|0032.356| 哪 杯 酒 烫 过 肺 腑 AP 曾 换 他 睥 睨 一 顾 AP 剑 破 乾 坤 轮 转 山 河 倾 覆|n a b ei j iu t ang g uo f ei f u AP c eng h uan t a p i n i y i g u AP j ian p o q ian k un l un zh uan sh an h e q ing f u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 67 67 62 62 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 1.348 1.348|0.088 0.297 0.040 0.273 0.064 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.337 0.249 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.337 0.273 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.241 0.096 0.273 0.064 0.249 0.088 0.610 0.064 0.241 0.096 0.273 0.064 1.348|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 | 5|0043.620| 他 三 清 尘 外 剔 去 心 中 毒|t a s an q ing ch en w ai t i q v x in zh ong d u|57 57 60 60 64 64 62 62 60 60 59 59 60 60 62 62 64 64 57 57|0.169 0.169 0.169 0.169 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.590 0.590|0.032 0.081 0.088 0.073 0.096 0.610 0.064 0.249 0.088 0.305 0.032 0.241 0.096 0.249 0.088 0.273 0.064 0.305 0.032 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 | 6|0048.981| 尝 世 间 百 味 甘 醇 与 涩 苦|ch ang sh i j ian b ai w ei g an ch un y v s e k u|57 57 60 60 64 64 62 62 60 60 59 59 60 60 62 62 64 64 69 69|0.169 0.169 0.169 0.169 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 1.180 1.180|0.064 0.081 0.088 0.105 0.064 0.634 0.040 0.249 0.088 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 1.180|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 | 7|0053.929| 曾 有 谁 偏 执 不 悟 AP 谈 笑 斗 酒 至 酣 处 AP 而 今 不 过 拍 去 肩 上 红 尘 土|c eng y ou sh ui p ian zh i b u w u AP t an x iao d ou j iu zh i h an ch u AP er j in b u g uo p ai q v j ian sh ang h ong ch en t u|60 60 62 62 64 64 67 67 64 64 67 67 62 62 rest 62 62 67 67 72 72 71 71 69 69 67 67 69 69 rest 67 64 64 62 62 62 62 64 64 67 67 60 60 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.088 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.297 0.040 0.249 0.088 0.337 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.273 0.064 0.337 0.337 0.273 0.064 0.297 0.040 0.273 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.249 0.088 0.273 0.064 0.273 0.064 0.305 0.032 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 10 | 8|0064.655| 风 霜 冷 冽 他 眉 目 时 光 雕 琢 他 风 骨 浮 世 南 柯 一 梦 冷 暖 都 藏 住|f eng sh uang l eng l ie t a m ei m u sh i g uang d iao z uo t a f eng g u f u sh i n an k e y i m eng l eng n uan d ou c ang zh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 64 64 62 62 64 64 62 62 67 67 64 64 60 60 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.506 0.506|0.064 0.249 0.088 0.241 0.096 0.241 0.096 0.305 0.032 0.249 0.088 0.249 0.088 0.586 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.610 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.241 0.096 0.249 0.088 0.305 0.032 0.249 0.088 0.273 0.064 0.506|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 11 | 9|0075.418| 哪 杯 酒 烫 过 肺 腑 曾 换 他 睥 睨 一 顾 AP 剑 破 乾 坤 轮 转 山 河 倾 覆|n a b ei j iu t ang g uo f ei f u c eng h uan t a p i n i y i g u AP j ian p o q ian k un l un zh uan sh an h e q ing f u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 67 67 62 62 60 60 59 59 60 60 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.337 0.337 0.674 0.674|0.088 0.297 0.040 0.273 0.064 0.305 0.032 0.273 0.064 0.273 0.064 0.273 0.064 0.586 0.088 0.273 0.064 0.305 0.032 0.249 0.088 0.249 0.088 0.249 0.088 0.273 0.064 0.421 0.189 0.064 0.249 0.088 0.241 0.096 0.273 0.064 0.241 0.096 0.273 0.064 0.249 0.088 0.610 0.064 0.241 0.096 0.273 0.064 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 12 | 10|0086.260| 到 最 后 沧 海 一 粟 AP 何 必 江 湖 多 殊 途 AP 当 年 论 剑 峰 顶 谁 几 笔 成 书|d ao z ui h ou c ang h ai y i s u AP h e b i j iang h u d uo sh u t u AP d ang n ian l un j ian f eng d ing sh ui j i b i ch eng sh u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 60 60 64 64 62 62 60 60 57 57 60 60 57 57 67 67 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674|0.032 0.249 0.088 0.273 0.064 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.421 0.189 0.064 0.297 0.040 0.273 0.064 0.273 0.064 0.305 0.032 0.249 0.088 0.305 0.032 0.421 0.221 0.032 0.249 0.088 0.241 0.096 0.273 0.064 0.273 0.064 0.305 0.032 0.249 0.088 0.273 0.064 0.297 0.040 0.273 0.064 0.249 0.088 0.674|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 13 | 11|0096.991| 纵 他 朝 众 生 再 晤 AP 奈 何 明 月 终 辜 负 AP 坐 听 晨 钟 难 算 太 虚 有 无|z ong t a ch ao zh ong sh eng z ai w u AP n ai h e m ing y ve zh ong g u f u AP z uo t ing ch en zh ong n an s uan t ai x v y ou w u|64 64 67 67 69 69 67 67 72 72 69 69 67 67 rest 64 64 62 62 64 64 62 62 67 67 64 64 60 60 rest 57 57 64 64 62 62 64 64 62 62 60 60 59 59 60 60 59 59 57 57|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.253 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.674 0.674 0.337 0.337 0.169 0.169 1.264 1.264|0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.249 0.088 0.421 0.165 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.273 0.064 0.273 0.064 0.273 0.064 0.421 0.165 0.088 0.305 0.032 0.273 0.064 0.273 0.064 0.249 0.088 0.249 0.088 0.305 0.032 0.586 0.088 0.249 0.088 0.081 0.088 1.264|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 | 12|0107.917| 天 道 勘 破 敢 问 一 句 悟 不|t ian d ao k an p o g an w en y i j v w u b u|57 57 64 64 62 62 64 64 62 62 60 60 59 59 60 60 62 62 64 64|0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.337 0.421 0.421 0.506 0.506 0.337 0.337 0.590 0.590|0.032 0.305 0.032 0.273 0.064 0.249 0.088 0.273 0.064 0.249 0.088 0.249 0.088 0.357 0.064 0.418 0.088 0.297 0.040 0.590|0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 | 13|0112.496| 悟 悟|w u w u|68 68 69 69|0.506 0.506 3.792 3.792|0.088 0.418 0.088 3.792|0 0 0 0 --------------------------------------------------------------------------------