├── LICENSE ├── README.md ├── attentions.py ├── commons.py ├── configs └── quickvc_44100.json ├── convert.py ├── crepe.py ├── data_utils_new_new.py ├── dataset └── encode.py ├── inference.py ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── pqmf.py ├── qvcfinalwhite.png ├── requirements.txt ├── stft.py ├── stft_loss.py ├── train.py ├── transforms.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ㌧㌧ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QuickVC(44100Hz 日本語HuBERT対応版) 2 | 3 | このリポジトリは、44100Hzの音声を学習および出力できるように編集した[QuickVC-VoiceConversion](https://github.com/quickvc/QuickVC-VoiceConversion)です。但し、以下の点を変更しております。 4 | - ContentEncoderをWaveNetからAttentionに変更 5 | - HuBERT-softのHiddenUnitsを、[日本語HuBERT](https://huggingface.co/rinna/japanese-hubert-base)12層目768dim特徴量に変更 6 | - MS-iSTFT-VITSのsubband数を4⇛8に変更 7 | - ContentEncoderにF0埋め込みを追加。それに従い、PreprocessにF0抽出処理を追加。 8 | 9 | 10 | 11 | 12 | ## ⚠Work in Progress⚠ 13 | 学習と推論を実装済。事前学習モデル等は学習音源の問題が解決し次第公開。 14 | 15 | ## 1. 環境構築 16 | 17 | Anacondaによる実行環境構築を想定する。 18 | 19 | 0. Anacondaで"QuickVC"という名前の仮想環境を作成する。[y]or nを聞かれたら[y]を入力する。 20 | ```sh 21 | conda create -n QuickVC python=3.8 22 | ``` 23 | 0. 仮想環境を有効化する。 24 | ```sh 25 | conda activate QuickVC 26 | ``` 27 | 0. このレポジトリをクローンする(もしくはDownload Zipでダウンロードする) 28 | 29 | ```sh 30 | git clone https://github.com/tonnetonne814/QuickVC-44100-Ja_HuBERT.git 31 | cd QuickVC-44100-Ja_HuBERT.git # フォルダへ移動 32 | ``` 33 | 34 | 0. [https://pytorch.org/](https://pytorch.org/)のURLよりPyTorchをインストールする。 35 | ```sh 36 | # OS=Linux, CUDA=11.7 の例 37 | pip3 install torch torchvision torchaudio 38 | ``` 39 | 40 | 0. その他、必要なパッケージをインストールする。 41 | ```sh 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ## 2. データセットの準備 46 | 47 | [JVSコーパス](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus)は配布時の音源が24000Hzの為適さないが、説明のために[JVSコーパス](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus)の学習を想定します。 48 | 49 | 1. [こちら](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvs_corpus)からJVSコーパスをダウンロード&解凍する。 50 | 1. 音源を44100Hz16Bitモノラル音源へと変換する。 51 | 1. 解凍したフォルダを、datasetフォルダへ移動し、以下を実行する。 52 | ```sh 53 | python3 ./dataset/encode.py --model japanese-hubert-base --f0 harvest 54 | ``` 55 | > F0抽出のライブラリは、["dio", "parselmouth", "harvest", "crepe"]から選択可能。適宜変更すること。 56 | 57 | 58 | ## 3. [configs](configs)フォルダ内のjsonを編集 59 | 主要なパラメータを説明します。必要であれば編集する。 60 | | 分類 | パラメータ名 | 説明 | 61 | |:-----:|:-----------------:|:---------------------------------------------------------:| 62 | | train | log_interval | 指定ステップ毎にロスを算出し記録する | 63 | | train | eval_interval | 指定ステップ毎にモデル評価を行う | 64 | | train | epochs | 学習データ全体を学習する回数 | 65 | | train | batch_size | 一度のパラメータ更新に使用する学習データ数 | 66 | 67 | 68 | ## 4. 学習 69 | 次のコマンドを入力することで、学習を開始する。YourModelNameは自由に変更して良い。 70 | > ⚠CUDA Out of Memoryのエラーが出た場合には、config.jsonにてbatch_sizeを小さくする。 71 | 72 | ```sh 73 | python train.py -c configs/quickvc_44100.json -m YourModelName 74 | ``` 75 | 76 | 学習経過はターミナルにも表示されるが、tensorboardを用いて確認することで、生成音声の視聴や、スペクトログラム、各ロス遷移を目視で確認することができます。 77 | ```sh 78 | tensorboard --logdir logs 79 | ``` 80 | 81 | ## 5. 推論 82 | 次のコマンドを入力することで、推論を開始する。config.jsonへのパス、生成器モデルパスを指定する。 83 | 84 | ```sh 85 | python inference.py --config ./path/to/config.json --model_path ./path/to/G_xxx.pth 86 | ``` 87 | 実行後、Terminal上にて使用するデバイスを選択後、以下のループが処理される。 88 | 1. ターゲット音声パスの入力 89 | 1. ソース音声ファイルパスの入力 90 | 1. F0計算方式の入力 ( dio:0 | parselmouth:1 | harvest:2 | crepe:3 ) 91 | 1. 処理実行、処理時間表示 92 | 1. 音声とログデータの保存(infer_logsフォルダがデフォルト) 93 | 1. 音声の再生 94 | 95 | ## 事前学習モデル 96 | 未実装。学習音源選定中。 97 | 98 | ## 参考文献 99 | - [QuickVC-VoiceConversion](https://github.com/quickvc/QuickVC-VoiceConversion) 100 | - [MB-ISTFT-VITS](https://github.com/MasayaKawamura/MB-iSTFT-VITS) 101 | - [japanese-hubert-base](https://huggingface.co/rinna/japanese-hubert-base) 102 | - [fish-diffusion](https://github.com/fishaudio/fish-diffusion) 103 | - [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc) 104 | - [etrieval-based-Voice-Conversion-WebUI](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI) 105 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import commons 9 | import modules 10 | from modules import LayerNorm 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): 15 | super().__init__() 16 | self.hidden_channels = hidden_channels 17 | self.filter_channels = filter_channels 18 | self.n_heads = n_heads 19 | self.n_layers = n_layers 20 | self.kernel_size = kernel_size 21 | self.p_dropout = p_dropout 22 | self.window_size = window_size 23 | 24 | self.drop = nn.Dropout(p_dropout) 25 | self.attn_layers = nn.ModuleList() 26 | self.norm_layers_1 = nn.ModuleList() 27 | self.ffn_layers = nn.ModuleList() 28 | self.norm_layers_2 = nn.ModuleList() 29 | for i in range(self.n_layers): 30 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) 31 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 32 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 33 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 34 | 35 | def forward(self, x, x_mask): 36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 37 | x = x * x_mask 38 | for i in range(self.n_layers): 39 | y = self.attn_layers[i](x, x, attn_mask) 40 | y = self.drop(y) 41 | x = self.norm_layers_1[i](x + y) 42 | 43 | y = self.ffn_layers[i](x, x_mask) 44 | y = self.drop(y) 45 | x = self.norm_layers_2[i](x + y) 46 | x = x * x_mask 47 | return x 48 | 49 | 50 | class Decoder(nn.Module): 51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): 52 | super().__init__() 53 | self.hidden_channels = hidden_channels 54 | self.filter_channels = filter_channels 55 | self.n_heads = n_heads 56 | self.n_layers = n_layers 57 | self.kernel_size = kernel_size 58 | self.p_dropout = p_dropout 59 | self.proximal_bias = proximal_bias 60 | self.proximal_init = proximal_init 61 | 62 | self.drop = nn.Dropout(p_dropout) 63 | self.self_attn_layers = nn.ModuleList() 64 | self.norm_layers_0 = nn.ModuleList() 65 | self.encdec_attn_layers = nn.ModuleList() 66 | self.norm_layers_1 = nn.ModuleList() 67 | self.ffn_layers = nn.ModuleList() 68 | self.norm_layers_2 = nn.ModuleList() 69 | for i in range(self.n_layers): 70 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 71 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 72 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 73 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 74 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 75 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 76 | 77 | def forward(self, x, x_mask, h, h_mask): 78 | """ 79 | x: decoder input 80 | h: encoder output 81 | """ 82 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 83 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 84 | x = x * x_mask 85 | for i in range(self.n_layers): 86 | y = self.self_attn_layers[i](x, x, self_attn_mask) 87 | y = self.drop(y) 88 | x = self.norm_layers_0[i](x + y) 89 | 90 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 91 | y = self.drop(y) 92 | x = self.norm_layers_1[i](x + y) 93 | 94 | y = self.ffn_layers[i](x, x_mask) 95 | y = self.drop(y) 96 | x = self.norm_layers_2[i](x + y) 97 | x = x * x_mask 98 | return x 99 | 100 | 101 | class MultiHeadAttention(nn.Module): 102 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 103 | super().__init__() 104 | assert channels % n_heads == 0 105 | 106 | self.channels = channels 107 | self.out_channels = out_channels 108 | self.n_heads = n_heads 109 | self.p_dropout = p_dropout 110 | self.window_size = window_size 111 | self.heads_share = heads_share 112 | self.block_length = block_length 113 | self.proximal_bias = proximal_bias 114 | self.proximal_init = proximal_init 115 | self.attn = None 116 | 117 | self.k_channels = channels // n_heads 118 | self.conv_q = nn.Conv1d(channels, channels, 1) 119 | self.conv_k = nn.Conv1d(channels, channels, 1) 120 | self.conv_v = nn.Conv1d(channels, channels, 1) 121 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 122 | self.drop = nn.Dropout(p_dropout) 123 | 124 | if window_size is not None: 125 | n_heads_rel = 1 if heads_share else n_heads 126 | rel_stddev = self.k_channels**-0.5 127 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 128 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 129 | 130 | nn.init.xavier_uniform_(self.conv_q.weight) 131 | nn.init.xavier_uniform_(self.conv_k.weight) 132 | nn.init.xavier_uniform_(self.conv_v.weight) 133 | if proximal_init: 134 | with torch.no_grad(): 135 | self.conv_k.weight.copy_(self.conv_q.weight) 136 | self.conv_k.bias.copy_(self.conv_q.bias) 137 | 138 | def forward(self, x, c, attn_mask=None): 139 | q = self.conv_q(x) 140 | k = self.conv_k(c) 141 | v = self.conv_v(c) 142 | 143 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 144 | 145 | x = self.conv_o(x) 146 | return x 147 | 148 | def attention(self, query, key, value, mask=None): 149 | # reshape [b, d, t] -> [b, n_h, t, d_k] 150 | b, d, t_s, t_t = (*key.size(), query.size(2)) 151 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 152 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 153 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 154 | 155 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 156 | if self.window_size is not None: 157 | assert t_s == t_t, "Relative attention is only available for self-attention." 158 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 159 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) 160 | scores_local = self._relative_position_to_absolute_position(rel_logits) 161 | scores = scores + scores_local 162 | if self.proximal_bias: 163 | assert t_s == t_t, "Proximal bias is only available for self-attention." 164 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 165 | if mask is not None: 166 | scores = scores.masked_fill(mask == 0, -1e4) 167 | if self.block_length is not None: 168 | assert t_s == t_t, "Local attention is only available for self-attention." 169 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 170 | scores = scores.masked_fill(block_mask == 0, -1e4) 171 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 172 | p_attn = self.drop(p_attn) 173 | output = torch.matmul(p_attn, value) 174 | if self.window_size is not None: 175 | relative_weights = self._absolute_position_to_relative_position(p_attn) 176 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 177 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 178 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 179 | return output, p_attn 180 | 181 | def _matmul_with_relative_values(self, x, y): 182 | """ 183 | x: [b, h, l, m] 184 | y: [h or 1, m, d] 185 | ret: [b, h, l, d] 186 | """ 187 | ret = torch.matmul(x, y.unsqueeze(0)) 188 | return ret 189 | 190 | def _matmul_with_relative_keys(self, x, y): 191 | """ 192 | x: [b, h, l, d] 193 | y: [h or 1, m, d] 194 | ret: [b, h, l, m] 195 | """ 196 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 197 | return ret 198 | 199 | def _get_relative_embeddings(self, relative_embeddings, length): 200 | max_relative_position = 2 * self.window_size + 1 201 | # Pad first before slice to avoid using cond ops. 202 | pad_length = max(length - (self.window_size + 1), 0) 203 | slice_start_position = max((self.window_size + 1) - length, 0) 204 | slice_end_position = slice_start_position + 2 * length - 1 205 | if pad_length > 0: 206 | padded_relative_embeddings = F.pad( 207 | relative_embeddings, 208 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 209 | else: 210 | padded_relative_embeddings = relative_embeddings 211 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 212 | return used_relative_embeddings 213 | 214 | def _relative_position_to_absolute_position(self, x): 215 | """ 216 | x: [b, h, l, 2*l-1] 217 | ret: [b, h, l, l] 218 | """ 219 | batch, heads, length, _ = x.size() 220 | # Concat columns of pad to shift from relative to absolute indexing. 221 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 222 | 223 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 224 | x_flat = x.view([batch, heads, length * 2 * length]) 225 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 226 | 227 | # Reshape and slice out the padded elements. 228 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 229 | return x_final 230 | 231 | def _absolute_position_to_relative_position(self, x): 232 | """ 233 | x: [b, h, l, l] 234 | ret: [b, h, l, 2*l-1] 235 | """ 236 | batch, heads, length, _ = x.size() 237 | # padd along column 238 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 239 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 240 | # add 0's in the beginning that will skew the elements after reshape 241 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 242 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 243 | return x_final 244 | 245 | def _attention_bias_proximal(self, length): 246 | """Bias for self-attention to encourage attention to close positions. 247 | Args: 248 | length: an integer scalar. 249 | Returns: 250 | a Tensor with shape [1, 1, length, length] 251 | """ 252 | r = torch.arange(length, dtype=torch.float32) 253 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 254 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 255 | 256 | 257 | class FFN(nn.Module): 258 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): 259 | super().__init__() 260 | self.in_channels = in_channels 261 | self.out_channels = out_channels 262 | self.filter_channels = filter_channels 263 | self.kernel_size = kernel_size 264 | self.p_dropout = p_dropout 265 | self.activation = activation 266 | self.causal = causal 267 | 268 | if causal: 269 | self.padding = self._causal_padding 270 | else: 271 | self.padding = self._same_padding 272 | 273 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 274 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 275 | self.drop = nn.Dropout(p_dropout) 276 | 277 | def forward(self, x, x_mask): 278 | x = self.conv_1(self.padding(x * x_mask)) 279 | if self.activation == "gelu": 280 | x = x * torch.sigmoid(1.702 * x) 281 | else: 282 | x = torch.relu(x) 283 | x = self.drop(x) 284 | x = self.conv_2(self.padding(x * x_mask)) 285 | return x * x_mask 286 | 287 | def _causal_padding(self, x): 288 | if self.kernel_size == 1: 289 | return x 290 | pad_l = self.kernel_size - 1 291 | pad_r = 0 292 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 293 | x = F.pad(x, commons.convert_pad_shape(padding)) 294 | return x 295 | 296 | def _same_padding(self, x): 297 | if self.kernel_size == 1: 298 | return x 299 | pad_l = (self.kernel_size - 1) // 2 300 | pad_r = self.kernel_size // 2 301 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 302 | x = F.pad(x, commons.convert_pad_shape(padding)) 303 | return x 304 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | def slice_segments_2dim(x, ids_str, segment_size=4): 57 | ret = torch.zeros_like(x[:, :segment_size]) 58 | for i in range(x.size(0)): 59 | idx_str = ids_str[i] 60 | idx_end = idx_str + segment_size 61 | ret[i] = x[i, idx_str:idx_end] 62 | return ret 63 | 64 | 65 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 66 | b, d, t = x.size() 67 | if x_lengths is None: 68 | x_lengths = t 69 | ids_str_max = x_lengths - segment_size + 1 70 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 71 | ret = slice_segments(x, ids_str, segment_size) 72 | return ret, ids_str 73 | 74 | def rand_spec_segments(x, x_lengths=None, segment_size=4): 75 | b, d, t = x.size() 76 | if x_lengths is None: 77 | x_lengths = t 78 | ids_str_max = x_lengths - segment_size 79 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 80 | ret = slice_segments(x, ids_str, segment_size) 81 | return ret, ids_str 82 | 83 | 84 | def get_timing_signal_1d( 85 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 86 | position = torch.arange(length, dtype=torch.float) 87 | num_timescales = channels // 2 88 | log_timescale_increment = ( 89 | math.log(float(max_timescale) / float(min_timescale)) / 90 | (num_timescales - 1)) 91 | inv_timescales = min_timescale * torch.exp( 92 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 93 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 94 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 95 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 96 | signal = signal.view(1, channels, length) 97 | return signal 98 | 99 | 100 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 101 | b, channels, length = x.size() 102 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 103 | return x + signal.to(dtype=x.dtype, device=x.device) 104 | 105 | 106 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 107 | b, channels, length = x.size() 108 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 109 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 110 | 111 | 112 | def subsequent_mask(length): 113 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 114 | return mask 115 | 116 | 117 | @torch.jit.script 118 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 119 | n_channels_int = n_channels[0] 120 | in_act = input_a + input_b 121 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 122 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 123 | acts = t_act * s_act 124 | return acts 125 | 126 | 127 | def convert_pad_shape(pad_shape): 128 | l = pad_shape[::-1] 129 | pad_shape = [item for sublist in l for item in sublist] 130 | return pad_shape 131 | 132 | 133 | def shift_1d(x): 134 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 135 | return x 136 | 137 | 138 | def sequence_mask(length, max_length=None): 139 | if max_length is None: 140 | max_length = length.max() 141 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 142 | return x.unsqueeze(0) < length.unsqueeze(1) 143 | 144 | 145 | def generate_path(duration, mask): 146 | """ 147 | duration: [b, 1, t_x] 148 | mask: [b, 1, t_y, t_x] 149 | """ 150 | device = duration.device 151 | 152 | b, _, t_y, t_x = mask.shape 153 | cum_duration = torch.cumsum(duration, -1) 154 | 155 | cum_duration_flat = cum_duration.view(b * t_x) 156 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 157 | path = path.view(b, t_x, t_y) 158 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 159 | path = path.unsqueeze(1).transpose(2,3) * mask 160 | return path 161 | 162 | 163 | def clip_grad_value_(parameters, clip_value, norm_type=2): 164 | if isinstance(parameters, torch.Tensor): 165 | parameters = [parameters] 166 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 167 | norm_type = float(norm_type) 168 | if clip_value is not None: 169 | clip_value = float(clip_value) 170 | 171 | total_norm = 0 172 | for p in parameters: 173 | param_norm = p.grad.data.norm(norm_type) 174 | total_norm += param_norm.item() ** norm_type 175 | if clip_value is not None: 176 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 177 | total_norm = total_norm ** (1. / norm_type) 178 | return total_norm 179 | -------------------------------------------------------------------------------- /configs/quickvc_44100.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 20, 4 | "eval_interval": 5000, 5 | "seed": 1234, 6 | "epochs": 2000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 16384, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0, 18 | "use_sr": true, 19 | "max_speclen": 2048, 20 | "port": "8002", 21 | "fft_sizes": [768, 1366, 342], 22 | "hop_sizes": [60, 120, 20], 23 | "win_lengths": [300, 600, 120], 24 | "window": "hann_window" 25 | }, 26 | "data": { 27 | "training_files":"./filelist/train.txt", 28 | "validation_files":"./filelist/test.txt", 29 | "text_cleaners":["english_cleaners2"], 30 | "max_wav_value": 32768.0, 31 | "sampling_rate": 44100, 32 | "filter_length": 2048, 33 | "hop_length": 512, 34 | "win_length": 2048, 35 | "n_mel_channels": 80, 36 | "mel_fmin": 0.0, 37 | "mel_fmax": null, 38 | "add_blank": true, 39 | "n_speakers": 0, 40 | "cleaned_text": true 41 | }, 42 | "model": { 43 | "ms_istft_vits": true, 44 | "mb_istft_vits": false, 45 | "istft_vits": false, 46 | "subbands": 8, 47 | "gen_istft_n_fft": 16, 48 | "gen_istft_hop_size": 4, 49 | "inter_channels": 192, 50 | "hidden_channels": 192, 51 | "filter_channels": 768, 52 | "n_heads": 2, 53 | "n_layers": 6, 54 | "kernel_size": 3, 55 | "p_dropout": 0.1, 56 | "resblock": "1", 57 | "resblock_kernel_sizes": [3,7,11], 58 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 59 | "upsample_rates": [4,4], 60 | "upsample_initial_channel": 512, 61 | "upsample_kernel_sizes": [16,16], 62 | "n_layers_q": 3, 63 | "use_spectral_norm": false, 64 | "gin_channels": 256, 65 | "use_sdp": false, 66 | "ssl_dim": 1024, 67 | "use_spk": false 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import librosa 5 | import time 6 | from scipy.io.wavfile import write 7 | from tqdm import tqdm 8 | 9 | import utils 10 | from models import SynthesizerTrn 11 | from mel_processing import mel_spectrogram_torch 12 | import logging 13 | logging.getLogger('numba').setLevel(logging.WARNING) 14 | import torch.autograd.profiler as profiler 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--hpfile", type=str, default="logs/quickvc/config.json", help="path to json config file") 19 | parser.add_argument("--ptfile", type=str, default="logs/quickvc/quickvc.pth", help="path to pth file") 20 | parser.add_argument("--txtpath", type=str, default="convert.txt", help="path to txt file") 21 | parser.add_argument("--outdir", type=str, default="output/quickvc", help="path to output dir") 22 | parser.add_argument("--use_timestamp", default=False, action="store_true") 23 | args = parser.parse_args() 24 | 25 | os.makedirs(args.outdir, exist_ok=True) 26 | hps = utils.get_hparams_from_file(args.hpfile) 27 | 28 | print("Loading model...") 29 | net_g = SynthesizerTrn( 30 | hps.data.filter_length // 2 + 1, 31 | hps.train.segment_size // hps.data.hop_length, 32 | **hps.model).cuda() 33 | _ = net_g.eval() 34 | total = sum([param.nelement() for param in net_g.parameters()]) 35 | 36 | print("Number of parameter: %.2fM" % (total/1e6)) 37 | print("Loading checkpoint...") 38 | _ = utils.load_checkpoint(args.ptfile, net_g, None) 39 | 40 | print(f"Loading hubert_soft checkpoint") 41 | hubert_soft = torch.hub.load("bshall/hubert:main", f"hubert_soft").cuda() 42 | print("Loaded soft hubert.") 43 | 44 | print("Processing text...") 45 | titles, srcs, tgts = [], [], [] 46 | with open(args.txtpath, "r") as f: 47 | for rawline in f.readlines(): 48 | title, src, tgt = rawline.strip().split("|") 49 | titles.append(title) 50 | srcs.append(src) 51 | tgts.append(tgt) 52 | 53 | print("Synthesizing...") 54 | 55 | with torch.no_grad(): 56 | for line in tqdm(zip(titles, srcs, tgts)): 57 | title, src, tgt = line 58 | # tgt 59 | wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate) 60 | wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) 61 | wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).cuda() 62 | mel_tgt = mel_spectrogram_torch( 63 | wav_tgt, 64 | hps.data.filter_length, 65 | hps.data.n_mel_channels, 66 | hps.data.sampling_rate, 67 | hps.data.hop_length, 68 | hps.data.win_length, 69 | hps.data.mel_fmin, 70 | hps.data.mel_fmax 71 | ) 72 | # src 73 | wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate) 74 | wav_src = torch.from_numpy(wav_src).unsqueeze(0).unsqueeze(0).cuda() 75 | print(wav_src.size()) 76 | #long running 77 | #do something other 78 | c = hubert_soft.units(wav_src) 79 | 80 | 81 | 82 | c=c.transpose(2,1) 83 | #print(c.size()) 84 | audio = net_g.infer(c, mel=mel_tgt) 85 | 86 | audio = audio[0][0].data.cpu().float().numpy() 87 | if args.use_timestamp: 88 | timestamp = time.strftime("%m-%d_%H-%M", time.localtime()) 89 | write(os.path.join(args.outdir, "{}.wav".format(timestamp+"_"+title)), hps.data.sampling_rate, audio) 90 | else: 91 | write(os.path.join(args.outdir, f"{title}.wav"), hps.data.sampling_rate, audio) 92 | -------------------------------------------------------------------------------- /crepe.py: -------------------------------------------------------------------------------- 1 | from typing import Optional,Union 2 | try: 3 | from typing import Literal 4 | except Exception as e: 5 | from typing_extensions import Literal 6 | import numpy as np 7 | import torch 8 | import torchcrepe 9 | from torch import nn 10 | from torch.nn import functional as F 11 | import scipy 12 | 13 | #from:https://github.com/fishaudio/fish-diffusion 14 | 15 | def repeat_expand( 16 | content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest" 17 | ): 18 | """Repeat content to target length. 19 | This is a wrapper of torch.nn.functional.interpolate. 20 | 21 | Args: 22 | content (torch.Tensor): tensor 23 | target_len (int): target length 24 | mode (str, optional): interpolation mode. Defaults to "nearest". 25 | 26 | Returns: 27 | torch.Tensor: tensor 28 | """ 29 | 30 | ndim = content.ndim 31 | 32 | if content.ndim == 1: 33 | content = content[None, None] 34 | elif content.ndim == 2: 35 | content = content[None] 36 | 37 | assert content.ndim == 3 38 | 39 | is_np = isinstance(content, np.ndarray) 40 | if is_np: 41 | content = torch.from_numpy(content) 42 | 43 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) 44 | 45 | if is_np: 46 | results = results.numpy() 47 | 48 | if ndim == 1: 49 | return results[0, 0] 50 | elif ndim == 2: 51 | return results[0] 52 | 53 | 54 | class BasePitchExtractor: 55 | def __init__( 56 | self, 57 | hop_length: int = 512, 58 | f0_min: float = 50.0, 59 | f0_max: float = 1100.0, 60 | keep_zeros: bool = True, 61 | ): 62 | """Base pitch extractor. 63 | 64 | Args: 65 | hop_length (int, optional): Hop length. Defaults to 512. 66 | f0_min (float, optional): Minimum f0. Defaults to 50.0. 67 | f0_max (float, optional): Maximum f0. Defaults to 1100.0. 68 | keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True. 69 | """ 70 | 71 | self.hop_length = hop_length 72 | self.f0_min = f0_min 73 | self.f0_max = f0_max 74 | self.keep_zeros = keep_zeros 75 | 76 | def __call__(self, x, sampling_rate=44100, pad_to=None): 77 | raise NotImplementedError("BasePitchExtractor is not callable.") 78 | 79 | def post_process(self, x, sampling_rate, f0, pad_to): 80 | if isinstance(f0, np.ndarray): 81 | f0 = torch.from_numpy(f0).float().to(x.device) 82 | 83 | if pad_to is None: 84 | return f0 85 | 86 | f0 = repeat_expand(f0, pad_to) 87 | 88 | if self.keep_zeros: 89 | return f0 90 | 91 | vuv_vector = torch.zeros_like(f0) 92 | vuv_vector[f0 > 0.0] = 1.0 93 | vuv_vector[f0 <= 0.0] = 0.0 94 | 95 | # Remove 0 frequency and apply linear interpolation 96 | nzindex = torch.nonzero(f0).squeeze() 97 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() 98 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() 99 | time_frame = np.arange(pad_to) * self.hop_length / sampling_rate 100 | 101 | if f0.shape[0] <= 0: 102 | return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device) 103 | 104 | if f0.shape[0] == 1: 105 | return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device) 106 | 107 | # Probably can be rewritten with torch? 108 | f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) 109 | vuv_vector = vuv_vector.cpu().numpy() 110 | vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) 111 | 112 | return f0,vuv_vector 113 | 114 | 115 | class MaskedAvgPool1d(nn.Module): 116 | def __init__( 117 | self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 118 | ): 119 | """An implementation of mean pooling that supports masked values. 120 | 121 | Args: 122 | kernel_size (int): The size of the median pooling window. 123 | stride (int, optional): The stride of the median pooling window. Defaults to None. 124 | padding (int, optional): The padding of the median pooling window. Defaults to 0. 125 | """ 126 | 127 | super(MaskedAvgPool1d, self).__init__() 128 | self.kernel_size = kernel_size 129 | self.stride = stride or kernel_size 130 | self.padding = padding 131 | 132 | def forward(self, x, mask=None): 133 | ndim = x.dim() 134 | if ndim == 2: 135 | x = x.unsqueeze(1) 136 | 137 | assert ( 138 | x.dim() == 3 139 | ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" 140 | 141 | # Apply the mask by setting masked elements to zero, or make NaNs zero 142 | if mask is None: 143 | mask = ~torch.isnan(x) 144 | 145 | # Ensure mask has the same shape as the input tensor 146 | assert x.shape == mask.shape, "Input tensor and mask must have the same shape" 147 | 148 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 149 | # Create a ones kernel with the same number of channels as the input tensor 150 | ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device) 151 | 152 | # Perform sum pooling 153 | sum_pooled = nn.functional.conv1d( 154 | masked_x, 155 | ones_kernel, 156 | stride=self.stride, 157 | padding=self.padding, 158 | groups=x.size(1), 159 | ) 160 | 161 | # Count the non-masked (valid) elements in each pooling window 162 | valid_count = nn.functional.conv1d( 163 | mask.float(), 164 | ones_kernel, 165 | stride=self.stride, 166 | padding=self.padding, 167 | groups=x.size(1), 168 | ) 169 | valid_count = valid_count.clamp(min=1) # Avoid division by zero 170 | 171 | # Perform masked average pooling 172 | avg_pooled = sum_pooled / valid_count 173 | 174 | # Fill zero values with NaNs 175 | avg_pooled[avg_pooled == 0] = float("nan") 176 | 177 | if ndim == 2: 178 | return avg_pooled.squeeze(1) 179 | 180 | return avg_pooled 181 | 182 | 183 | class MaskedMedianPool1d(nn.Module): 184 | def __init__( 185 | self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 186 | ): 187 | """An implementation of median pooling that supports masked values. 188 | 189 | This implementation is inspired by the median pooling implementation in 190 | https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 191 | 192 | Args: 193 | kernel_size (int): The size of the median pooling window. 194 | stride (int, optional): The stride of the median pooling window. Defaults to None. 195 | padding (int, optional): The padding of the median pooling window. Defaults to 0. 196 | """ 197 | 198 | super(MaskedMedianPool1d, self).__init__() 199 | self.kernel_size = kernel_size 200 | self.stride = stride or kernel_size 201 | self.padding = padding 202 | 203 | def forward(self, x, mask=None): 204 | ndim = x.dim() 205 | if ndim == 2: 206 | x = x.unsqueeze(1) 207 | 208 | assert ( 209 | x.dim() == 3 210 | ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" 211 | 212 | if mask is None: 213 | mask = ~torch.isnan(x) 214 | 215 | assert x.shape == mask.shape, "Input tensor and mask must have the same shape" 216 | 217 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 218 | 219 | x = F.pad(masked_x, (self.padding, self.padding), mode="reflect") 220 | mask = F.pad( 221 | mask.float(), (self.padding, self.padding), mode="constant", value=0 222 | ) 223 | 224 | x = x.unfold(2, self.kernel_size, self.stride) 225 | mask = mask.unfold(2, self.kernel_size, self.stride) 226 | 227 | x = x.contiguous().view(x.size()[:3] + (-1,)) 228 | mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device) 229 | 230 | # Combine the mask with the input tensor 231 | #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf"))) 232 | x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device)) 233 | 234 | # Sort the masked tensor along the last dimension 235 | x_sorted, _ = torch.sort(x_masked, dim=-1) 236 | 237 | # Compute the count of non-masked (valid) values 238 | valid_count = mask.sum(dim=-1) 239 | 240 | # Calculate the index of the median value for each pooling window 241 | median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0) 242 | 243 | # Gather the median values using the calculated indices 244 | median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) 245 | 246 | # Fill infinite values with NaNs 247 | median_pooled[torch.isinf(median_pooled)] = float("nan") 248 | 249 | if ndim == 2: 250 | return median_pooled.squeeze(1) 251 | 252 | return median_pooled 253 | 254 | 255 | class CrepePitchExtractor(BasePitchExtractor): 256 | def __init__( 257 | self, 258 | hop_length: int = 512, 259 | f0_min: float = 50.0, 260 | f0_max: float = 1100.0, 261 | threshold: float = 0.05, 262 | keep_zeros: bool = False, 263 | device = None, 264 | model: Literal["full", "tiny"] = "full", 265 | use_fast_filters: bool = True, 266 | ): 267 | super().__init__(hop_length, f0_min, f0_max, keep_zeros) 268 | 269 | self.threshold = threshold 270 | self.model = model 271 | self.use_fast_filters = use_fast_filters 272 | self.hop_length = hop_length 273 | if device is None: 274 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 275 | else: 276 | self.dev = torch.device(device) 277 | if self.use_fast_filters: 278 | self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device) 279 | self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device) 280 | 281 | def __call__(self, x, sampling_rate=44100, pad_to=None): 282 | """Extract pitch using crepe. 283 | 284 | 285 | Args: 286 | x (torch.Tensor): Audio signal, shape (1, T). 287 | sampling_rate (int, optional): Sampling rate. Defaults to 44100. 288 | pad_to (int, optional): Pad to length. Defaults to None. 289 | 290 | Returns: 291 | torch.Tensor: Pitch, shape (T // hop_length,). 292 | """ 293 | 294 | assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor." 295 | assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels." 296 | 297 | x = x.to(self.dev) 298 | f0, pd = torchcrepe.predict( 299 | x, 300 | sampling_rate, 301 | self.hop_length, 302 | self.f0_min, 303 | self.f0_max, 304 | pad=True, 305 | model=self.model, 306 | batch_size=1024, 307 | device=x.device, 308 | return_periodicity=True, 309 | ) 310 | 311 | # Filter, remove silence, set uv threshold, refer to the original warehouse readme 312 | if self.use_fast_filters: 313 | pd = self.median_filter(pd) 314 | else: 315 | pd = torchcrepe.filter.median(pd, 3) 316 | 317 | pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512) 318 | f0 = torchcrepe.threshold.At(self.threshold)(f0, pd) 319 | 320 | if self.use_fast_filters: 321 | f0 = self.mean_filter(f0) 322 | else: 323 | f0 = torchcrepe.filter.mean(f0, 3) 324 | 325 | f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0] 326 | 327 | if torch.all(f0 == 0): 328 | rtn = f0.cpu().numpy() if pad_to==None else np.zeros(pad_to) 329 | return rtn,rtn 330 | 331 | return self.post_process(x, sampling_rate, f0, pad_to) 332 | -------------------------------------------------------------------------------- /data_utils_new_new.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | import commons 9 | from mel_processing import spectrogram_torch, spec_to_mel_torch 10 | from utils import load_wav_to_torch, load_filepaths_and_text, transform 11 | #import h5py 12 | 13 | 14 | """Multi speaker version""" 15 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 16 | """ 17 | 1) loads audio, speaker_id, text pairs 18 | 2) normalizes text and converts them to sequences of integers 19 | 3) computes spectrograms from audio files. 20 | """ 21 | def __init__(self, audiopaths, hparams): 22 | self.audiopaths = load_filepaths_and_text(audiopaths) 23 | self.max_wav_value = hparams.data.max_wav_value 24 | self.sampling_rate = hparams.data.sampling_rate 25 | self.filter_length = hparams.data.filter_length 26 | self.hop_length = hparams.data.hop_length 27 | self.win_length = hparams.data.win_length 28 | self.sampling_rate = hparams.data.sampling_rate 29 | self.use_sr = hparams.train.use_sr 30 | self.use_spk = hparams.model.use_spk 31 | self.spec_len = hparams.train.max_speclen 32 | 33 | random.seed(1243) 34 | random.shuffle(self.audiopaths) 35 | self._filter() 36 | 37 | def _filter(self): 38 | """ 39 | Filter text & store spec lengths 40 | """ 41 | # Store spectrogram lengths for Bucketing 42 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 43 | # spec_length = wav_length // hop_length 44 | 45 | lengths = [] 46 | for audiopath in self.audiopaths: 47 | lengths.append(os.path.getsize(audiopath[0]) // (2 * self.hop_length)) 48 | self.lengths = lengths 49 | def get_audio(self, filename): 50 | audio_norm, sampling_rate = load_wav_to_torch(filename) 51 | audio_norm = audio_norm.unsqueeze(0) 52 | 53 | spec_filename = filename.replace(".wav", ".spec.pt") 54 | 55 | if os.path.exists(spec_filename): 56 | spec = torch.load(spec_filename) 57 | else: 58 | spec = spectrogram_torch(audio_norm, self.filter_length, 59 | self.sampling_rate, self.hop_length, self.win_length, 60 | center=False) 61 | spec = torch.squeeze(spec, 0) 62 | torch.save(spec, spec_filename) 63 | 64 | #i = 80#random.randint(68,92) 65 | 66 | c_filename = filename.replace(".wav", f".content.npy") 67 | c = np.load(c_filename)#.squeeze(0) 68 | c=c.transpose(1,0) 69 | c = torch.FloatTensor(c)#.squeeze(0) 70 | 71 | f0_filename = filename.replace(".wav", ".f0.npy") 72 | f0 = np.load(f0_filename) 73 | f0, uv = utils.interpolate_f0(f0) 74 | 75 | _,spec_len = spec.shape 76 | f0_len = f0.shape 77 | _,c_len = c.shape 78 | wav_len = audio_norm.shape 79 | length = max([spec_len, f0_len[0], c_len]) 80 | 81 | f0 = torch.FloatTensor(f0) 82 | uv = torch.FloatTensor(uv) 83 | 84 | c = utils.repeat_expand_2d(c.squeeze(0), length) 85 | """ 86 | if length == spec_len: 87 | c = utils.repeat_expand_2d(c.squeeze(0), length) 88 | f0 = utils.repeat_expand_2d(torch.unsqueeze(f0, 0), length)[0] 89 | uv = utils.repeat_expand_2d(torch.unsqueeze(uv, 0), length)[0] 90 | 91 | elif length == f0_len: 92 | c = utils.repeat_expand_2d(c.squeeze(0), length) 93 | spec = utils.repeat_expand_2d(spec, length) 94 | 95 | elif length == c_len: 96 | spec = utils.repeat_expand_2d(spec.squeeze(0), length) 97 | f0 = utils.repeat_expand_2d(torch.unsqueeze(f0, 0), length)[0] 98 | uv = utils.repeat_expand_2d(torch.unsqueeze(uv, 0), length)[0] 99 | """ 100 | 101 | lmin = min(c.size(-1), spec.size(-1)) 102 | assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename) 103 | assert abs(audio_norm.shape[1]-lmin * self.hop_length) < 3 * self.hop_length 104 | 105 | spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin] 106 | audio_norm = audio_norm[:, :lmin * self.hop_length] 107 | 108 | 109 | return c, spec, audio_norm, f0, uv 110 | 111 | def random_slice(self, c, f0, spec, audio_norm, uv): 112 | # if spec.shape[1] < 30: 113 | # print("skip too short audio:", filename) 114 | # return None 115 | if spec.shape[1] > 800: 116 | start = random.randint(0, spec.shape[1]-800) 117 | end = start + 790 118 | spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end] 119 | audio_norm = audio_norm[:, start * self.hop_length : end * self.hop_length] 120 | 121 | return c, f0, spec, audio_norm, uv 122 | def __getitem__(self, index): 123 | return self.get_audio(self.audiopaths[index][0]) 124 | 125 | def __len__(self): 126 | return len(self.audiopaths) 127 | 128 | 129 | 130 | class TextAudioSpeakerCollate(): 131 | """ Zero-pads model inputs and targets 132 | """ 133 | def __init__(self, hps): 134 | self.hps = hps 135 | self.use_sr = hps.train.use_sr 136 | self.use_spk = hps.model.use_spk 137 | 138 | def __call__(self, batch): 139 | """Collate's training batch from normalized text, audio and speaker identities 140 | PARAMS 141 | ------ 142 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 143 | """ 144 | # Right zero-pad all one-hot text sequences to max input length 145 | _, ids_sorted_decreasing = torch.sort( 146 | torch.LongTensor([x[0].size(1) for x in batch]), 147 | dim=0, descending=True) 148 | max_c_len = max([x[0].size(1) for x in batch]) 149 | max_spec_len = max([x[1].size(1) for x in batch]) 150 | max_wav_len = max([x[2].size(1) for x in batch]) 151 | max_f0_len = max([x[3].size(0) for x in batch]) 152 | max_uv_len = max([x[4].size(0) for x in batch]) 153 | #print(max_spec_len,max_c_len) 154 | spec_lengths = torch.LongTensor(len(batch)) 155 | wav_lengths = torch.LongTensor(len(batch)) 156 | #f0_lengths = torch.LongTensor(len(batch)) 157 | #uv_lengths = torch.LongTensor(len(batch)) 158 | 159 | #if self.use_spk: 160 | # spks = torch.FloatTensor(len(batch), batch[0][3].size(0)) 161 | #else: 162 | # spks = None 163 | spks = None 164 | 165 | # maybe spec f0 uv c length is same 166 | c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) 167 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 168 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 169 | f0_padded = torch.FloatTensor(len(batch), max_f0_len) 170 | uv_padded = torch.FloatTensor(len(batch), max_uv_len) 171 | c_padded.zero_() 172 | spec_padded.zero_() 173 | wav_padded.zero_() 174 | f0_padded.zero_() 175 | uv_padded.zero_() 176 | 177 | for i in range(len(ids_sorted_decreasing)): 178 | row = batch[ids_sorted_decreasing[i]] 179 | 180 | c = row[0] 181 | c_padded[i, :, :c.size(1)] = c 182 | 183 | spec = row[1] 184 | spec_padded[i, :, :spec.size(1)] = spec 185 | spec_lengths[i] = spec.size(1) 186 | 187 | wav = row[2] 188 | wav_padded[i, :, :wav.size(1)] = wav 189 | wav_lengths[i] = wav.size(1) 190 | 191 | f0 = row[3] 192 | f0_padded[i, :f0.size(0)] = f0 193 | 194 | uv = row[4] 195 | uv_padded[i, :uv.size(0)] = uv 196 | 197 | #if self.use_spk: 198 | # spks[i] = row[3] 199 | #""" 200 | spec_seglen = spec_lengths[-1] if spec_lengths[-1] < self.hps.train.max_speclen + 1 else self.hps.train.max_speclen + 1 201 | #print(spec_seglen) 202 | wav_seglen = spec_seglen * self.hps.data.hop_length 203 | #print(spec_padded.size(), spec_lengths, spec_seglen) 204 | spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen) 205 | wav_padded = commons.slice_segments(wav_padded, ids_slice * self.hps.data.hop_length, wav_seglen)[:,:,:-self.hps.data.hop_length] 206 | 207 | c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:,:,:-1] 208 | f0_padded = commons.slice_segments_2dim(f0_padded, ids_slice , spec_seglen)[:,:-1] 209 | uv_padded = commons.slice_segments_2dim(uv_padded, ids_slice , spec_seglen)[:,:-1] 210 | 211 | spec_padded = spec_padded[:,:,:-1] 212 | #wav_padded = wav_padded[:,:,:-self.hps.data.hop_length] 213 | #""" 214 | #if self.use_spk: 215 | #return c_padded, spec_padded, wav_padded, spks 216 | #else: 217 | #return c_padded, spec_padded, wav_padded 218 | return c_padded, spec_padded, wav_padded, f0_padded, uv_padded 219 | 220 | 221 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 222 | """ 223 | Maintain similar input lengths in a batch. 224 | Length groups are specified by boundaries. 225 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 226 | 227 | It removes samples which are not included in the boundaries. 228 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 229 | """ 230 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): 231 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 232 | self.lengths = dataset.lengths 233 | self.batch_size = batch_size 234 | self.boundaries = boundaries 235 | 236 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 237 | self.total_size = sum(self.num_samples_per_bucket) 238 | self.num_samples = self.total_size // self.num_replicas 239 | 240 | def _create_buckets(self): 241 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 242 | for i in range(len(self.lengths)): 243 | length = self.lengths[i] 244 | idx_bucket = self._bisect(length) 245 | if idx_bucket != -1: 246 | buckets[idx_bucket].append(i) 247 | 248 | for i in range(len(buckets) - 1, 0, -1): 249 | if len(buckets[i]) == 0: 250 | buckets.pop(i) 251 | self.boundaries.pop(i+1) 252 | if len(buckets[0]) == 0: 253 | buckets.pop(0) 254 | self.boundaries.pop(0+1) 255 | 256 | num_samples_per_bucket = [] 257 | for i in range(len(buckets)): 258 | len_bucket = len(buckets[i]) 259 | total_batch_size = self.num_replicas * self.batch_size 260 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 261 | num_samples_per_bucket.append(len_bucket + rem) 262 | return buckets, num_samples_per_bucket 263 | 264 | def __iter__(self): 265 | # deterministically shuffle based on epoch 266 | g = torch.Generator() 267 | g.manual_seed(self.epoch) 268 | 269 | indices = [] 270 | if self.shuffle: 271 | for bucket in self.buckets: 272 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 273 | else: 274 | for bucket in self.buckets: 275 | indices.append(list(range(len(bucket)))) 276 | 277 | batches = [] 278 | for i in range(len(self.buckets)): 279 | bucket = self.buckets[i] 280 | len_bucket = len(bucket) 281 | ids_bucket = indices[i] 282 | num_samples_bucket = self.num_samples_per_bucket[i] 283 | 284 | # add extra samples to make it evenly divisible 285 | rem = num_samples_bucket - len_bucket 286 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] 287 | 288 | # subsample 289 | ids_bucket = ids_bucket[self.rank::self.num_replicas] 290 | 291 | # batching 292 | for j in range(len(ids_bucket) // self.batch_size): 293 | batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]] 294 | batches.append(batch) 295 | 296 | if self.shuffle: 297 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 298 | batches = [batches[i] for i in batch_ids] 299 | self.batches = batches 300 | 301 | assert len(self.batches) * self.batch_size == self.num_samples 302 | return iter(self.batches) 303 | 304 | def _bisect(self, x, lo=0, hi=None): 305 | if hi is None: 306 | hi = len(self.boundaries) - 1 307 | 308 | if hi > lo: 309 | mid = (hi + lo) // 2 310 | if self.boundaries[mid] < x and x <= self.boundaries[mid+1]: 311 | return mid 312 | elif x <= self.boundaries[mid]: 313 | return self._bisect(x, lo, mid) 314 | else: 315 | return self._bisect(x, mid + 1, hi) 316 | else: 317 | return -1 318 | 319 | def __len__(self): 320 | return self.num_samples // self.batch_size 321 | import utils 322 | from torch.utils.data import DataLoader 323 | if __name__ == "__main__": 324 | hps = utils.get_hparams() 325 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) 326 | train_sampler = DistributedBucketSampler( 327 | train_dataset, 328 | hps.train.batch_size, 329 | [32,70,100,200,300,400,500,600,700,800,900,1000], 330 | num_replicas=1, 331 | rank=0, 332 | shuffle=True) 333 | collate_fn = TextAudioSpeakerCollate(hps) 334 | train_loader = DataLoader(train_dataset, num_workers=1, shuffle=False, pin_memory=True, 335 | collate_fn=collate_fn, batch_sampler=train_sampler) 336 | 337 | for batch_idx, (c, spec, y,f0, uv) in enumerate(train_loader): 338 | print(c.size(), spec.size(), y.size()) 339 | #print(batch_idx, c, spec, y) 340 | #break 341 | -------------------------------------------------------------------------------- /dataset/encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | import os 5 | 6 | import torch 7 | import torchaudio 8 | from torchaudio.functional import resample 9 | 10 | ### add ### 11 | import torch 12 | from transformers import HubertModel 13 | import random 14 | 15 | import os 16 | import glob 17 | import argparse 18 | import logging 19 | import numpy 20 | from scipy.io.wavfile import read 21 | import torch 22 | MATPLOTLIB_FLAG = False 23 | 24 | from scipy.io.wavfile import read 25 | import torch 26 | from torch.nn import functional as F 27 | 28 | import pyworld 29 | 30 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 31 | 32 | def encode_dataset(args): 33 | 34 | filelist = glob.glob(f"./{args.in_dir}/**/*{args.extension}", recursive=True) 35 | out_files_list = list() 36 | 37 | if args.model == "japanese-hubert-base": 38 | model = HubertModel.from_pretrained("rinna/japanese-hubert-base") 39 | model.eval() 40 | model.to(device) 41 | 42 | for in_path in tqdm(filelist): 43 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 44 | out_dir = "/".join(out_path.split("/")[:-1]) 45 | os.makedirs(out_dir, exist_ok=True) 46 | #out_path = args.out_dir / in_path.relative_to(args.in_dir) 47 | #if True:#not os.path.exists(out_path.with_suffix(".npy")): 48 | try: 49 | wav, sr = torchaudio.load(in_path) 50 | except: 51 | continue 52 | wav = resample(wav, sr, 16000).to(device) 53 | with torch.inference_mode(): 54 | units = model(wav)[0].squeeze().cpu().numpy() 55 | out_files_list.append(out_path) 56 | numpy.save(out_path.replace(args.extension,".content.npy"), units) 57 | 58 | else: 59 | print(f"Loading hubert checkpoint") 60 | hubert = torch.hub.load("bshall/hubert:main", f"hubert_soft").to(device).eval() 61 | print(f"Encoding dataset at {args.in_dir}") 62 | for in_path in tqdm(filelist): 63 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 64 | out_dir = "/".join(out_path.split("/")[:-1]) 65 | os.makedirs(out_dir, exist_ok=True) 66 | #if True:#not os.path.exists(out_path.with_suffix(".npy")): 67 | try: 68 | wav, sr = torchaudio.load(in_path) 69 | except: 70 | continue 71 | wav = resample(wav, sr, 16000) 72 | wav = wav.unsqueeze(0).to(device) 73 | with torch.inference_mode(): 74 | units = hubert.units(wav) #[Batch, Frame, Hidden] 75 | out_files_list.append(out_path) 76 | numpy.save(out_path.replace(args.extension,".content.npy"), units.squeeze().cpu().numpy()) 77 | 78 | if args.f0 == "dio": 79 | for in_path in tqdm(filelist): 80 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 81 | out_dir = "/".join(out_path.split("/")[:-1]) 82 | os.makedirs(out_dir, exist_ok=True) 83 | if True:#not os.path.exists(out_path.with_suffix(".npy")): 84 | try: 85 | wav, sr = torchaudio.load(in_path) 86 | except: 87 | continue 88 | wav = wav[0].to('cpu').detach().numpy().copy() 89 | f0_dio = compute_f0_dio( 90 | wav, sampling_rate=sr, hop_length=512) 91 | numpy.save(out_path.replace(args.extension,".f0.npy"), f0_dio) 92 | 93 | elif args.f0 == "harvest": 94 | for in_path in tqdm(filelist): 95 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 96 | out_dir = "/".join(out_path.split("/")[:-1]) 97 | os.makedirs(out_dir, exist_ok=True) 98 | if True:#not os.path.exists(out_path.with_suffix(".npy")): 99 | try: 100 | wav, sr = torchaudio.load(in_path) 101 | except: 102 | continue 103 | wav = wav[0].to('cpu').detach().numpy().copy() 104 | f0_harvest = compute_f0_harvest( 105 | wav, sampling_rate=sr, hop_length=512) 106 | numpy.save(out_path.replace(args.extension,".f0.npy"), f0_harvest) 107 | 108 | elif args.f0 == "parselmouth": 109 | for in_path in tqdm(filelist): 110 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 111 | out_dir = "/".join(out_path.split("/")[:-1]) 112 | os.makedirs(out_dir, exist_ok=True) 113 | if True:#not os.path.exists(out_path.with_suffix(".npy")): 114 | try: 115 | wav, sr = torchaudio.load(in_path) 116 | except: 117 | continue 118 | wav = wav[0].to('cpu').detach().numpy().copy() 119 | f0_parselmouth = compute_f0_parselmouth( 120 | wav, sampling_rate=sr, hop_length=512) 121 | numpy.save(out_path.replace(args.extension,".f0.npy"), f0_parselmouth) 122 | 123 | elif args.f0 == "crepe": 124 | for in_path in tqdm(filelist): 125 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 126 | out_dir = "/".join(out_path.split("/")[:-1]) 127 | os.makedirs(out_dir, exist_ok=True) 128 | if True:#not os.path.exists(out_path.with_suffix(".npy")): 129 | try: 130 | wav, sr = torchaudio.load(in_path) 131 | except: 132 | continue 133 | wav = wav[0].to('cpu').detach().numpy().copy() 134 | f0_crepe, _= compute_f0_torchcrepe( 135 | wav, sampling_rate=sr, hop_length=512) 136 | numpy.save(out_path.replace(args.extension,".f0.npy"), f0_crepe) 137 | 138 | 139 | elif args.f0 == "check_f0_method": 140 | for in_path in tqdm(filelist): 141 | out_path = in_path.replace(f"{args.in_dir}", f"{args.out_dir}") 142 | out_dir = "/".join(out_path.split("/")[:-1]) 143 | os.makedirs(out_dir, exist_ok=True) 144 | if True:#not os.path.exists(out_path.with_suffix(".npy")): 145 | try: 146 | wav, sr = torchaudio.load(in_path) 147 | except: 148 | continue 149 | wav = wav[0].to('cpu').detach().numpy().copy() 150 | f0_crepe, _= compute_f0_torchcrepe( 151 | wav, sampling_rate=sr, hop_length=512) 152 | f0_harvest = compute_f0_harvest( 153 | wav, sampling_rate=sr, hop_length=512) 154 | f0_dio = compute_f0_dio( 155 | wav, sampling_rate=sr, hop_length=512) 156 | f0_parselmouth = compute_f0_parselmouth( 157 | wav, sampling_rate=sr, hop_length=512) 158 | import matplotlib.pyplot as plt 159 | import numpy as np 160 | x = numpy.linspace(0, 1, len(f0_parselmouth)) 161 | plt.plot(x, f0_crepe, label="crepe") 162 | plt.plot(x, f0_harvest, label="harvest") 163 | plt.plot(x, f0_dio, label="dio") 164 | plt.plot(x, f0_parselmouth, label="pm") 165 | plt.legend() 166 | plt.show() 167 | plt.close() 168 | 169 | 170 | n_files = len(out_files_list) 171 | 172 | test_list = list() 173 | for idx in range(int(n_files*0.05)): # 5% of all files are used for test 174 | target_idx = random.randint(a=0, b=int(n_files-idx-1)) 175 | path = out_files_list.pop(target_idx) 176 | path_str = str(path) + "\n" 177 | test_list.append(path_str) 178 | 179 | train_list = list() 180 | for path in out_files_list: 181 | path_str = str(path)+ "\n" 182 | train_list.append(path_str) 183 | 184 | os.makedirs("./filelist/", exist_ok=True) 185 | with open("./filelist/train.txt", mode="w", encoding="utf-8") as f: 186 | f.writelines(train_list) 187 | with open("./filelist/test.txt", mode="w", encoding="utf-8") as f: 188 | f.writelines(test_list) 189 | 190 | 191 | 192 | 193 | ################################################################ 194 | 195 | f0_bin = 256 196 | f0_max = 1100.0 197 | f0_min = 50.0 198 | f0_mel_min = 1127 * numpy.log(1 + f0_min / 700) 199 | f0_mel_max = 1127 * numpy.log(1 + f0_max / 700) 200 | 201 | def normalize_f0(f0, x_mask, uv, random_scale=True): 202 | # calculate means based on x_mask 203 | uv_sum = torch.sum(uv, dim=1, keepdim=True) 204 | uv_sum[uv_sum == 0] = 9999 205 | means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum 206 | 207 | if random_scale: 208 | factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device) 209 | else: 210 | factor = torch.ones(f0.shape[0], 1).to(f0.device) 211 | # normalize f0 based on means and factor 212 | f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) 213 | if torch.isnan(f0_norm).any(): 214 | exit(0) 215 | return f0_norm * x_mask 216 | 217 | def compute_f0_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05): 218 | x = wav_numpy 219 | if p_len is None: 220 | p_len = x.shape[0]//hop_length 221 | else: 222 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 223 | 224 | x = torch.from_numpy(x.astype(numpy.float32)).clone() 225 | F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold) 226 | f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len) 227 | f0[uv<0.5] = 0 228 | return f0,uv 229 | 230 | def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 231 | x = wav_numpy 232 | if p_len is None: 233 | p_len = x.shape[0]//hop_length 234 | else: 235 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 236 | 237 | f0, t = pyworld.harvest( 238 | x.astype(numpy.double), 239 | fs=sampling_rate, 240 | f0_ceil=f0_max, 241 | f0_floor=f0_min, 242 | frame_period=1000 * hop_length / sampling_rate, 243 | ) 244 | f0 = pyworld.stonemask(x.astype(numpy.double), f0, t, sampling_rate) 245 | return resize_f0(f0, p_len) 246 | 247 | def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 248 | import parselmouth 249 | x = wav_numpy 250 | if p_len is None: 251 | p_len = x.shape[0]//hop_length 252 | else: 253 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 254 | time_step = hop_length / sampling_rate * 1000 255 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac( 256 | time_step=time_step / 1000, voicing_threshold=0.6, 257 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 258 | 259 | pad_size=(p_len - len(f0) + 1) // 2 260 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 261 | f0 = numpy.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 262 | return f0 263 | 264 | def resize_f0(x, target_len): 265 | source = numpy.array(x) 266 | source[source<0.001] = numpy.nan 267 | target = numpy.interp(numpy.arange(0, len(source)*target_len, len(source))/ target_len, numpy.arange(0, len(source)), source) 268 | res = numpy.nan_to_num(target) 269 | return res 270 | 271 | def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 272 | import pyworld 273 | if p_len is None: 274 | p_len = wav_numpy.shape[0]//hop_length 275 | f0, t = pyworld.dio( 276 | wav_numpy.astype(numpy.double), 277 | fs=sampling_rate, 278 | f0_ceil=f0_max, 279 | f0_floor=f0_min, 280 | frame_period=1000 * hop_length / sampling_rate, 281 | ) 282 | f0 = pyworld.stonemask(wav_numpy.astype(numpy.double), f0, t, sampling_rate) 283 | for index, pitch in enumerate(f0): 284 | f0[index] = round(pitch, 1) 285 | return resize_f0(f0, p_len) 286 | 287 | def f0_to_coarse(f0): 288 | is_torch = isinstance(f0, torch.Tensor) 289 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * numpy.log(1 + f0 / 700) 290 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 291 | 292 | f0_mel[f0_mel <= 1] = 1 293 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 294 | f0_coarse = (f0_mel + 0.5).int() if is_torch else numpy.rint(f0_mel).astype(numpy.int) 295 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 296 | return f0_coarse 297 | 298 | 299 | def interpolate_f0(f0): 300 | 301 | data = numpy.reshape(f0, (f0.size, 1)) 302 | 303 | vuv_vector = numpy.zeros((data.size, 1), dtype=numpy.float32) 304 | vuv_vector[data > 0.0] = 1.0 305 | vuv_vector[data <= 0.0] = 0.0 306 | 307 | ip_data = data 308 | 309 | frame_number = data.size 310 | last_value = 0.0 311 | for i in range(frame_number): 312 | if data[i] <= 0.0: 313 | j = i + 1 314 | for j in range(i + 1, frame_number): 315 | if data[j] > 0.0: 316 | break 317 | if j < frame_number - 1: 318 | if last_value > 0.0: 319 | step = (data[j] - data[i - 1]) / float(j - i) 320 | for k in range(i, j): 321 | ip_data[k] = data[i - 1] + step * (k - i + 1) 322 | else: 323 | for k in range(i, j): 324 | ip_data[k] = data[j] 325 | else: 326 | for k in range(i, frame_number): 327 | ip_data[k] = last_value 328 | else: 329 | ip_data[i] = data[i] # this may not be necessary 330 | last_value = data[i] 331 | 332 | return ip_data[:,0], vuv_vector[:,0] 333 | 334 | from typing import Optional,Union 335 | try: 336 | from typing import Literal 337 | except Exception as e: 338 | from typing_extensions import Literal 339 | import numpy as np 340 | import torch 341 | import torchcrepe 342 | from torch import nn 343 | from torch.nn import functional as F 344 | import scipy 345 | 346 | #from:https://github.com/fishaudio/fish-diffusion 347 | 348 | def repeat_expand( 349 | content: Union[torch.Tensor, numpy.ndarray], target_len: int, mode: str = "nearest" 350 | ): 351 | """Repeat content to target length. 352 | This is a wrapper of torch.nn.functional.interpolate. 353 | 354 | Args: 355 | content (torch.Tensor): tensor 356 | target_len (int): target length 357 | mode (str, optional): interpolation mode. Defaults to "nearest". 358 | 359 | Returns: 360 | torch.Tensor: tensor 361 | """ 362 | 363 | ndim = content.ndim 364 | 365 | if content.ndim == 1: 366 | content = content[None, None] 367 | elif content.ndim == 2: 368 | content = content[None] 369 | 370 | assert content.ndim == 3 371 | 372 | is_np = isinstance(content, numpy.ndarray) 373 | if is_np: 374 | content = torch.from_numpy(content) 375 | 376 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) 377 | 378 | if is_np: 379 | results = results.numpy() 380 | 381 | if ndim == 1: 382 | return results[0, 0] 383 | elif ndim == 2: 384 | return results[0] 385 | 386 | 387 | class BasePitchExtractor: 388 | def __init__( 389 | self, 390 | hop_length: int = 512, 391 | f0_min: float = 50.0, 392 | f0_max: float = 1100.0, 393 | keep_zeros: bool = True, 394 | ): 395 | """Base pitch extractor. 396 | 397 | Args: 398 | hop_length (int, optional): Hop length. Defaults to 512. 399 | f0_min (float, optional): Minimum f0. Defaults to 50.0. 400 | f0_max (float, optional): Maximum f0. Defaults to 1100.0. 401 | keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True. 402 | """ 403 | 404 | self.hop_length = hop_length 405 | self.f0_min = f0_min 406 | self.f0_max = f0_max 407 | self.keep_zeros = keep_zeros 408 | 409 | def __call__(self, x, sampling_rate=44100, pad_to=None): 410 | raise NotImplementedError("BasePitchExtractor is not callable.") 411 | 412 | def post_process(self, x, sampling_rate, f0, pad_to): 413 | if isinstance(f0, numpy.ndarray): 414 | f0 = torch.from_numpy(f0).float().to(x.device) 415 | 416 | if pad_to is None: 417 | return f0 418 | 419 | f0 = repeat_expand(f0, pad_to) 420 | 421 | if self.keep_zeros: 422 | return f0 423 | 424 | vuv_vector = torch.zeros_like(f0) 425 | vuv_vector[f0 > 0.0] = 1.0 426 | vuv_vector[f0 <= 0.0] = 0.0 427 | 428 | # Remove 0 frequency and apply linear interpolation 429 | nzindex = torch.nonzero(f0).squeeze() 430 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() 431 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() 432 | time_frame = numpy.arange(pad_to) * self.hop_length / sampling_rate 433 | 434 | if f0.shape[0] <= 0: 435 | return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device) 436 | 437 | if f0.shape[0] == 1: 438 | return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device) 439 | 440 | # Probably can be rewritten with torch? 441 | f0 = numpy.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) 442 | vuv_vector = vuv_vector.cpu().numpy() 443 | vuv_vector = numpy.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) 444 | 445 | return f0,vuv_vector 446 | 447 | 448 | class MaskedAvgPool1d(nn.Module): 449 | def __init__( 450 | self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 451 | ): 452 | """An implementation of mean pooling that supports masked values. 453 | 454 | Args: 455 | kernel_size (int): The size of the median pooling window. 456 | stride (int, optional): The stride of the median pooling window. Defaults to None. 457 | padding (int, optional): The padding of the median pooling window. Defaults to 0. 458 | """ 459 | 460 | super(MaskedAvgPool1d, self).__init__() 461 | self.kernel_size = kernel_size 462 | self.stride = stride or kernel_size 463 | self.padding = padding 464 | 465 | def forward(self, x, mask=None): 466 | ndim = x.dim() 467 | if ndim == 2: 468 | x = x.unsqueeze(1) 469 | 470 | assert ( 471 | x.dim() == 3 472 | ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" 473 | 474 | # Apply the mask by setting masked elements to zero, or make NaNs zero 475 | if mask is None: 476 | mask = ~torch.isnan(x) 477 | 478 | # Ensure mask has the same shape as the input tensor 479 | assert x.shape == mask.shape, "Input tensor and mask must have the same shape" 480 | 481 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 482 | # Create a ones kernel with the same number of channels as the input tensor 483 | ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device) 484 | 485 | # Perform sum pooling 486 | sum_pooled = nn.functional.conv1d( 487 | masked_x, 488 | ones_kernel, 489 | stride=self.stride, 490 | padding=self.padding, 491 | groups=x.size(1), 492 | ) 493 | 494 | # Count the non-masked (valid) elements in each pooling window 495 | valid_count = nn.functional.conv1d( 496 | mask.float(), 497 | ones_kernel, 498 | stride=self.stride, 499 | padding=self.padding, 500 | groups=x.size(1), 501 | ) 502 | valid_count = valid_count.clamp(min=1) # Avoid division by zero 503 | 504 | # Perform masked average pooling 505 | avg_pooled = sum_pooled / valid_count 506 | 507 | # Fill zero values with NaNs 508 | avg_pooled[avg_pooled == 0] = float("nan") 509 | 510 | if ndim == 2: 511 | return avg_pooled.squeeze(1) 512 | 513 | return avg_pooled 514 | 515 | 516 | class MaskedMedianPool1d(nn.Module): 517 | def __init__( 518 | self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 519 | ): 520 | """An implementation of median pooling that supports masked values. 521 | 522 | This implementation is inspired by the median pooling implementation in 523 | https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 524 | 525 | Args: 526 | kernel_size (int): The size of the median pooling window. 527 | stride (int, optional): The stride of the median pooling window. Defaults to None. 528 | padding (int, optional): The padding of the median pooling window. Defaults to 0. 529 | """ 530 | 531 | super(MaskedMedianPool1d, self).__init__() 532 | self.kernel_size = kernel_size 533 | self.stride = stride or kernel_size 534 | self.padding = padding 535 | 536 | def forward(self, x, mask=None): 537 | ndim = x.dim() 538 | if ndim == 2: 539 | x = x.unsqueeze(1) 540 | 541 | assert ( 542 | x.dim() == 3 543 | ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" 544 | 545 | if mask is None: 546 | mask = ~torch.isnan(x) 547 | 548 | assert x.shape == mask.shape, "Input tensor and mask must have the same shape" 549 | 550 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 551 | 552 | x = F.pad(masked_x, (self.padding, self.padding), mode="reflect") 553 | mask = F.pad( 554 | mask.float(), (self.padding, self.padding), mode="constant", value=0 555 | ) 556 | 557 | x = x.unfold(2, self.kernel_size, self.stride) 558 | mask = mask.unfold(2, self.kernel_size, self.stride) 559 | 560 | x = x.contiguous().view(x.size()[:3] + (-1,)) 561 | mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device) 562 | 563 | # Combine the mask with the input tensor 564 | #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf"))) 565 | x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device)) 566 | 567 | # Sort the masked tensor along the last dimension 568 | x_sorted, _ = torch.sort(x_masked, dim=-1) 569 | 570 | # Compute the count of non-masked (valid) values 571 | valid_count = mask.sum(dim=-1) 572 | 573 | # Calculate the index of the median value for each pooling window 574 | median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0) 575 | 576 | # Gather the median values using the calculated indices 577 | median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) 578 | 579 | # Fill infinite values with NaNs 580 | median_pooled[torch.isinf(median_pooled)] = float("nan") 581 | 582 | if ndim == 2: 583 | return median_pooled.squeeze(1) 584 | 585 | return median_pooled 586 | 587 | 588 | class CrepePitchExtractor(BasePitchExtractor): 589 | def __init__( 590 | self, 591 | hop_length: int = 512, 592 | f0_min: float = 50.0, 593 | f0_max: float = 1100.0, 594 | threshold: float = 0.05, 595 | keep_zeros: bool = False, 596 | device = None, 597 | model: Literal["full", "tiny"] = "full", 598 | use_fast_filters: bool = True, 599 | ): 600 | super().__init__(hop_length, f0_min, f0_max, keep_zeros) 601 | 602 | self.threshold = threshold 603 | self.model = model 604 | self.use_fast_filters = use_fast_filters 605 | self.hop_length = hop_length 606 | if device is None: 607 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 608 | else: 609 | self.dev = torch.device(device) 610 | if self.use_fast_filters: 611 | self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device) 612 | self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device) 613 | 614 | def __call__(self, x, sampling_rate=44100, pad_to=None): 615 | """Extract pitch using crepe. 616 | 617 | 618 | Args: 619 | x (torch.Tensor): Audio signal, shape (1, T). 620 | sampling_rate (int, optional): Sampling rate. Defaults to 44100. 621 | pad_to (int, optional): Pad to length. Defaults to None. 622 | 623 | Returns: 624 | torch.Tensor: Pitch, shape (T // hop_length,). 625 | """ 626 | 627 | assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor." 628 | assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels." 629 | 630 | x = x.to(self.dev) 631 | f0, pd = torchcrepe.predict( 632 | x, 633 | sampling_rate, 634 | self.hop_length, 635 | self.f0_min, 636 | self.f0_max, 637 | pad=True, 638 | model=self.model, 639 | batch_size=1024, 640 | device=x.device, 641 | return_periodicity=True, 642 | ) 643 | 644 | # Filter, remove silence, set uv threshold, refer to the original warehouse readme 645 | if self.use_fast_filters: 646 | pd = self.median_filter(pd) 647 | else: 648 | pd = torchcrepe.filter.median(pd, 3) 649 | 650 | pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512) 651 | f0 = torchcrepe.threshold.At(self.threshold)(f0, pd) 652 | 653 | if self.use_fast_filters: 654 | f0 = self.mean_filter(f0) 655 | else: 656 | f0 = torchcrepe.filter.mean(f0, 3) 657 | 658 | f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0] 659 | 660 | if torch.all(f0 == 0): 661 | rtn = f0.cpu().numpy() if pad_to==None else numpy.zeros(pad_to) 662 | return rtn,rtn 663 | 664 | return self.post_process(x, sampling_rate, f0, pad_to) 665 | 666 | ### ### 667 | 668 | 669 | if __name__ == "__main__": 670 | parser = argparse.ArgumentParser(description="Encode an audio dataset.") 671 | parser.add_argument( 672 | "--model", 673 | # help="available models (HuBERT-Soft or HuBERT-Discrete)", 674 | help="available models (HuBERT-Soft or HuBERT-Discrete or japanese-hubert-base)", 675 | choices=["soft", "soft", "japanese-hubert-base"], 676 | default="japanese-hubert-base" 677 | ) 678 | parser.add_argument( 679 | "--f0", 680 | # help="available models (HuBERT-Soft or HuBERT-Discrete)", 681 | help="available F0 extractor", 682 | choices=["dio", "parselmouth", "harvest", "crepe"], 683 | default="harvest" 684 | ) 685 | parser.add_argument( 686 | "--in_dir", 687 | help="path to the dataset directory.", 688 | default="./dataset/", ### add ### 689 | type=Path, 690 | ) 691 | parser.add_argument( 692 | "--out_dir", 693 | help="path to the output directory.", 694 | default="./dataset/", ### add ### 695 | type=Path, 696 | ) 697 | parser.add_argument( 698 | "--extension", 699 | help="extension of the audio files (defaults to .flac).", 700 | default=".wav", 701 | type=str, 702 | ) 703 | args = parser.parse_args() 704 | encode_dataset(args) 705 | """ 706 | wav_path = "./dataset/jvs_ver1/jvs001/falset10/wav24kHz16bit/BASIC5000_0235.wav" 707 | 708 | model = HubertModel.from_pretrained("rinna/japanese-hubert-base") 709 | model.eval() 710 | model.to(device) 711 | 712 | wav, sr = torchaudio.load(wav_path) 713 | wav_16k = resample(wav, sr, 16000).to(device) 714 | #wav = wav.unsqueeze(0) 715 | with torch.inference_mode(): 716 | units = model(wav_16k) 717 | f0_harvest = compute_f0_harvest(wav[0].to('cpu').detach().numpy().copy(), sampling_rate=sr, hop_length=int(512)) 718 | print("") 719 | """ -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import librosa 5 | import time 6 | from scipy.io.wavfile import write 7 | 8 | import utils 9 | from models import SynthesizerTrn 10 | from mel_processing import mel_spectrogram_torch 11 | import logging 12 | logging.getLogger('numba').setLevel(logging.WARNING) 13 | 14 | # add # 15 | from transformers import HubertModel 16 | import soundcard as sc 17 | import logging 18 | import numpy 19 | from torch.nn import functional as F 20 | import pyworld 21 | import soundfile as sf 22 | 23 | 24 | def inference(args): 25 | 26 | # Set counter 27 | count = 0 28 | 29 | # Check device 30 | if torch.cuda.is_available() is True: 31 | print("Enter the device number to use.") 32 | key = input("GPU:0, CPU:1 ===> ") 33 | if key == "0": 34 | device="cuda:0" 35 | elif key=="1": 36 | device="cpu" 37 | print(f"Device : {device}") 38 | else: 39 | print(f"CUDA is not available. Device : cpu") 40 | device = "cpu" 41 | 42 | # play audio by system default 43 | speaker = sc.get_speaker(sc.default_speaker().name) 44 | 45 | # Load config 46 | hps = utils.get_hparams_from_file(args.hpfile) 47 | os.makedirs(args.outdir, exist_ok=True) 48 | 49 | # Init Generator 50 | net_g = SynthesizerTrn( 51 | hps.data.filter_length // 2 + 1, 52 | hps.train.segment_size // hps.data.hop_length, 53 | **hps.model).to(device) 54 | _ = net_g.eval() 55 | print("Loadied generator model.") 56 | total = sum([param.nelement() for param in net_g.parameters()]) 57 | _ = utils.load_checkpoint(args.ptfile, net_g, None) 58 | print("Loadied checkpoint...") 59 | 60 | # Select hubert model 61 | if args.hubert == "japanese-hubert-base": 62 | hubert_model = HubertModel.from_pretrained("rinna/japanese-hubert-base") 63 | hubert_model.eval() 64 | hubert_model.to(device) 65 | hubert_name = "rinna/japanese-hubert-base." 66 | print("Loaded rinna/japanese-hubert-base.") 67 | else: 68 | hubert_model = torch.hub.load("bshall/hubert:main", f"hubert_soft").to(device) 69 | hubert_name = "hubert_soft" 70 | print("Loaded soft hubert.") 71 | while True: 72 | # Load Target Speech 73 | target_wavpath = input("Enter the target speech wavpath ==> ") 74 | y_tgt, sr_tgt = sf.read(target_wavpath) 75 | y_tgt = y_tgt.astype(np.float32) 76 | print("Loaded the target speech. :",target_wavpath) 77 | if sr_tgt != hps.data.sampling_rate: 78 | y_tgt = librosa.resample(y_tgt, orig_sr=sr_tgt, target_sr=hps.data.sampling_rate) 79 | print(f"Detect {sr_tgt} Hz. Target speech is resampled to {hps.data.sampling_rate} Hz.") 80 | 81 | # Load Source Speech 82 | source_wavpath = input("Enter the source speech wavpath ==> ") 83 | y_src, sr_src = sf.read(source_wavpath) 84 | y_src = y_src.astype(np.float32) 85 | print("Loaded the target speech. :",source_wavpath) 86 | if sr_src != hps.data.sampling_rate: 87 | y_src = librosa.resample(y_src, orig_sr=sr_src, target_sr=hps.data.sampling_rate) 88 | print(f"Detect {sr_src} Hz. Source speech is resampled to {hps.data.sampling_rate} Hz.") 89 | 90 | # Select F0 method. 91 | print("F0 method = [ dio:0 | parselmouth:1 | harvest:2 | crepe:3 ]") 92 | f0_method = input("Enter F0 method ==> ") 93 | if f0_method == "dio" or f0_method=="0": 94 | f0_function = compute_f0_dio 95 | f0_method = "dio" 96 | print("Set up calculation of F0 by dio") 97 | elif f0_method == "parselmouth" or f0_method=="1": 98 | f0_function = compute_f0_parselmouth 99 | f0_method = "parselmouth" 100 | print("Set up calculation of F0 by parselmouth") 101 | elif f0_method == "harvest" or f0_method=="2": 102 | f0_function = compute_f0_harvest 103 | f0_method = "harvest" 104 | print("Set up calculation of F0 by harvest") 105 | elif f0_method == "crepe" or f0_method=="3": 106 | f0_function = compute_f0_torchcrepe 107 | f0_method = "crepe" 108 | print("Set up calculation of F0 by crepe") 109 | 110 | # Synchronize CUDA 111 | torch.cuda.synchronize() 112 | time_S = time.time() 113 | 114 | # Start Voice Convertion 115 | with torch.inference_mode(): 116 | 117 | # Calculate target mel spectrogram 118 | time_S_tgt_mel = time.time() 119 | wav_tgt = torch.from_numpy(y_tgt).unsqueeze(0).to(device) 120 | mel_tgt = mel_spectrogram_torch( # [Batch, n_mels, Frame] 121 | wav_tgt, 122 | hps.data.filter_length, 123 | hps.data.n_mel_channels, 124 | hps.data.sampling_rate, 125 | hps.data.hop_length, 126 | hps.data.win_length, 127 | hps.data.mel_fmin, 128 | hps.data.mel_fmax 129 | ) 130 | _,_,length = mel_tgt.shape 131 | time_E_tgt_mel = time.time() 132 | 133 | # Calculate source content embeddings. 134 | time_S_HuBERT= time.time() 135 | wav_src = torch.from_numpy(y_src).unsqueeze(0).to(device) 136 | if args.hubert == "japanese-hubert-base": 137 | c = hubert_model(wav_src)[0].squeeze() # [Frame, Hidden] 138 | else: 139 | c = hubert_model.units(wav_src) 140 | c=c.transpose(2,1) 141 | c = utils.repeat_expand_2d(c.transpose(1,0), length).unsqueeze(0) 142 | time_E_HuBERT= time.time() 143 | 144 | # Calculate F0. 145 | time_S_F0= time.time() 146 | if f0_method=="crepe": 147 | f0, vuv = f0_function(y_tgt, sampling_rate=sr_tgt, hop_length=512) 148 | else: 149 | f0 = f0_function(y_tgt, sampling_rate=sr_tgt, hop_length=512) 150 | f0 = torch.from_numpy(f0).unsqueeze(0).to(device) # [Batch, Frame] 151 | time_E_F0= time.time() 152 | 153 | # Infer 154 | time_S_infer= time.time() 155 | audio = net_g.infer(c, mel=mel_tgt, f0=f0) 156 | audio = audio[0][0].data.cpu().float().numpy() 157 | time_E_infer= time.time() 158 | 159 | # Synchronize CUDA 160 | torch.cuda.synchronize() 161 | time_E = time.time() 162 | 163 | # Print time infomation 164 | time_all = time_E - time_S 165 | time_tgtmel = time_E_tgt_mel - time_S_tgt_mel 166 | time_HuBERT = time_E_HuBERT - time_S_HuBERT 167 | time_F0 = time_E_F0 - time_S_F0 168 | time_infer = time_E_infer - time_S_infer 169 | print(f"ALL Calc Time : {time_all}") 170 | print(f"Tgt mel Time : {time_tgtmel}") 171 | print(f"HuBERT Time : {time_HuBERT}") 172 | print(f"F0 Time : {time_F0}") 173 | print(f"Inference Time : {time_infer}") 174 | 175 | timestamp = time.strftime("%m-%d_%H-%M", time.localtime()) 176 | filepath = os.path.join(args.outdir, "{}.wav".format(timestamp+"_"+str(count).zfill(3))) 177 | write(filepath, hps.data.sampling_rate, audio) 178 | print("Inference audio is saved at",filepath) 179 | 180 | with open(os.path.join(args.outdir, "inference_logs.txt"), mode="a", encoding="utf-8") as f: 181 | if count==0: 182 | f.write("~~filepath~~|~~source_wavpath~~|~~target_wavpath~~|~~f0~~|~~HuBERT~~| \n") 183 | txt = f"{filepath}|{source_wavpath}|{target_wavpath}|{f0_method}|{hubert_name}\n" 184 | f.write(txt) 185 | 186 | print("Inference log is saved at",os.path.join(args.outdir, "inference_logs.txt")) 187 | 188 | count += 1 189 | # play audio 190 | print("Play generated audio") 191 | speaker.play(audio, hps.data.sampling_rate) 192 | 193 | print("~END~ \n\n") 194 | 195 | return 0 196 | 197 | 198 | 199 | f0_bin = 256 200 | f0_max = 1100.0 201 | f0_min = 50.0 202 | f0_mel_min = 1127 * numpy.log(1 + f0_min / 700) 203 | f0_mel_max = 1127 * numpy.log(1 + f0_max / 700) 204 | 205 | def normalize_f0(f0, x_mask, uv, random_scale=True): 206 | # calculate means based on x_mask 207 | uv_sum = torch.sum(uv, dim=1, keepdim=True) 208 | uv_sum[uv_sum == 0] = 9999 209 | means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum 210 | 211 | if random_scale: 212 | factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device) 213 | else: 214 | factor = torch.ones(f0.shape[0], 1).to(f0.device) 215 | # normalize f0 based on means and factor 216 | f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) 217 | if torch.isnan(f0_norm).any(): 218 | exit(0) 219 | return f0_norm * x_mask 220 | 221 | def compute_f0_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05): 222 | x = wav_numpy 223 | if p_len is None: 224 | p_len = x.shape[0]//hop_length 225 | else: 226 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 227 | 228 | x = torch.from_numpy(x.astype(numpy.float32)).clone() 229 | F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold) 230 | f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len) 231 | f0[uv<0.5] = 0 232 | return f0,uv 233 | 234 | def compute_f0_harvest(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 235 | x = wav_numpy 236 | if p_len is None: 237 | p_len = x.shape[0]//hop_length 238 | else: 239 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 240 | 241 | f0, t = pyworld.harvest( 242 | x.astype(numpy.double), 243 | fs=sampling_rate, 244 | f0_ceil=f0_max, 245 | f0_floor=f0_min, 246 | frame_period=1000 * hop_length / sampling_rate, 247 | ) 248 | f0 = pyworld.stonemask(x.astype(numpy.double), f0, t, sampling_rate) 249 | return resize_f0(f0, p_len) 250 | 251 | def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 252 | import parselmouth 253 | x = wav_numpy 254 | if p_len is None: 255 | p_len = x.shape[0]//hop_length 256 | else: 257 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 258 | time_step = hop_length / sampling_rate * 1000 259 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac( 260 | time_step=time_step / 1000, voicing_threshold=0.6, 261 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 262 | 263 | pad_size=(p_len - len(f0) + 1) // 2 264 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 265 | f0 = numpy.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 266 | return f0 267 | 268 | def resize_f0(x, target_len): 269 | source = numpy.array(x) 270 | source[source<0.001] = numpy.nan 271 | target = numpy.interp(numpy.arange(0, len(source)*target_len, len(source))/ target_len, numpy.arange(0, len(source)), source) 272 | res = numpy.nan_to_num(target) 273 | return res 274 | 275 | def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 276 | import pyworld 277 | if p_len is None: 278 | p_len = wav_numpy.shape[0]//hop_length 279 | f0, t = pyworld.dio( 280 | wav_numpy.astype(numpy.double), 281 | fs=sampling_rate, 282 | f0_ceil=f0_max, 283 | f0_floor=f0_min, 284 | frame_period=1000 * hop_length / sampling_rate, 285 | ) 286 | f0 = pyworld.stonemask(wav_numpy.astype(numpy.double), f0, t, sampling_rate) 287 | for index, pitch in enumerate(f0): 288 | f0[index] = round(pitch, 1) 289 | return resize_f0(f0, p_len) 290 | 291 | def f0_to_coarse(f0): 292 | is_torch = isinstance(f0, torch.Tensor) 293 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * numpy.log(1 + f0 / 700) 294 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 295 | 296 | f0_mel[f0_mel <= 1] = 1 297 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 298 | f0_coarse = (f0_mel + 0.5).int() if is_torch else numpy.rint(f0_mel).astype(numpy.int) 299 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 300 | return f0_coarse 301 | 302 | 303 | def interpolate_f0(f0): 304 | 305 | data = numpy.reshape(f0, (f0.size, 1)) 306 | 307 | vuv_vector = numpy.zeros((data.size, 1), dtype=numpy.float32) 308 | vuv_vector[data > 0.0] = 1.0 309 | vuv_vector[data <= 0.0] = 0.0 310 | 311 | ip_data = data 312 | 313 | frame_number = data.size 314 | last_value = 0.0 315 | for i in range(frame_number): 316 | if data[i] <= 0.0: 317 | j = i + 1 318 | for j in range(i + 1, frame_number): 319 | if data[j] > 0.0: 320 | break 321 | if j < frame_number - 1: 322 | if last_value > 0.0: 323 | step = (data[j] - data[i - 1]) / float(j - i) 324 | for k in range(i, j): 325 | ip_data[k] = data[i - 1] + step * (k - i + 1) 326 | else: 327 | for k in range(i, j): 328 | ip_data[k] = data[j] 329 | else: 330 | for k in range(i, frame_number): 331 | ip_data[k] = last_value 332 | else: 333 | ip_data[i] = data[i] # this may not be necessary 334 | last_value = data[i] 335 | 336 | return ip_data[:,0], vuv_vector[:,0] 337 | 338 | from typing import Optional,Union 339 | try: 340 | from typing import Literal 341 | except Exception as e: 342 | from typing_extensions import Literal 343 | import numpy as np 344 | import torch 345 | import torchcrepe 346 | from torch import nn 347 | from torch.nn import functional as F 348 | import scipy 349 | 350 | #from:https://github.com/fishaudio/fish-diffusion 351 | 352 | def repeat_expand( 353 | content: Union[torch.Tensor, numpy.ndarray], target_len: int, mode: str = "nearest" 354 | ): 355 | """Repeat content to target length. 356 | This is a wrapper of torch.nn.functional.interpolate. 357 | 358 | Args: 359 | content (torch.Tensor): tensor 360 | target_len (int): target length 361 | mode (str, optional): interpolation mode. Defaults to "nearest". 362 | 363 | Returns: 364 | torch.Tensor: tensor 365 | """ 366 | 367 | ndim = content.ndim 368 | 369 | if content.ndim == 1: 370 | content = content[None, None] 371 | elif content.ndim == 2: 372 | content = content[None] 373 | 374 | assert content.ndim == 3 375 | 376 | is_np = isinstance(content, numpy.ndarray) 377 | if is_np: 378 | content = torch.from_numpy(content) 379 | 380 | results = torch.nn.functional.interpolate(content, size=target_len, mode=mode) 381 | 382 | if is_np: 383 | results = results.numpy() 384 | 385 | if ndim == 1: 386 | return results[0, 0] 387 | elif ndim == 2: 388 | return results[0] 389 | 390 | 391 | class BasePitchExtractor: 392 | def __init__( 393 | self, 394 | hop_length: int = 512, 395 | f0_min: float = 50.0, 396 | f0_max: float = 1100.0, 397 | keep_zeros: bool = True, 398 | ): 399 | """Base pitch extractor. 400 | 401 | Args: 402 | hop_length (int, optional): Hop length. Defaults to 512. 403 | f0_min (float, optional): Minimum f0. Defaults to 50.0. 404 | f0_max (float, optional): Maximum f0. Defaults to 1100.0. 405 | keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True. 406 | """ 407 | 408 | self.hop_length = hop_length 409 | self.f0_min = f0_min 410 | self.f0_max = f0_max 411 | self.keep_zeros = keep_zeros 412 | 413 | def __call__(self, x, sampling_rate=44100, pad_to=None): 414 | raise NotImplementedError("BasePitchExtractor is not callable.") 415 | 416 | def post_process(self, x, sampling_rate, f0, pad_to): 417 | if isinstance(f0, numpy.ndarray): 418 | f0 = torch.from_numpy(f0).float().to(x.device) 419 | 420 | if pad_to is None: 421 | return f0 422 | 423 | f0 = repeat_expand(f0, pad_to) 424 | 425 | if self.keep_zeros: 426 | return f0 427 | 428 | vuv_vector = torch.zeros_like(f0) 429 | vuv_vector[f0 > 0.0] = 1.0 430 | vuv_vector[f0 <= 0.0] = 0.0 431 | 432 | # Remove 0 frequency and apply linear interpolation 433 | nzindex = torch.nonzero(f0).squeeze() 434 | f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy() 435 | time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy() 436 | time_frame = numpy.arange(pad_to) * self.hop_length / sampling_rate 437 | 438 | if f0.shape[0] <= 0: 439 | return torch.zeros(pad_to, dtype=torch.float, device=x.device),torch.zeros(pad_to, dtype=torch.float, device=x.device) 440 | 441 | if f0.shape[0] == 1: 442 | return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],torch.ones(pad_to, dtype=torch.float, device=x.device) 443 | 444 | # Probably can be rewritten with torch? 445 | f0 = numpy.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) 446 | vuv_vector = vuv_vector.cpu().numpy() 447 | vuv_vector = numpy.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0)) 448 | 449 | return f0,vuv_vector 450 | 451 | 452 | class MaskedAvgPool1d(nn.Module): 453 | def __init__( 454 | self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 455 | ): 456 | """An implementation of mean pooling that supports masked values. 457 | 458 | Args: 459 | kernel_size (int): The size of the median pooling window. 460 | stride (int, optional): The stride of the median pooling window. Defaults to None. 461 | padding (int, optional): The padding of the median pooling window. Defaults to 0. 462 | """ 463 | 464 | super(MaskedAvgPool1d, self).__init__() 465 | self.kernel_size = kernel_size 466 | self.stride = stride or kernel_size 467 | self.padding = padding 468 | 469 | def forward(self, x, mask=None): 470 | ndim = x.dim() 471 | if ndim == 2: 472 | x = x.unsqueeze(1) 473 | 474 | assert ( 475 | x.dim() == 3 476 | ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" 477 | 478 | # Apply the mask by setting masked elements to zero, or make NaNs zero 479 | if mask is None: 480 | mask = ~torch.isnan(x) 481 | 482 | # Ensure mask has the same shape as the input tensor 483 | assert x.shape == mask.shape, "Input tensor and mask must have the same shape" 484 | 485 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 486 | # Create a ones kernel with the same number of channels as the input tensor 487 | ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device) 488 | 489 | # Perform sum pooling 490 | sum_pooled = nn.functional.conv1d( 491 | masked_x, 492 | ones_kernel, 493 | stride=self.stride, 494 | padding=self.padding, 495 | groups=x.size(1), 496 | ) 497 | 498 | # Count the non-masked (valid) elements in each pooling window 499 | valid_count = nn.functional.conv1d( 500 | mask.float(), 501 | ones_kernel, 502 | stride=self.stride, 503 | padding=self.padding, 504 | groups=x.size(1), 505 | ) 506 | valid_count = valid_count.clamp(min=1) # Avoid division by zero 507 | 508 | # Perform masked average pooling 509 | avg_pooled = sum_pooled / valid_count 510 | 511 | # Fill zero values with NaNs 512 | avg_pooled[avg_pooled == 0] = float("nan") 513 | 514 | if ndim == 2: 515 | return avg_pooled.squeeze(1) 516 | 517 | return avg_pooled 518 | 519 | 520 | class MaskedMedianPool1d(nn.Module): 521 | def __init__( 522 | self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0 523 | ): 524 | """An implementation of median pooling that supports masked values. 525 | 526 | This implementation is inspired by the median pooling implementation in 527 | https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 528 | 529 | Args: 530 | kernel_size (int): The size of the median pooling window. 531 | stride (int, optional): The stride of the median pooling window. Defaults to None. 532 | padding (int, optional): The padding of the median pooling window. Defaults to 0. 533 | """ 534 | 535 | super(MaskedMedianPool1d, self).__init__() 536 | self.kernel_size = kernel_size 537 | self.stride = stride or kernel_size 538 | self.padding = padding 539 | 540 | def forward(self, x, mask=None): 541 | ndim = x.dim() 542 | if ndim == 2: 543 | x = x.unsqueeze(1) 544 | 545 | assert ( 546 | x.dim() == 3 547 | ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)" 548 | 549 | if mask is None: 550 | mask = ~torch.isnan(x) 551 | 552 | assert x.shape == mask.shape, "Input tensor and mask must have the same shape" 553 | 554 | masked_x = torch.where(mask, x, torch.zeros_like(x)) 555 | 556 | x = F.pad(masked_x, (self.padding, self.padding), mode="reflect") 557 | mask = F.pad( 558 | mask.float(), (self.padding, self.padding), mode="constant", value=0 559 | ) 560 | 561 | x = x.unfold(2, self.kernel_size, self.stride) 562 | mask = mask.unfold(2, self.kernel_size, self.stride) 563 | 564 | x = x.contiguous().view(x.size()[:3] + (-1,)) 565 | mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device) 566 | 567 | # Combine the mask with the input tensor 568 | #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf"))) 569 | x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device)) 570 | 571 | # Sort the masked tensor along the last dimension 572 | x_sorted, _ = torch.sort(x_masked, dim=-1) 573 | 574 | # Compute the count of non-masked (valid) values 575 | valid_count = mask.sum(dim=-1) 576 | 577 | # Calculate the index of the median value for each pooling window 578 | median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0) 579 | 580 | # Gather the median values using the calculated indices 581 | median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) 582 | 583 | # Fill infinite values with NaNs 584 | median_pooled[torch.isinf(median_pooled)] = float("nan") 585 | 586 | if ndim == 2: 587 | return median_pooled.squeeze(1) 588 | 589 | return median_pooled 590 | 591 | 592 | class CrepePitchExtractor(BasePitchExtractor): 593 | def __init__( 594 | self, 595 | hop_length: int = 512, 596 | f0_min: float = 50.0, 597 | f0_max: float = 1100.0, 598 | threshold: float = 0.05, 599 | keep_zeros: bool = False, 600 | device = None, 601 | model: Literal["full", "tiny"] = "full", 602 | use_fast_filters: bool = True, 603 | ): 604 | super().__init__(hop_length, f0_min, f0_max, keep_zeros) 605 | 606 | self.threshold = threshold 607 | self.model = model 608 | self.use_fast_filters = use_fast_filters 609 | self.hop_length = hop_length 610 | if device is None: 611 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 612 | else: 613 | self.dev = torch.device(device) 614 | if self.use_fast_filters: 615 | self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device) 616 | self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device) 617 | 618 | def __call__(self, x, sampling_rate=44100, pad_to=None): 619 | """Extract pitch using crepe. 620 | 621 | 622 | Args: 623 | x (torch.Tensor): Audio signal, shape (1, T). 624 | sampling_rate (int, optional): Sampling rate. Defaults to 44100. 625 | pad_to (int, optional): Pad to length. Defaults to None. 626 | 627 | Returns: 628 | torch.Tensor: Pitch, shape (T // hop_length,). 629 | """ 630 | 631 | assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor." 632 | assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels." 633 | 634 | x = x.to(self.dev) 635 | f0, pd = torchcrepe.predict( 636 | x, 637 | sampling_rate, 638 | self.hop_length, 639 | self.f0_min, 640 | self.f0_max, 641 | pad=True, 642 | model=self.model, 643 | batch_size=1024, 644 | device=x.device, 645 | return_periodicity=True, 646 | ) 647 | 648 | # Filter, remove silence, set uv threshold, refer to the original warehouse readme 649 | if self.use_fast_filters: 650 | pd = self.median_filter(pd) 651 | else: 652 | pd = torchcrepe.filter.median(pd, 3) 653 | 654 | pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, 512) 655 | f0 = torchcrepe.threshold.At(self.threshold)(f0, pd) 656 | 657 | if self.use_fast_filters: 658 | f0 = self.mean_filter(f0) 659 | else: 660 | f0 = torchcrepe.filter.mean(f0, 3) 661 | 662 | f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0] 663 | 664 | if torch.all(f0 == 0): 665 | rtn = f0.cpu().numpy() if pad_to==None else numpy.zeros(pad_to) 666 | return rtn,rtn 667 | 668 | return self.post_process(x, sampling_rate, f0, pad_to) 669 | 670 | 671 | ################################################################ 672 | if __name__ == "__main__": 673 | 674 | # get arg 675 | parser = argparse.ArgumentParser() 676 | 677 | parser.add_argument("--config", 678 | type=str, 679 | required=True, 680 | default="./path/to/config.json", 681 | help="path to json config file") 682 | 683 | parser.add_argument("--model_path", 684 | type=str, 685 | required=True, 686 | default="./path/to/G_xxx.pth", 687 | help="path to pth file") 688 | 689 | parser.add_argument("--hubert", 690 | type=str, 691 | default="japanese-hubert-base", 692 | help="path to txt file") 693 | 694 | parser.add_argument("--outdir", 695 | type=str, 696 | default="infer_logs", 697 | help="path to output dir") 698 | 699 | args = parser.parse_args() 700 | 701 | # start 702 | inference(args) 703 | ################################################################ 704 | 705 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from stft_loss import MultiResolutionSTFTLoss 4 | 5 | 6 | import commons 7 | 8 | 9 | def feature_loss(fmap_r, fmap_g): 10 | loss = 0 11 | for dr, dg in zip(fmap_r, fmap_g): 12 | for rl, gl in zip(dr, dg): 13 | rl = rl.float().detach() 14 | gl = gl.float() 15 | loss += torch.mean(torch.abs(rl - gl)) 16 | 17 | return loss * 2 18 | 19 | 20 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 21 | loss = 0 22 | r_losses = [] 23 | g_losses = [] 24 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 25 | dr = dr.float() 26 | dg = dg.float() 27 | r_loss = torch.mean((1-dr)**2) 28 | g_loss = torch.mean(dg**2) 29 | loss += (r_loss + g_loss) 30 | r_losses.append(r_loss.item()) 31 | g_losses.append(g_loss.item()) 32 | 33 | return loss, r_losses, g_losses 34 | 35 | 36 | def generator_loss(disc_outputs): 37 | loss = 0 38 | gen_losses = [] 39 | for dg in disc_outputs: 40 | dg = dg.float() 41 | l = torch.mean((1-dg)**2) 42 | gen_losses.append(l) 43 | loss += l 44 | 45 | return loss, gen_losses 46 | 47 | 48 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 49 | """ 50 | z_p, logs_q: [b, h, t_t] 51 | m_p, logs_p: [b, h, t_t] 52 | """ 53 | z_p = z_p.float() 54 | logs_q = logs_q.float() 55 | m_p = m_p.float() 56 | logs_p = logs_p.float() 57 | z_mask = z_mask.float() 58 | 59 | kl = logs_p - logs_q - 0.5 60 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 61 | kl = torch.sum(kl * z_mask) 62 | l = kl / torch.sum(z_mask) 63 | return l 64 | 65 | def subband_stft_loss(h, y_mb, y_hat_mb): 66 | sub_stft_loss = MultiResolutionSTFTLoss(h.train.fft_sizes, h.train.hop_sizes, h.train.win_lengths) 67 | y_mb = y_mb.view(-1, y_mb.size(2)) 68 | y_hat_mb = y_hat_mb.view(-1, y_hat_mb.size(2)) 69 | sub_sc_loss, sub_mag_loss = sub_stft_loss(y_hat_mb[:, :y_mb.size(-1)], y_mb) 70 | return sub_sc_loss+sub_mag_loss 71 | 72 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | 66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 67 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 68 | 69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 70 | return spec 71 | 72 | 73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 74 | global mel_basis 75 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 76 | fmax_dtype_device = str(fmax) + '_' + dtype_device 77 | if fmax_dtype_device not in mel_basis: 78 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 81 | spec = spectral_normalize_torch(spec) 82 | return spec 83 | 84 | 85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 86 | if torch.min(y) < -1.: 87 | print('min value is ', torch.min(y)) 88 | if torch.max(y) > 1.: 89 | print('max value is ', torch.max(y)) 90 | 91 | global mel_basis, hann_window 92 | dtype_device = str(y.dtype) + '_' + str(y.device) 93 | fmax_dtype_device = str(fmax) + '_' + dtype_device 94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 95 | if fmax_dtype_device not in mel_basis: 96 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 98 | if wnsize_dtype_device not in hann_window: 99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 100 | 101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 102 | y = y.squeeze(1) 103 | 104 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 105 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 106 | 107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 108 | 109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 110 | spec = spectral_normalize_torch(spec) 111 | 112 | return spec 113 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | import commons 8 | import modules 9 | import attentions 10 | 11 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 12 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 13 | from commons import init_weights, get_padding 14 | from pqmf import PQMF 15 | from stft import TorchSTFT 16 | import math 17 | from utils import f0_to_coarse 18 | 19 | 20 | class StochasticDurationPredictor(nn.Module): 21 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 22 | super().__init__() 23 | filter_channels = in_channels # it needs to be removed from future version. 24 | self.in_channels = in_channels 25 | self.filter_channels = filter_channels 26 | self.kernel_size = kernel_size 27 | self.p_dropout = p_dropout 28 | self.n_flows = n_flows 29 | self.gin_channels = gin_channels 30 | 31 | self.log_flow = modules.Log() 32 | self.flows = nn.ModuleList() 33 | self.flows.append(modules.ElementwiseAffine(2)) 34 | for i in range(n_flows): 35 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 36 | self.flows.append(modules.Flip()) 37 | 38 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 39 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 40 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 41 | self.post_flows = nn.ModuleList() 42 | self.post_flows.append(modules.ElementwiseAffine(2)) 43 | for i in range(4): 44 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 45 | self.post_flows.append(modules.Flip()) 46 | 47 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 48 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 49 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 50 | if gin_channels != 0: 51 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 52 | 53 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 54 | x = torch.detach(x) 55 | x = self.pre(x) 56 | if g is not None: 57 | g = torch.detach(g) 58 | x = x + self.cond(g) 59 | x = self.convs(x, x_mask) 60 | x = self.proj(x) * x_mask 61 | 62 | if not reverse: 63 | flows = self.flows 64 | assert w is not None 65 | 66 | logdet_tot_q = 0 67 | h_w = self.post_pre(w) 68 | h_w = self.post_convs(h_w, x_mask) 69 | h_w = self.post_proj(h_w) * x_mask 70 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 71 | z_q = e_q 72 | for flow in self.post_flows: 73 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 74 | logdet_tot_q += logdet_q 75 | z_u, z1 = torch.split(z_q, [1, 1], 1) 76 | u = torch.sigmoid(z_u) * x_mask 77 | z0 = (w - u) * x_mask 78 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 79 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 80 | 81 | logdet_tot = 0 82 | z0, logdet = self.log_flow(z0, x_mask) 83 | logdet_tot += logdet 84 | z = torch.cat([z0, z1], 1) 85 | for flow in flows: 86 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 87 | logdet_tot = logdet_tot + logdet 88 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 89 | return nll + logq # [b] 90 | else: 91 | flows = list(reversed(self.flows)) 92 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 93 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 94 | for flow in flows: 95 | z = flow(z, x_mask, g=x, reverse=reverse) 96 | z0, z1 = torch.split(z, [1, 1], 1) 97 | logw = z0 98 | return logw 99 | 100 | 101 | class DurationPredictor(nn.Module): 102 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 103 | super().__init__() 104 | 105 | self.in_channels = in_channels 106 | self.filter_channels = filter_channels 107 | self.kernel_size = kernel_size 108 | self.p_dropout = p_dropout 109 | self.gin_channels = gin_channels 110 | 111 | self.drop = nn.Dropout(p_dropout) 112 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 113 | self.norm_1 = modules.LayerNorm(filter_channels) 114 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 115 | self.norm_2 = modules.LayerNorm(filter_channels) 116 | self.proj = nn.Conv1d(filter_channels, 1, 1) 117 | 118 | if gin_channels != 0: 119 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 120 | 121 | def forward(self, x, x_mask, g=None): 122 | x = torch.detach(x) 123 | if g is not None: 124 | g = torch.detach(g) 125 | x = x + self.cond(g) 126 | x = self.conv_1(x * x_mask) 127 | x = torch.relu(x) 128 | x = self.norm_1(x) 129 | x = self.drop(x) 130 | x = self.conv_2(x * x_mask) 131 | x = torch.relu(x) 132 | x = self.norm_2(x) 133 | x = self.drop(x) 134 | x = self.proj(x * x_mask) 135 | return x * x_mask 136 | 137 | """ 138 | class TextEncoder(nn.Module): 139 | def __init__(self, 140 | n_vocab, 141 | out_channels, 142 | hidden_channels, 143 | filter_channels, 144 | n_heads, 145 | n_layers, 146 | kernel_size, 147 | p_dropout): 148 | super().__init__() 149 | self.n_vocab = n_vocab 150 | self.out_channels = out_channels 151 | self.hidden_channels = hidden_channels 152 | self.filter_channels = filter_channels 153 | self.n_heads = n_heads 154 | self.n_layers = n_layers 155 | self.kernel_size = kernel_size 156 | self.p_dropout = p_dropout 157 | 158 | 159 | self.encoder = attentions.Encoder( 160 | hidden_channels, 161 | filter_channels, 162 | n_heads, 163 | n_layers, 164 | kernel_size, 165 | p_dropout) 166 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 167 | 168 | def forward(self, x, x_lengths): 169 | x = torch.transpose(x, 1, -1) # [b, h, t] 170 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 171 | 172 | x = self.encoder(x * x_mask, x_mask) 173 | stats = self.proj(x) * x_mask 174 | 175 | m, logs = torch.split(stats, self.out_channels, dim=1) 176 | return x, m, logs, x_mask 177 | """ 178 | 179 | class TextEncoder(nn.Module): 180 | def __init__(self, 181 | out_channels, 182 | hidden_channels, 183 | kernel_size, 184 | n_layers, 185 | gin_channels=0, 186 | filter_channels=None, 187 | n_heads=None, 188 | p_dropout=None): 189 | super().__init__() 190 | self.out_channels = out_channels 191 | self.hidden_channels = hidden_channels 192 | self.kernel_size = kernel_size 193 | self.n_layers = n_layers 194 | self.gin_channels = gin_channels 195 | self.lrelu = nn.LeakyReLU(0.1, inplace=True) 196 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 197 | self.f0_emb = nn.Embedding(256, hidden_channels) 198 | self.c_emb = nn.Linear(768, hidden_channels) 199 | 200 | self.enc_ = attentions.Encoder( 201 | hidden_channels, 202 | filter_channels, 203 | n_heads, 204 | n_layers, 205 | kernel_size, 206 | p_dropout) 207 | 208 | def forward(self, c, c_lemgth, f0, noice_scale=1): #c[B, Hid, Frame] f0[B, fra] 209 | #x = x + self.f0_emb(f0) 210 | f0_emb = self.f0_emb(f0) 211 | c = c.transpose(1,2) 212 | c_emb = self.c_emb(c) 213 | x = f0_emb + c_emb 214 | x = self.lrelu(x * math.sqrt(self.hidden_channels)) 215 | x = x.transpose(1,2) 216 | x_mask = torch.unsqueeze(commons.sequence_mask(c_lemgth, x.size(2)), 1).to(x.dtype) 217 | #x_mask = commons.sequence_mask(c_lemgth, x.size(2)).to(x.dtype) 218 | 219 | x = self.enc_(x * x_mask, x_mask) 220 | stats = self.proj(x) * x_mask 221 | m, logs = torch.split(stats, self.out_channels, dim=1) 222 | z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask 223 | 224 | return z, m, logs, x_mask 225 | 226 | 227 | class Encoder(nn.Module): 228 | def __init__(self, 229 | in_channels, 230 | out_channels, 231 | hidden_channels, 232 | kernel_size, 233 | dilation_rate, 234 | n_layers, 235 | gin_channels=0): 236 | super().__init__() 237 | self.in_channels = in_channels 238 | self.out_channels = out_channels 239 | self.hidden_channels = hidden_channels 240 | self.kernel_size = kernel_size 241 | self.dilation_rate = dilation_rate 242 | self.n_layers = n_layers 243 | self.gin_channels = gin_channels 244 | 245 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 246 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 247 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 248 | 249 | def forward(self, x, x_lengths, g=None): 250 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 251 | x = self.pre(x) * x_mask 252 | x = self.enc(x, x_mask, g=g) 253 | stats = self.proj(x) * x_mask 254 | m, logs = torch.split(stats, self.out_channels, dim=1) 255 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 256 | return z, m, logs, x_mask 257 | 258 | 259 | class ResidualCouplingBlock(nn.Module): 260 | def __init__(self, 261 | channels, 262 | hidden_channels, 263 | kernel_size, 264 | dilation_rate, 265 | n_layers, 266 | n_flows=4, 267 | gin_channels=0): 268 | super().__init__() 269 | self.channels = channels 270 | self.hidden_channels = hidden_channels 271 | self.kernel_size = kernel_size 272 | self.dilation_rate = dilation_rate 273 | self.n_layers = n_layers 274 | self.n_flows = n_flows 275 | self.gin_channels = gin_channels 276 | 277 | self.flows = nn.ModuleList() 278 | for i in range(n_flows): 279 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 280 | self.flows.append(modules.Flip()) 281 | 282 | def forward(self, x, x_mask, g=None, reverse=False): 283 | if not reverse: 284 | for flow in self.flows: 285 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 286 | else: 287 | for flow in reversed(self.flows): 288 | x = flow(x, x_mask, g=g, reverse=reverse) 289 | return x 290 | 291 | 292 | class PosteriorEncoder(nn.Module): 293 | def __init__(self, 294 | in_channels, 295 | out_channels, 296 | hidden_channels, 297 | kernel_size, 298 | dilation_rate, 299 | n_layers, 300 | gin_channels=0): 301 | super().__init__() 302 | self.in_channels = in_channels 303 | self.out_channels = out_channels 304 | self.hidden_channels = hidden_channels 305 | self.kernel_size = kernel_size 306 | self.dilation_rate = dilation_rate 307 | self.n_layers = n_layers 308 | self.gin_channels = gin_channels 309 | 310 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 311 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 312 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 313 | 314 | def forward(self, x, x_lengths, g=None): 315 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 316 | x = self.pre(x) * x_mask 317 | x = self.enc(x, x_mask, g=g) 318 | stats = self.proj(x) * x_mask 319 | m, logs = torch.split(stats, self.out_channels, dim=1) 320 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 321 | return z, m, logs, x_mask 322 | 323 | class iSTFT_Generator(torch.nn.Module): 324 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, gin_channels=0): 325 | super(iSTFT_Generator, self).__init__() 326 | # self.h = h 327 | self.gen_istft_n_fft = gen_istft_n_fft 328 | self.gen_istft_hop_size = gen_istft_hop_size 329 | 330 | self.num_kernels = len(resblock_kernel_sizes) 331 | self.num_upsamples = len(upsample_rates) 332 | self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)) 333 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 334 | 335 | self.ups = nn.ModuleList() 336 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 337 | self.ups.append(weight_norm( 338 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 339 | k, u, padding=(k-u)//2))) 340 | 341 | self.resblocks = nn.ModuleList() 342 | for i in range(len(self.ups)): 343 | ch = upsample_initial_channel//(2**(i+1)) 344 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 345 | self.resblocks.append(resblock(ch, k, d)) 346 | 347 | self.post_n_fft = self.gen_istft_n_fft 348 | self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) 349 | self.ups.apply(init_weights) 350 | self.conv_post.apply(init_weights) 351 | self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) 352 | self.cond = nn.Conv1d(256, 512, 1) 353 | self.stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft) 354 | def forward(self, x, g=None): 355 | 356 | x = self.conv_pre(x) 357 | x = x + self.cond(g) 358 | for i in range(self.num_upsamples): 359 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 360 | x = self.ups[i](x) 361 | xs = None 362 | for j in range(self.num_kernels): 363 | if xs is None: 364 | xs = self.resblocks[i*self.num_kernels+j](x) 365 | else: 366 | xs += self.resblocks[i*self.num_kernels+j](x) 367 | x = xs / self.num_kernels 368 | x = F.leaky_relu(x) 369 | x = self.reflection_pad(x) 370 | x = self.conv_post(x) 371 | spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) 372 | phase = math.pi*torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) 373 | out = self.stft.inverse(spec, phase).to(x.device) 374 | return out, None 375 | 376 | def remove_weight_norm(self): 377 | print('Removing weight norm...') 378 | for l in self.ups: 379 | remove_weight_norm(l) 380 | for l in self.resblocks: 381 | l.remove_weight_norm() 382 | remove_weight_norm(self.conv_pre) 383 | remove_weight_norm(self.conv_post) 384 | 385 | 386 | class Multiband_iSTFT_Generator(torch.nn.Module): 387 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, subbands, gin_channels=0): 388 | super(Multiband_iSTFT_Generator, self).__init__() 389 | # self.h = h 390 | self.subbands = subbands 391 | self.num_kernels = len(resblock_kernel_sizes) 392 | self.num_upsamples = len(upsample_rates) 393 | self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)) 394 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 395 | 396 | self.ups = nn.ModuleList() 397 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 398 | self.ups.append(weight_norm( 399 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 400 | k, u, padding=(k-u+1-i)//2,output_padding=1-i))) 401 | 402 | self.resblocks = nn.ModuleList() 403 | for i in range(len(self.ups)): 404 | ch = upsample_initial_channel//(2**(i+1)) 405 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 406 | self.resblocks.append(resblock(ch, k, d)) 407 | 408 | self.post_n_fft = gen_istft_n_fft 409 | self.ups.apply(init_weights) 410 | self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) 411 | self.reshape_pixelshuffle = [] 412 | 413 | self.subband_conv_post = weight_norm(Conv1d(ch, self.subbands*(self.post_n_fft + 2), 7, 1, padding=3)) 414 | 415 | self.subband_conv_post.apply(init_weights) 416 | self.cond = nn.Conv1d(256, 512, 1) 417 | self.gen_istft_n_fft = gen_istft_n_fft 418 | self.gen_istft_hop_size = gen_istft_hop_size 419 | 420 | 421 | def forward(self, x, g=None): 422 | 423 | stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft).to(x.device) 424 | #print(x.device) 425 | pqmf = PQMF(x.device) 426 | 427 | x = self.conv_pre(x)#[B, ch, length] 428 | x = x + self.cond(g) 429 | for i in range(self.num_upsamples): 430 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 431 | x = self.ups[i](x) 432 | 433 | 434 | xs = None 435 | for j in range(self.num_kernels): 436 | if xs is None: 437 | xs = self.resblocks[i*self.num_kernels+j](x) 438 | else: 439 | xs += self.resblocks[i*self.num_kernels+j](x) 440 | x = xs / self.num_kernels 441 | 442 | x = F.leaky_relu(x) 443 | x = self.reflection_pad(x) 444 | x = self.subband_conv_post(x) 445 | x = torch.reshape(x, (x.shape[0], self.subbands, x.shape[1]//self.subbands, x.shape[-1])) 446 | 447 | spec = torch.exp(x[:,:,:self.post_n_fft // 2 + 1, :]) 448 | phase = math.pi*torch.sin(x[:,:, self.post_n_fft // 2 + 1:, :]) 449 | 450 | y_mb_hat = stft.inverse(torch.reshape(spec, (spec.shape[0]*self.subbands, self.gen_istft_n_fft // 2 + 1, spec.shape[-1])), torch.reshape(phase, (phase.shape[0]*self.subbands, self.gen_istft_n_fft // 2 + 1, phase.shape[-1]))) 451 | y_mb_hat = torch.reshape(y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])) 452 | y_mb_hat = y_mb_hat.squeeze(-2) 453 | 454 | y_g_hat = pqmf.synthesis(y_mb_hat) 455 | 456 | return y_g_hat, y_mb_hat 457 | 458 | def remove_weight_norm(self): 459 | print('Removing weight norm...') 460 | for l in self.ups: 461 | remove_weight_norm(l) 462 | for l in self.resblocks: 463 | l.remove_weight_norm() 464 | 465 | 466 | class Multistream_iSTFT_Generator(torch.nn.Module): 467 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, subbands, gin_channels=0): 468 | super(Multistream_iSTFT_Generator, self).__init__() 469 | # self.h = h 470 | self.subbands = subbands 471 | self.num_kernels = len(resblock_kernel_sizes) 472 | self.num_upsamples = len(upsample_rates) 473 | self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)) 474 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 475 | 476 | self.ups = nn.ModuleList() 477 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 478 | self.ups.append(weight_norm( 479 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 480 | #k, u, padding=(k-u+1-i)//2,output_padding=1-i)))#这里k和u不是成倍数的关系,对最终结果很有可能是有影响的,会有checkerboard artifacts的现象 481 | k, u, padding=(k-u)//2))) 482 | self.resblocks = nn.ModuleList() 483 | for i in range(len(self.ups)): 484 | ch = upsample_initial_channel//(2**(i+1)) 485 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 486 | self.resblocks.append(resblock(ch, k, d)) 487 | 488 | self.post_n_fft = gen_istft_n_fft 489 | self.ups.apply(init_weights) 490 | self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) 491 | self.reshape_pixelshuffle = [] 492 | 493 | self.subband_conv_post = weight_norm(Conv1d(ch, self.subbands*(self.post_n_fft + 2), 7, 1, padding=3)) 494 | 495 | self.subband_conv_post.apply(init_weights) 496 | 497 | self.gen_istft_n_fft = gen_istft_n_fft 498 | self.gen_istft_hop_size = gen_istft_hop_size 499 | 500 | updown_filter = torch.zeros((self.subbands, self.subbands, self.subbands)).float() 501 | for k in range(self.subbands): 502 | updown_filter[k, k, 0] = 1.0 503 | self.register_buffer("updown_filter", updown_filter) 504 | self.multistream_conv_post = weight_norm(Conv1d(self.subbands, 1, kernel_size=63, bias=False, padding=get_padding(63, 1))) 505 | self.multistream_conv_post.apply(init_weights) 506 | self.cond = nn.Conv1d(256, 512, 1) 507 | 508 | 509 | def forward(self, x, g=None): 510 | stft = TorchSTFT(filter_length=self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft).to(x.device) 511 | # pqmf = PQMF(x.device) 512 | 513 | x = self.conv_pre(x)#[B, ch, length] 514 | #print(x.size(),g.size()) 515 | x = x + self.cond(g) # g [b, 256, 1] => cond(g) [b, 512, 1] 516 | 517 | for i in range(self.num_upsamples): 518 | 519 | #print(x.size(),g.size()) 520 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 521 | #print(x.size(),g.size()) 522 | x = self.ups[i](x) 523 | 524 | #print(x.size(),g.size()) 525 | xs = None 526 | for j in range(self.num_kernels): 527 | if xs is None: 528 | xs = self.resblocks[i*self.num_kernels+j](x) 529 | else: 530 | xs += self.resblocks[i*self.num_kernels+j](x) 531 | x = xs / self.num_kernels 532 | #print(x.size(),g.size()) 533 | x = F.leaky_relu(x) 534 | x = self.reflection_pad(x) 535 | x = self.subband_conv_post(x) 536 | x = torch.reshape(x, (x.shape[0], self.subbands, x.shape[1]//self.subbands, x.shape[-1])) 537 | #print(x.size(),g.size()) 538 | spec = torch.exp(x[:,:,:self.post_n_fft // 2 + 1, :]) 539 | phase = math.pi*torch.sin(x[:,:, self.post_n_fft // 2 + 1:, :]) 540 | #print(spec.size(),phase.size()) 541 | y_mb_hat = stft.inverse(torch.reshape(spec, (spec.shape[0]*self.subbands, self.gen_istft_n_fft // 2 + 1, spec.shape[-1])), torch.reshape(phase, (phase.shape[0]*self.subbands, self.gen_istft_n_fft // 2 + 1, phase.shape[-1]))) 542 | #print(y_mb_hat.size()) 543 | y_mb_hat = torch.reshape(y_mb_hat, (x.shape[0], self.subbands, 1, y_mb_hat.shape[-1])) 544 | #print(y_mb_hat.size()) 545 | y_mb_hat = y_mb_hat.squeeze(-2) 546 | #print(y_mb_hat.size()) 547 | y_mb_hat = F.conv_transpose1d(y_mb_hat, self.updown_filter* self.subbands, stride=self.subbands)#.cuda(x.device) * self.subbands, stride=self.subbands) 548 | #print(y_mb_hat.size()) 549 | y_g_hat = self.multistream_conv_post(y_mb_hat) 550 | #print(y_g_hat.size(),y_mb_hat.size()) 551 | return y_g_hat, y_mb_hat 552 | 553 | def remove_weight_norm(self): 554 | print('Removing weight norm...') 555 | for l in self.ups: 556 | remove_weight_norm(l) 557 | for l in self.resblocks: 558 | l.remove_weight_norm() 559 | 560 | 561 | class DiscriminatorP(torch.nn.Module): 562 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 563 | super(DiscriminatorP, self).__init__() 564 | self.period = period 565 | self.use_spectral_norm = use_spectral_norm 566 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 567 | self.convs = nn.ModuleList([ 568 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 569 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 570 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 571 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 572 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 573 | ]) 574 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 575 | 576 | def forward(self, x): 577 | fmap = [] 578 | 579 | # 1d to 2d 580 | b, c, t = x.shape 581 | if t % self.period != 0: # pad first 582 | n_pad = self.period - (t % self.period) 583 | x = F.pad(x, (0, n_pad), "reflect") 584 | t = t + n_pad 585 | x = x.view(b, c, t // self.period, self.period) 586 | 587 | for l in self.convs: 588 | x = l(x) 589 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 590 | fmap.append(x) 591 | x = self.conv_post(x) 592 | fmap.append(x) 593 | x = torch.flatten(x, 1, -1) 594 | 595 | return x, fmap 596 | 597 | 598 | class DiscriminatorS(torch.nn.Module): 599 | def __init__(self, use_spectral_norm=False): 600 | super(DiscriminatorS, self).__init__() 601 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 602 | self.convs = nn.ModuleList([ 603 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 604 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 605 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 606 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 607 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 608 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 609 | ]) 610 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 611 | 612 | def forward(self, x): 613 | fmap = [] 614 | 615 | for l in self.convs: 616 | x = l(x) 617 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 618 | fmap.append(x) 619 | x = self.conv_post(x) 620 | fmap.append(x) 621 | x = torch.flatten(x, 1, -1) 622 | 623 | return x, fmap 624 | 625 | 626 | class MultiPeriodDiscriminator(torch.nn.Module): 627 | def __init__(self, use_spectral_norm=False): 628 | super(MultiPeriodDiscriminator, self).__init__() 629 | periods = [2,3,5,7,11] 630 | 631 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 632 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 633 | self.discriminators = nn.ModuleList(discs) 634 | 635 | def forward(self, y, y_hat): 636 | 637 | y_d_rs = [] 638 | y_d_gs = [] 639 | fmap_rs = [] 640 | fmap_gs = [] 641 | for i, d in enumerate(self.discriminators): 642 | y_d_r, fmap_r = d(y) 643 | y_d_g, fmap_g = d(y_hat) 644 | y_d_rs.append(y_d_r) 645 | y_d_gs.append(y_d_g) 646 | fmap_rs.append(fmap_r) 647 | fmap_gs.append(fmap_g) 648 | 649 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 650 | 651 | class SpeakerEncoder(torch.nn.Module): 652 | def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): 653 | super(SpeakerEncoder, self).__init__() 654 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) 655 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 656 | self.relu = nn.ReLU() 657 | 658 | def forward(self, mels): 659 | self.lstm.flatten_parameters() 660 | _, (hidden, _) = self.lstm(mels) 661 | embeds_raw = self.relu(self.linear(hidden[-1])) 662 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 663 | 664 | def compute_partial_slices(self, total_frames, partial_frames, partial_hop): 665 | mel_slices = [] 666 | for i in range(0, total_frames-partial_frames, partial_hop): 667 | mel_range = torch.arange(i, i+partial_frames) 668 | mel_slices.append(mel_range) 669 | 670 | return mel_slices 671 | 672 | def embed_utterance(self, mel, partial_frames=128, partial_hop=64): 673 | mel_len = mel.size(1) 674 | last_mel = mel[:,-partial_frames:] 675 | 676 | if mel_len > partial_frames: 677 | mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) 678 | mels = list(mel[:,s] for s in mel_slices) 679 | mels.append(last_mel) 680 | mels = torch.stack(tuple(mels), 0).squeeze(1) 681 | 682 | with torch.no_grad(): 683 | partial_embeds = self(mels) 684 | embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) 685 | #embed = embed / torch.linalg.norm(embed, 2) 686 | else: 687 | with torch.no_grad(): 688 | embed = self(last_mel) 689 | 690 | return embed 691 | 692 | class SynthesizerTrn(nn.Module): 693 | """ 694 | Synthesizer for Training 695 | """ 696 | 697 | def __init__(self, 698 | spec_channels, 699 | segment_size, 700 | inter_channels, 701 | hidden_channels, 702 | filter_channels, 703 | n_heads, 704 | n_layers, 705 | kernel_size, 706 | p_dropout, 707 | resblock, 708 | resblock_kernel_sizes, 709 | resblock_dilation_sizes, 710 | upsample_rates, 711 | upsample_initial_channel, 712 | upsample_kernel_sizes, 713 | gen_istft_n_fft, 714 | gen_istft_hop_size, 715 | n_speakers=0, 716 | gin_channels=0, 717 | use_sdp=False, 718 | ms_istft_vits=False, 719 | mb_istft_vits = False, 720 | subbands = False, 721 | istft_vits=False, 722 | **kwargs): 723 | 724 | super().__init__() 725 | self.spec_channels = spec_channels 726 | self.inter_channels = inter_channels 727 | self.hidden_channels = hidden_channels 728 | self.filter_channels = filter_channels 729 | self.n_heads = n_heads 730 | self.n_layers = n_layers 731 | self.kernel_size = kernel_size 732 | self.p_dropout = p_dropout 733 | self.resblock = resblock 734 | self.resblock_kernel_sizes = resblock_kernel_sizes 735 | self.resblock_dilation_sizes = resblock_dilation_sizes 736 | self.upsample_rates = upsample_rates 737 | self.upsample_initial_channel = upsample_initial_channel 738 | self.upsample_kernel_sizes = upsample_kernel_sizes 739 | self.segment_size = segment_size 740 | self.n_speakers = n_speakers 741 | self.gin_channels = gin_channels 742 | self.ms_istft_vits = ms_istft_vits 743 | self.mb_istft_vits = mb_istft_vits 744 | self.istft_vits = istft_vits 745 | 746 | self.use_sdp = use_sdp 747 | 748 | #self.enc_p = PosteriorEncoder(768, inter_channels, hidden_channels, 5, 1, 16)#768, inter_channels, hidden_channels, 5, 1, 16) 749 | 750 | self.enc_p = TextEncoder( 751 | inter_channels, 752 | hidden_channels, 753 | filter_channels=filter_channels, 754 | n_heads=n_heads, 755 | n_layers=n_layers, 756 | kernel_size=kernel_size, 757 | p_dropout=p_dropout 758 | ) 759 | 760 | if mb_istft_vits == True: 761 | print('Mutli-band iSTFT VITS') 762 | self.dec = Multiband_iSTFT_Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, subbands, gin_channels=gin_channels) 763 | elif ms_istft_vits == True: 764 | print('Mutli-stream iSTFT VITS') 765 | self.dec = Multistream_iSTFT_Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, subbands, gin_channels=gin_channels) 766 | elif istft_vits == True: 767 | print('iSTFT-VITS') 768 | self.dec = iSTFT_Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, gin_channels=gin_channels) 769 | else: 770 | print('Decoder Error in json file') 771 | 772 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) 773 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 774 | 775 | self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels) 776 | 777 | def forward(self, c, spec, g=None, mel=None, f0=None, uv=None, c_lengths=None, spec_lengths=None): 778 | if c_lengths == None: 779 | c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) 780 | if spec_lengths == None: 781 | spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) 782 | 783 | g = self.enc_spk(mel.transpose(1,2)) 784 | g = g.unsqueeze(-1) 785 | 786 | _, m_p, logs_p, _ = self.enc_p(c, c_lengths, f0=f0_to_coarse(f0)) 787 | z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g) 788 | z_p = self.flow(z, spec_mask, g=g) 789 | 790 | z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) 791 | o, o_mb = self.dec(z_slice, g=g) 792 | 793 | return o, o_mb, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) 794 | 795 | def infer(self, c, g=None, mel=None, c_lengths=None, f0=None, uv=None,): 796 | if c_lengths == None: 797 | c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) 798 | g = self.enc_spk.embed_utterance(mel.transpose(1,2)) 799 | g = g.unsqueeze(-1) 800 | 801 | z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths, f0=f0_to_coarse(f0)) 802 | z = self.flow(z_p, c_mask, g=g, reverse=True) 803 | o,o_mb = self.dec(z * c_mask, g=g) 804 | 805 | return o 806 | 807 | if __name__ == "__main__": 808 | x = torch.rand(size=(4, 768,500)) 809 | c_len = torch.FloatTensor(torch.rand(4)) 810 | f0 = torch.zeros(size=(4,500),dtype=torch.long) 811 | 812 | enc_p = TextEncoder( 813 | out_channels=192, 814 | hidden_channels=192, 815 | filter_channels=768, 816 | n_heads=2, 817 | n_layers=6, 818 | kernel_size=3, 819 | p_dropout=0.1 820 | ).to("cpu") 821 | 822 | x = enc_p(x, c_len, f0, noice_scale=1) -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import commons 13 | from commons import init_weights, get_padding 14 | from transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.hidden_channels = hidden_channels 40 | self.out_channels = out_channels 41 | self.kernel_size = kernel_size 42 | self.n_layers = n_layers 43 | self.p_dropout = p_dropout 44 | assert n_layers > 1, "Number of layers should be larger than 0." 45 | 46 | self.conv_layers = nn.ModuleList() 47 | self.norm_layers = nn.ModuleList() 48 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 49 | self.norm_layers.append(LayerNorm(hidden_channels)) 50 | self.relu_drop = nn.Sequential( 51 | nn.ReLU(), 52 | nn.Dropout(p_dropout)) 53 | for _ in range(n_layers-1): 54 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 55 | self.norm_layers.append(LayerNorm(hidden_channels)) 56 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 57 | self.proj.weight.data.zero_() 58 | self.proj.bias.data.zero_() 59 | 60 | def forward(self, x, x_mask): 61 | x_org = x 62 | for i in range(self.n_layers): 63 | x = self.conv_layers[i](x * x_mask) 64 | x = self.norm_layers[i](x) 65 | x = self.relu_drop(x) 66 | x = x_org + self.proj(x) 67 | return x * x_mask 68 | 69 | 70 | class DDSConv(nn.Module): 71 | """ 72 | Dialted and Depth-Separable Convolution 73 | """ 74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 75 | super().__init__() 76 | self.channels = channels 77 | self.kernel_size = kernel_size 78 | self.n_layers = n_layers 79 | self.p_dropout = p_dropout 80 | 81 | self.drop = nn.Dropout(p_dropout) 82 | self.convs_sep = nn.ModuleList() 83 | self.convs_1x1 = nn.ModuleList() 84 | self.norms_1 = nn.ModuleList() 85 | self.norms_2 = nn.ModuleList() 86 | for i in range(n_layers): 87 | dilation = kernel_size ** i 88 | padding = (kernel_size * dilation - dilation) // 2 89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 90 | groups=channels, dilation=dilation, padding=padding 91 | )) 92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 93 | self.norms_1.append(LayerNorm(channels)) 94 | self.norms_2.append(LayerNorm(channels)) 95 | 96 | def forward(self, x, x_mask, g=None): 97 | if g is not None: 98 | x = x + g 99 | for i in range(self.n_layers): 100 | y = self.convs_sep[i](x * x_mask) 101 | y = self.norms_1[i](y) 102 | y = F.gelu(y) 103 | y = self.convs_1x1[i](y) 104 | y = self.norms_2[i](y) 105 | y = F.gelu(y) 106 | y = self.drop(y) 107 | x = x + y 108 | return x * x_mask 109 | 110 | 111 | class WN(torch.nn.Module): 112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 113 | super(WN, self).__init__() 114 | assert(kernel_size % 2 == 1) 115 | self.hidden_channels =hidden_channels 116 | self.kernel_size = kernel_size, 117 | self.dilation_rate = dilation_rate 118 | self.n_layers = n_layers 119 | self.gin_channels = gin_channels 120 | self.p_dropout = p_dropout 121 | 122 | self.in_layers = torch.nn.ModuleList() 123 | self.res_skip_layers = torch.nn.ModuleList() 124 | self.drop = nn.Dropout(p_dropout) 125 | 126 | if gin_channels != 0: 127 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 129 | 130 | for i in range(n_layers): 131 | dilation = dilation_rate ** i 132 | padding = int((kernel_size * dilation - dilation) / 2) 133 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 134 | dilation=dilation, padding=padding) 135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 136 | self.in_layers.append(in_layer) 137 | 138 | # last one is not necessary 139 | if i < n_layers - 1: 140 | res_skip_channels = 2 * hidden_channels 141 | else: 142 | res_skip_channels = hidden_channels 143 | 144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 146 | self.res_skip_layers.append(res_skip_layer) 147 | 148 | def forward(self, x, x_mask, g=None, **kwargs): 149 | output = torch.zeros_like(x) 150 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 151 | 152 | if g is not None: 153 | g = self.cond_layer(g) 154 | 155 | for i in range(self.n_layers): 156 | x_in = self.in_layers[i](x) 157 | if g is not None: 158 | cond_offset = i * 2 * self.hidden_channels 159 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 160 | else: 161 | g_l = torch.zeros_like(x_in) 162 | 163 | acts = commons.fused_add_tanh_sigmoid_multiply( 164 | x_in, 165 | g_l, 166 | n_channels_tensor) 167 | acts = self.drop(acts) 168 | 169 | res_skip_acts = self.res_skip_layers[i](acts) 170 | if i < self.n_layers - 1: 171 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 172 | x = (x + res_acts) * x_mask 173 | output = output + res_skip_acts[:,self.hidden_channels:,:] 174 | else: 175 | output = output + res_skip_acts 176 | return output * x_mask 177 | 178 | def remove_weight_norm(self): 179 | if self.gin_channels != 0: 180 | torch.nn.utils.remove_weight_norm(self.cond_layer) 181 | for l in self.in_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | for l in self.res_skip_layers: 184 | torch.nn.utils.remove_weight_norm(l) 185 | 186 | 187 | class ResBlock1(torch.nn.Module): 188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 189 | super(ResBlock1, self).__init__() 190 | self.convs1 = nn.ModuleList([ 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 192 | padding=get_padding(kernel_size, dilation[0]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 194 | padding=get_padding(kernel_size, dilation[1]))), 195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 196 | padding=get_padding(kernel_size, dilation[2]))) 197 | ]) 198 | self.convs1.apply(init_weights) 199 | 200 | self.convs2 = nn.ModuleList([ 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))), 205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 206 | padding=get_padding(kernel_size, 1))) 207 | ]) 208 | self.convs2.apply(init_weights) 209 | 210 | def forward(self, x, x_mask=None): 211 | for c1, c2 in zip(self.convs1, self.convs2): 212 | xt = F.leaky_relu(x, LRELU_SLOPE) 213 | if x_mask is not None: 214 | xt = xt * x_mask 215 | xt = c1(xt) 216 | xt = F.leaky_relu(xt, LRELU_SLOPE) 217 | #print(xt.size()) 218 | if x_mask is not None: 219 | xt = xt * x_mask 220 | xt = c2(xt) 221 | #print(xt.size()) 222 | x = xt + x 223 | if x_mask is not None: 224 | x = x * x_mask 225 | return x 226 | 227 | def remove_weight_norm(self): 228 | for l in self.convs1: 229 | remove_weight_norm(l) 230 | for l in self.convs2: 231 | remove_weight_norm(l) 232 | 233 | 234 | class ResBlock2(torch.nn.Module): 235 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 236 | super(ResBlock2, self).__init__() 237 | self.convs = nn.ModuleList([ 238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 239 | padding=get_padding(kernel_size, dilation[0]))), 240 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 241 | padding=get_padding(kernel_size, dilation[1]))) 242 | ]) 243 | self.convs.apply(init_weights) 244 | 245 | def forward(self, x, x_mask=None): 246 | for c in self.convs: 247 | xt = F.leaky_relu(x, LRELU_SLOPE) 248 | if x_mask is not None: 249 | xt = xt * x_mask 250 | xt = c(xt) 251 | x = xt + x 252 | if x_mask is not None: 253 | x = x * x_mask 254 | return x 255 | 256 | def remove_weight_norm(self): 257 | for l in self.convs: 258 | remove_weight_norm(l) 259 | 260 | 261 | class Log(nn.Module): 262 | def forward(self, x, x_mask, reverse=False, **kwargs): 263 | if not reverse: 264 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 265 | logdet = torch.sum(-y, [1, 2]) 266 | return y, logdet 267 | else: 268 | x = torch.exp(x) * x_mask 269 | return x 270 | 271 | 272 | class Flip(nn.Module): 273 | def forward(self, x, *args, reverse=False, **kwargs): 274 | x = torch.flip(x, [1]) 275 | if not reverse: 276 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 277 | return x, logdet 278 | else: 279 | return x 280 | 281 | 282 | class ElementwiseAffine(nn.Module): 283 | def __init__(self, channels): 284 | super().__init__() 285 | self.channels = channels 286 | self.m = nn.Parameter(torch.zeros(channels,1)) 287 | self.logs = nn.Parameter(torch.zeros(channels,1)) 288 | 289 | def forward(self, x, x_mask, reverse=False, **kwargs): 290 | if not reverse: 291 | y = self.m + torch.exp(self.logs) * x 292 | y = y * x_mask 293 | logdet = torch.sum(self.logs * x_mask, [1,2]) 294 | return y, logdet 295 | else: 296 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 297 | return x 298 | 299 | 300 | class ResidualCouplingLayer(nn.Module): 301 | def __init__(self, 302 | channels, 303 | hidden_channels, 304 | kernel_size, 305 | dilation_rate, 306 | n_layers, 307 | p_dropout=0, 308 | gin_channels=0, 309 | mean_only=False): 310 | assert channels % 2 == 0, "channels should be divisible by 2" 311 | super().__init__() 312 | self.channels = channels 313 | self.hidden_channels = hidden_channels 314 | self.kernel_size = kernel_size 315 | self.dilation_rate = dilation_rate 316 | self.n_layers = n_layers 317 | self.half_channels = channels // 2 318 | self.mean_only = mean_only 319 | 320 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 321 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 322 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 323 | self.post.weight.data.zero_() 324 | self.post.bias.data.zero_() 325 | 326 | def forward(self, x, x_mask, g=None, reverse=False): 327 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 328 | h = self.pre(x0) * x_mask 329 | h = self.enc(h, x_mask, g=g) 330 | stats = self.post(h) * x_mask 331 | if not self.mean_only: 332 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 333 | else: 334 | m = stats 335 | logs = torch.zeros_like(m) 336 | 337 | if not reverse: 338 | x1 = m + x1 * torch.exp(logs) * x_mask 339 | x = torch.cat([x0, x1], 1) 340 | logdet = torch.sum(logs, [1,2]) 341 | return x, logdet 342 | else: 343 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 344 | x = torch.cat([x0, x1], 1) 345 | return x 346 | 347 | 348 | class ConvFlow(nn.Module): 349 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 350 | super().__init__() 351 | self.in_channels = in_channels 352 | self.filter_channels = filter_channels 353 | self.kernel_size = kernel_size 354 | self.n_layers = n_layers 355 | self.num_bins = num_bins 356 | self.tail_bound = tail_bound 357 | self.half_channels = in_channels // 2 358 | 359 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 360 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 361 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 362 | self.proj.weight.data.zero_() 363 | self.proj.bias.data.zero_() 364 | 365 | def forward(self, x, x_mask, g=None, reverse=False): 366 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 367 | h = self.pre(x0) 368 | h = self.convs(h, x_mask, g=g) 369 | h = self.proj(h) * x_mask 370 | 371 | b, c, t = x0.shape 372 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 373 | 374 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 375 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 376 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 377 | 378 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 379 | unnormalized_widths, 380 | unnormalized_heights, 381 | unnormalized_derivatives, 382 | inverse=reverse, 383 | tails='linear', 384 | tail_bound=self.tail_bound 385 | ) 386 | 387 | x = torch.cat([x0, x1], 1) * x_mask 388 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 389 | if not reverse: 390 | return x, logdet 391 | else: 392 | return x 393 | -------------------------------------------------------------------------------- /pqmf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Pseudo QMF modules.""" 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from scipy.signal import kaiser 13 | 14 | 15 | def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0): 16 | """Design prototype filter for PQMF. 17 | This method is based on `A Kaiser window approach for the design of prototype 18 | filters of cosine modulated filterbanks`_. 19 | Args: 20 | taps (int): The number of filter taps. 21 | cutoff_ratio (float): Cut-off frequency ratio. 22 | beta (float): Beta coefficient for kaiser window. 23 | Returns: 24 | ndarray: Impluse response of prototype filter (taps + 1,). 25 | .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: 26 | https://ieeexplore.ieee.org/abstract/document/681427 27 | """ 28 | # check the arguments are valid 29 | assert taps % 2 == 0, "The number of taps mush be even number." 30 | assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." 31 | 32 | # make initial filter 33 | omega_c = np.pi * cutoff_ratio 34 | with np.errstate(invalid='ignore'): 35 | h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \ 36 | / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) 37 | h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form 38 | 39 | # apply kaiser window 40 | w = kaiser(taps + 1, beta) 41 | h = h_i * w 42 | 43 | return h 44 | 45 | 46 | class PQMF(torch.nn.Module): 47 | """PQMF module. 48 | This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. 49 | .. _`Near-perfect-reconstruction pseudo-QMF banks`: 50 | https://ieeexplore.ieee.org/document/258122 51 | """ 52 | 53 | def __init__(self, device, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0): 54 | """Initilize PQMF module. 55 | Args: 56 | subbands (int): The number of subbands. 57 | taps (int): The number of filter taps. 58 | cutoff_ratio (float): Cut-off frequency ratio. 59 | beta (float): Beta coefficient for kaiser window. 60 | """ 61 | super(PQMF, self).__init__() 62 | 63 | # define filter coefficient 64 | h_proto = design_prototype_filter(taps, cutoff_ratio, beta) 65 | h_analysis = np.zeros((subbands, len(h_proto))) 66 | h_synthesis = np.zeros((subbands, len(h_proto))) 67 | for k in range(subbands): 68 | h_analysis[k] = 2 * h_proto * np.cos( 69 | (2 * k + 1) * (np.pi / (2 * subbands)) * 70 | (np.arange(taps + 1) - ((taps - 1) / 2)) + 71 | (-1) ** k * np.pi / 4) 72 | h_synthesis[k] = 2 * h_proto * np.cos( 73 | (2 * k + 1) * (np.pi / (2 * subbands)) * 74 | (np.arange(taps + 1) - ((taps - 1) / 2)) - 75 | (-1) ** k * np.pi / 4) 76 | 77 | # convert to tensor 78 | analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1).cuda(device) 79 | synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0).cuda(device) 80 | 81 | # register coefficients as beffer 82 | self.register_buffer("analysis_filter", analysis_filter) 83 | self.register_buffer("synthesis_filter", synthesis_filter) 84 | 85 | # filter for downsampling & upsampling 86 | updown_filter = torch.zeros((subbands, subbands, subbands)).float().cuda(device) 87 | for k in range(subbands): 88 | updown_filter[k, k, 0] = 1.0 89 | self.register_buffer("updown_filter", updown_filter) 90 | self.subbands = subbands 91 | 92 | # keep padding info 93 | self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) 94 | 95 | def analysis(self, x): 96 | """Analysis with PQMF. 97 | Args: 98 | x (Tensor): Input tensor (B, 1, T). 99 | Returns: 100 | Tensor: Output tensor (B, subbands, T // subbands). 101 | """ 102 | x = F.conv1d(self.pad_fn(x), self.analysis_filter) 103 | return F.conv1d(x, self.updown_filter, stride=self.subbands) 104 | 105 | def synthesis(self, x): 106 | """Synthesis with PQMF. 107 | Args: 108 | x (Tensor): Input tensor (B, subbands, T // subbands). 109 | Returns: 110 | Tensor: Output tensor (B, 1, T). 111 | """ 112 | # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. 113 | # Not sure this is the correct way, it is better to check again. 114 | # TODO(kan-bayashi): Understand the reconstruction procedure 115 | x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) 116 | return F.conv1d(self.pad_fn(x), self.synthesis_filter) -------------------------------------------------------------------------------- /qvcfinalwhite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/QuickVC-44100-Ja_HuBERT/d46f661c81b05fd9b8a38e54570492bfbeb40ddf/qvcfinalwhite.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | transformers 3 | scikit-learn 4 | tensorboard 5 | librosa==0.9.1 6 | matplotlib 7 | torchcrepe 8 | pyworld 9 | praat-parselmouth 10 | soundcard -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | Copyright (c) 2017, Prem Seetharaman 4 | All rights reserved. 5 | * Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its 13 | contributors may be used to endorse or promote products derived from this 14 | software without specific prior written permission. 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | import torch 28 | import numpy as np 29 | import torch.nn.functional as F 30 | from torch.autograd import Variable 31 | from scipy.signal import get_window 32 | from librosa.util import pad_center, tiny 33 | import librosa.util as librosa_util 34 | 35 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 36 | n_fft=800, dtype=np.float32, norm=None): 37 | """ 38 | # from librosa 0.6 39 | Compute the sum-square envelope of a window function at a given hop length. 40 | This is used to estimate modulation effects induced by windowing 41 | observations in short-time fourier transforms. 42 | Parameters 43 | ---------- 44 | window : string, tuple, number, callable, or list-like 45 | Window specification, as in `get_window` 46 | n_frames : int > 0 47 | The number of analysis frames 48 | hop_length : int > 0 49 | The number of samples to advance between frames 50 | win_length : [optional] 51 | The length of the window function. By default, this matches `n_fft`. 52 | n_fft : int > 0 53 | The length of each analysis frame. 54 | dtype : np.dtype 55 | The data type of the output 56 | Returns 57 | ------- 58 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 59 | The sum-squared envelope of the window function 60 | """ 61 | if win_length is None: 62 | win_length = n_fft 63 | 64 | n = n_fft + hop_length * (n_frames - 1) 65 | x = np.zeros(n, dtype=dtype) 66 | 67 | # Compute the squared window at the desired length 68 | win_sq = get_window(window, win_length, fftbins=True) 69 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 70 | win_sq = librosa_util.pad_center(win_sq, n_fft) 71 | 72 | # Fill the envelope 73 | for i in range(n_frames): 74 | sample = i * hop_length 75 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 76 | return x 77 | 78 | 79 | class STFT(torch.nn.Module): 80 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 81 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 82 | window='hann'): 83 | super(STFT, self).__init__() 84 | self.filter_length = filter_length 85 | self.hop_length = hop_length 86 | self.win_length = win_length 87 | self.window = window 88 | self.forward_transform = None 89 | scale = self.filter_length / self.hop_length 90 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 91 | 92 | cutoff = int((self.filter_length / 2 + 1)) 93 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 94 | np.imag(fourier_basis[:cutoff, :])]) 95 | 96 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 97 | inverse_basis = torch.FloatTensor( 98 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 99 | 100 | if window is not None: 101 | assert(filter_length >= win_length) 102 | # get window and zero center pad it to filter_length 103 | fft_window = get_window(window, win_length, fftbins=True) 104 | fft_window = pad_center(fft_window, filter_length) 105 | fft_window = torch.from_numpy(fft_window).float() 106 | 107 | # window the bases 108 | forward_basis *= fft_window 109 | inverse_basis *= fft_window 110 | 111 | self.register_buffer('forward_basis', forward_basis.float()) 112 | self.register_buffer('inverse_basis', inverse_basis.float()) 113 | 114 | def transform(self, input_data): 115 | num_batches = input_data.size(0) 116 | num_samples = input_data.size(1) 117 | 118 | self.num_samples = num_samples 119 | 120 | # similar to librosa, reflect-pad the input 121 | input_data = input_data.view(num_batches, 1, num_samples) 122 | input_data = F.pad( 123 | input_data.unsqueeze(1), 124 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 125 | mode='reflect') 126 | input_data = input_data.squeeze(1) 127 | 128 | forward_transform = F.conv1d( 129 | input_data, 130 | Variable(self.forward_basis, requires_grad=False), 131 | stride=self.hop_length, 132 | padding=0) 133 | 134 | cutoff = int((self.filter_length / 2) + 1) 135 | real_part = forward_transform[:, :cutoff, :] 136 | imag_part = forward_transform[:, cutoff:, :] 137 | 138 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 139 | phase = torch.autograd.Variable( 140 | torch.atan2(imag_part.data, real_part.data)) 141 | 142 | return magnitude, phase 143 | 144 | def inverse(self, magnitude, phase): 145 | recombine_magnitude_phase = torch.cat( 146 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 147 | 148 | inverse_transform = F.conv_transpose1d( 149 | recombine_magnitude_phase, 150 | Variable(self.inverse_basis, requires_grad=False), 151 | stride=self.hop_length, 152 | padding=0) 153 | 154 | if self.window is not None: 155 | window_sum = window_sumsquare( 156 | self.window, magnitude.size(-1), hop_length=self.hop_length, 157 | win_length=self.win_length, n_fft=self.filter_length, 158 | dtype=np.float32) 159 | # remove modulation effects 160 | approx_nonzero_indices = torch.from_numpy( 161 | np.where(window_sum > tiny(window_sum))[0]) 162 | window_sum = torch.autograd.Variable( 163 | torch.from_numpy(window_sum), requires_grad=False) 164 | window_sum = window_sum.to(inverse_transform.device()) if magnitude.is_cuda else window_sum 165 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 166 | 167 | # scale by hop ratio 168 | inverse_transform *= float(self.filter_length) / self.hop_length 169 | 170 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 171 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 172 | 173 | return inverse_transform 174 | 175 | def forward(self, input_data): 176 | self.magnitude, self.phase = self.transform(input_data) 177 | reconstruction = self.inverse(self.magnitude, self.phase) 178 | return reconstruction 179 | 180 | 181 | class TorchSTFT(torch.nn.Module): 182 | def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): 183 | super().__init__() 184 | self.filter_length = filter_length 185 | self.hop_length = hop_length 186 | self.win_length = win_length 187 | self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32)) 188 | 189 | def transform(self, input_data): 190 | forward_transform = torch.stft( 191 | input_data, 192 | self.filter_length, self.hop_length, self.win_length, window=self.window, 193 | return_complex=True) 194 | 195 | return torch.abs(forward_transform), torch.angle(forward_transform) 196 | 197 | def inverse(self, magnitude, phase): 198 | inverse_transform = torch.istft( 199 | magnitude * torch.exp(phase * 1j), 200 | self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device)) 201 | 202 | return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation 203 | 204 | def forward(self, input_data): 205 | self.magnitude, self.phase = self.transform(input_data) 206 | reconstruction = self.inverse(self.magnitude, self.phase) 207 | return reconstruction 208 | 209 | 210 | -------------------------------------------------------------------------------- /stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | def stft(x, fft_size, hop_size, win_length, window): 13 | """Perform STFT and convert to magnitude spectrogram. 14 | Args: 15 | x (Tensor): Input signal tensor (B, T). 16 | fft_size (int): FFT size. 17 | hop_size (int): Hop size. 18 | win_length (int): Window length. 19 | window (str): Window function type. 20 | Returns: 21 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 22 | """ 23 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window.to(x.device)) 24 | real = x_stft[..., 0] 25 | imag = x_stft[..., 1] 26 | 27 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 28 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 29 | 30 | 31 | class SpectralConvergengeLoss(torch.nn.Module): 32 | """Spectral convergence loss module.""" 33 | 34 | def __init__(self): 35 | """Initilize spectral convergence loss module.""" 36 | super(SpectralConvergengeLoss, self).__init__() 37 | 38 | def forward(self, x_mag, y_mag): 39 | """Calculate forward propagation. 40 | Args: 41 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 42 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 43 | Returns: 44 | Tensor: Spectral convergence loss value. 45 | """ 46 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 47 | 48 | 49 | class LogSTFTMagnitudeLoss(torch.nn.Module): 50 | """Log STFT magnitude loss module.""" 51 | 52 | def __init__(self): 53 | """Initilize los STFT magnitude loss module.""" 54 | super(LogSTFTMagnitudeLoss, self).__init__() 55 | 56 | def forward(self, x_mag, y_mag): 57 | """Calculate forward propagation. 58 | Args: 59 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 60 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 61 | Returns: 62 | Tensor: Log STFT magnitude loss value. 63 | """ 64 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 65 | 66 | 67 | class STFTLoss(torch.nn.Module): 68 | """STFT loss module.""" 69 | 70 | def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): 71 | """Initialize STFT loss module.""" 72 | super(STFTLoss, self).__init__() 73 | self.fft_size = fft_size 74 | self.shift_size = shift_size 75 | self.win_length = win_length 76 | self.window = getattr(torch, window)(win_length) 77 | self.spectral_convergenge_loss = SpectralConvergengeLoss() 78 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 79 | 80 | def forward(self, x, y): 81 | """Calculate forward propagation. 82 | Args: 83 | x (Tensor): Predicted signal (B, T). 84 | y (Tensor): Groundtruth signal (B, T). 85 | Returns: 86 | Tensor: Spectral convergence loss value. 87 | Tensor: Log STFT magnitude loss value. 88 | """ 89 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 90 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 91 | sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) 92 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 93 | 94 | return sc_loss, mag_loss 95 | 96 | 97 | class MultiResolutionSTFTLoss(torch.nn.Module): 98 | """Multi resolution STFT loss module.""" 99 | 100 | def __init__(self, 101 | fft_sizes=[1024, 2048, 512], 102 | hop_sizes=[120, 240, 50], 103 | win_lengths=[600, 1200, 240], 104 | window="hann_window"): 105 | """Initialize Multi resolution STFT loss module. 106 | Args: 107 | fft_sizes (list): List of FFT sizes. 108 | hop_sizes (list): List of hop sizes. 109 | win_lengths (list): List of window lengths. 110 | window (str): Window function type. 111 | """ 112 | super(MultiResolutionSTFTLoss, self).__init__() 113 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 114 | self.stft_losses = torch.nn.ModuleList() 115 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 116 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 117 | 118 | def forward(self, x, y): 119 | """Calculate forward propagation. 120 | Args: 121 | x (Tensor): Predicted signal (B, T). 122 | y (Tensor): Groundtruth signal (B, T). 123 | Returns: 124 | Tensor: Multi resolution spectral convergence loss value. 125 | Tensor: Multi resolution log STFT magnitude loss value. 126 | """ 127 | sc_loss = 0.0 128 | mag_loss = 0.0 129 | for f in self.stft_losses: 130 | sc_l, mag_l = f(x, y) 131 | sc_loss += sc_l 132 | mag_loss += mag_l 133 | sc_loss /= len(self.stft_losses) 134 | mag_loss /= len(self.stft_losses) 135 | 136 | return sc_loss, mag_loss -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import itertools 5 | import math 6 | import torch 7 | from torch import nn, optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch.multiprocessing as mp 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.cuda.amp import autocast, GradScaler 15 | from pqmf import PQMF 16 | 17 | import commons 18 | import utils 19 | 20 | from data_utils_new_new import ( 21 | TextAudioSpeakerLoader, 22 | TextAudioSpeakerCollate, 23 | DistributedBucketSampler 24 | ) 25 | from models import ( 26 | SynthesizerTrn, 27 | MultiPeriodDiscriminator, 28 | ) 29 | from losses import ( 30 | generator_loss, 31 | discriminator_loss, 32 | feature_loss, 33 | kl_loss, 34 | subband_stft_loss 35 | ) 36 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch 37 | #from text.symbols import symbols 38 | 39 | torch.autograd.set_detect_anomaly(True) 40 | torch.backends.cudnn.benchmark = True 41 | global_step = 0 42 | 43 | 44 | def main(): 45 | """Assume Single Node Multi GPUs Training Only""" 46 | assert torch.cuda.is_available(), "CPU training is not allowed." 47 | 48 | n_gpus = torch.cuda.device_count() 49 | os.environ['MASTER_ADDR'] = 'localhost' 50 | os.environ['MASTER_PORT'] = '65520' 51 | # n_gpus = 1 52 | 53 | hps = utils.get_hparams() 54 | run(0,1,hps) 55 | #mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 56 | 57 | 58 | def run(rank, n_gpus, hps): 59 | global global_step 60 | if rank == 0: 61 | logger = utils.get_logger(hps.model_dir) 62 | logger.info(hps) 63 | utils.check_git_hash(hps.model_dir) 64 | writer = SummaryWriter(log_dir=hps.model_dir) 65 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 66 | 67 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 68 | torch.manual_seed(hps.train.seed) 69 | torch.cuda.set_device(rank) 70 | 71 | 72 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) 73 | train_sampler = DistributedBucketSampler( 74 | train_dataset, 75 | hps.train.batch_size, 76 | [32,40,50,60,70,80,90,100,110,120,160,200,230,260,300,350,400,450,500,600,700,800,900,1000], 77 | num_replicas=n_gpus, 78 | rank=rank, 79 | shuffle=True) 80 | collate_fn = TextAudioSpeakerCollate(hps) 81 | train_loader = DataLoader(train_dataset, num_workers=24, shuffle=False, pin_memory=True, 82 | collate_fn=collate_fn, batch_sampler=train_sampler) 83 | if rank == 0: 84 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps) 85 | eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=True, 86 | batch_size=1, pin_memory=False, 87 | drop_last=False) 88 | 89 | net_g = SynthesizerTrn( 90 | hps.data.filter_length // 2 + 1, 91 | hps.train.segment_size // hps.data.hop_length, 92 | **hps.model).cuda(rank) 93 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 94 | optim_g = torch.optim.AdamW( 95 | net_g.parameters(), 96 | hps.train.learning_rate, 97 | betas=hps.train.betas, 98 | eps=hps.train.eps) 99 | optim_d = torch.optim.AdamW( 100 | net_d.parameters(), 101 | hps.train.learning_rate, 102 | betas=hps.train.betas, 103 | eps=hps.train.eps) 104 | net_g = DDP(net_g, device_ids=[rank]) 105 | net_d = DDP(net_d, device_ids=[rank]) 106 | 107 | try: 108 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) 109 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) 110 | global_step = (epoch_str - 1) * len(train_loader) 111 | except: 112 | epoch_str = 1 113 | global_step = 0 114 | 115 | try: 116 | net_g = utils.load_model_diffsize(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, hps, optim_g) 117 | net_d = utils.load_model_diffsize(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, hps, optim_d) 118 | except: 119 | pass 120 | 121 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) 122 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) 123 | 124 | scaler = GradScaler(enabled=hps.train.fp16_run) 125 | 126 | for epoch in range(epoch_str, hps.train.epochs + 1): 127 | if rank==0: 128 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) 129 | else: 130 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) 131 | scheduler_g.step() 132 | scheduler_d.step() 133 | 134 | 135 | 136 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): 137 | net_g, net_d = nets 138 | optim_g, optim_d = optims 139 | scheduler_g, scheduler_d = schedulers 140 | train_loader, eval_loader = loaders 141 | if writers is not None: 142 | writer, writer_eval = writers 143 | #tmp=0 144 | #tmp1=1000000000 145 | #train_loader.batch_sampler.set_epoch(epoch) 146 | global global_step 147 | 148 | net_g.train() 149 | net_d.train() 150 | for batch_idx, (c, spec, y, f0, uv) in enumerate(train_loader): 151 | g = None 152 | 153 | spec = spec .cuda(rank, non_blocking=True) 154 | y = y .cuda(rank, non_blocking=True) 155 | c = c .cuda(rank, non_blocking=True) 156 | f0 = f0 .cuda(rank, non_blocking=True) 157 | uv = uv .cuda(rank, non_blocking=True) 158 | 159 | 160 | mel = spec_to_mel_torch( 161 | spec, 162 | hps.data.filter_length, 163 | hps.data.n_mel_channels, 164 | hps.data.sampling_rate, 165 | hps.data.mel_fmin, 166 | hps.data.mel_fmax) 167 | 168 | 169 | with autocast(enabled=hps.train.fp16_run): 170 | #print(c.size()) 171 | y_hat, y_hat_mb, ids_slice, z_mask,\ 172 | (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g, mel=mel, f0=f0, uv=uv) 173 | 174 | mel = spec_to_mel_torch( 175 | spec, 176 | hps.data.filter_length, 177 | hps.data.n_mel_channels, 178 | hps.data.sampling_rate, 179 | hps.data.mel_fmin, 180 | hps.data.mel_fmax) 181 | y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) 182 | y_hat_mel = mel_spectrogram_torch( 183 | y_hat.squeeze(1), 184 | hps.data.filter_length, 185 | hps.data.n_mel_channels, 186 | hps.data.sampling_rate, 187 | hps.data.hop_length, 188 | hps.data.win_length, 189 | hps.data.mel_fmin, 190 | hps.data.mel_fmax 191 | ) 192 | 193 | #tmp=max(tmp,y.size()[2]) 194 | #tmp1=min(tmp1,y.size()[2]) 195 | y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 196 | 197 | if y.shape != y_hat.shape or y_mel.shape != y_hat_mel.shape: 198 | print("output shape != audio data shape") 199 | 200 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 201 | with autocast(enabled=False): 202 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) 203 | loss_disc_all = loss_disc 204 | optim_d.zero_grad() 205 | scaler.scale(loss_disc_all).backward() 206 | scaler.unscale_(optim_d) 207 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) 208 | scaler.step(optim_d) 209 | 210 | with autocast(enabled=hps.train.fp16_run): 211 | # Generator 212 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 213 | with autocast(enabled=False): 214 | #loss_dur = torch.sum(l_length.float()) 215 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 216 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 217 | 218 | loss_fm = feature_loss(fmap_r, fmap_g) 219 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 220 | 221 | if hps.model.mb_istft_vits == True: 222 | pqmf = PQMF(y.device) 223 | y_mb = pqmf.analysis(y) 224 | loss_subband = subband_stft_loss(hps, y_mb, y_hat_mb) 225 | else: 226 | loss_subband = torch.tensor(0.0) 227 | 228 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_subband#+ loss_dur 229 | 230 | optim_g.zero_grad() 231 | scaler.scale(loss_gen_all).backward() 232 | scaler.unscale_(optim_g) 233 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) 234 | scaler.step(optim_g) 235 | scaler.update() 236 | 237 | if rank==0: 238 | if global_step % hps.train.log_interval == 0: 239 | lr = optim_g.param_groups[0]['lr'] 240 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl, loss_subband] 241 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 242 | epoch, 243 | 100. * batch_idx / len(train_loader))) 244 | logger.info([x.item() for x in losses] + [global_step, lr]) 245 | 246 | scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} 247 | scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, "loss/g/subband": loss_subband}) 248 | 249 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) 250 | scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) 251 | scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) 252 | image_dict = { 253 | "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), 254 | "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), 255 | "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), 256 | #"all/attn": utils.plot_alignment_to_numpy(attn[0,0].data.cpu().numpy()) 257 | } 258 | utils.summarize( 259 | writer=writer, 260 | global_step=global_step, 261 | images=image_dict, 262 | scalars=scalar_dict) 263 | 264 | if global_step % hps.train.eval_interval == 0: 265 | evaluate(hps, net_g, eval_loader, writer_eval) 266 | utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 267 | utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) 268 | global_step += 1 269 | 270 | 271 | if rank == 0: 272 | logger.info('====> Epoch: {}'.format(epoch)) 273 | #print(tmp,tmp1) 274 | 275 | 276 | 277 | def evaluate(hps, generator, eval_loader, writer_eval): 278 | generator.eval() 279 | with torch.no_grad(): 280 | for batch_idx, (c, spec, y, f0, uv) in enumerate(eval_loader): 281 | g = None 282 | spec= spec[:1].cuda(0) 283 | y = y[:1].cuda(0) 284 | c = c[:1].cuda(0) 285 | f0 = f0[:1].cuda(0) 286 | uv = uv[:1].cuda(0) 287 | 288 | break 289 | mel = spec_to_mel_torch( 290 | spec, 291 | hps.data.filter_length, 292 | hps.data.n_mel_channels, 293 | hps.data.sampling_rate, 294 | hps.data.mel_fmin, 295 | hps.data.mel_fmax) 296 | #y_hat, y_hat_mb, attn, mask, *_ = generator.module.infer(x, x_lengths, max_len=1000) 297 | #y_hat_lengths = mask.sum([1,2]).long() * hps.data.hop_length 298 | y_hat = generator.module.infer(c, g=g, mel=mel, f0=f0) 299 | mel = spec_to_mel_torch( 300 | spec, 301 | hps.data.filter_length, 302 | hps.data.n_mel_channels, 303 | hps.data.sampling_rate, 304 | hps.data.mel_fmin, 305 | hps.data.mel_fmax) 306 | y_hat_mel = mel_spectrogram_torch( 307 | y_hat.squeeze(1).float(), 308 | hps.data.filter_length, 309 | hps.data.n_mel_channels, 310 | hps.data.sampling_rate, 311 | hps.data.hop_length, 312 | hps.data.win_length, 313 | hps.data.mel_fmin, 314 | hps.data.mel_fmax 315 | ) 316 | image_dict = { 317 | "gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()), 318 | "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()) 319 | } 320 | audio_dict = { 321 | "gen/audio": y_hat[0], 322 | "gt/audio": y[0] 323 | } 324 | 325 | #import torchaudio 326 | #y_gt=y*32768 327 | #print(y_hat.size()) 328 | #torchaudio.save("temp_result/vctkms_new_tem_result_{}.wav".format(global_step),y_hat[0, :, :].cpu(),16000) 329 | #torchaudio.save("temp_result/vctkms_new_tem_result_gt_{}.wav".format(global_step),y[0, :, :].cpu(),16000) 330 | #torchaudio.save("tem_result_gt32768_{}.wav".format(global_step),y_gt[0, :, :].cpu(),16000) 331 | 332 | utils.summarize( 333 | writer=writer_eval, 334 | global_step=global_step, 335 | images=image_dict, 336 | audios=audio_dict, 337 | audio_sampling_rate=hps.data.sampling_rate 338 | ) 339 | generator.train() 340 | 341 | 342 | if __name__ == "__main__": 343 | os.environ[ 344 | "TORCH_DISTRIBUTED_DEBUG" 345 | ] = "DETAIL" 346 | main() 347 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | import torchvision 12 | MATPLOTLIB_FLAG = False 13 | 14 | import numpy as np 15 | from scipy.io.wavfile import read 16 | import torch 17 | from torch.nn import functional as F 18 | 19 | 20 | logging.basicConfig(stream=sys.stdout, level=logging.WARNING) 21 | logger = logging 22 | 23 | def load_model_diffsize(checkpoint_path, model,hps, optimizer=None): 24 | assert os.path.isfile(checkpoint_path) 25 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')["model"] 26 | 27 | if hasattr(model, 'module'): 28 | state_dict = model.module.state_dict() 29 | else: 30 | state_dict = model.state_dict() 31 | 32 | for k, v in checkpoint_dict.items(): 33 | if k in state_dict and state_dict[k].size() == v.size(): 34 | state_dict[k] = v 35 | else: 36 | print("Diffsize ",k) 37 | 38 | if hasattr(model, 'module'): 39 | model.module.load_state_dict(state_dict, strict=False) 40 | else: 41 | model.load_state_dict(state_dict, strict=False) 42 | 43 | return model 44 | 45 | 46 | 47 | def load_checkpoint(checkpoint_path, model, optimizer=None): 48 | assert os.path.isfile(checkpoint_path) 49 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 50 | iteration = checkpoint_dict['iteration'] 51 | learning_rate = checkpoint_dict['learning_rate'] 52 | if optimizer is not None: 53 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 54 | saved_state_dict = checkpoint_dict['model'] 55 | if hasattr(model, 'module'): 56 | state_dict = model.module.state_dict() 57 | else: 58 | state_dict = model.state_dict() 59 | new_state_dict= {} 60 | for k, v in state_dict.items(): 61 | try: 62 | new_state_dict[k] = saved_state_dict[k] 63 | except: 64 | logger.info("%s is not in the checkpoint" % k) 65 | new_state_dict[k] = v 66 | if hasattr(model, 'module'): 67 | model.module.load_state_dict(new_state_dict) 68 | else: 69 | model.load_state_dict(new_state_dict) 70 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 71 | checkpoint_path, iteration)) 72 | return model, optimizer, learning_rate, iteration 73 | 74 | 75 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 76 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 77 | iteration, checkpoint_path)) 78 | if hasattr(model, 'module'): 79 | state_dict = model.module.state_dict() 80 | else: 81 | state_dict = model.state_dict() 82 | torch.save({'model': state_dict, 83 | 'iteration': iteration, 84 | 'optimizer': optimizer.state_dict(), 85 | 'learning_rate': learning_rate}, checkpoint_path) 86 | 87 | 88 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 89 | for k, v in scalars.items(): 90 | writer.add_scalar(k, v, global_step) 91 | for k, v in histograms.items(): 92 | writer.add_histogram(k, v, global_step) 93 | for k, v in images.items(): 94 | writer.add_image(k, v, global_step, dataformats='HWC') 95 | for k, v in audios.items(): 96 | writer.add_audio(k, v, global_step, audio_sampling_rate) 97 | 98 | 99 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 100 | f_list = glob.glob(os.path.join(dir_path, regex)) 101 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 102 | x = f_list[-1] 103 | print(x) 104 | return x 105 | 106 | 107 | def plot_spectrogram_to_numpy(spectrogram): 108 | global MATPLOTLIB_FLAG 109 | if not MATPLOTLIB_FLAG: 110 | import matplotlib 111 | matplotlib.use("Agg") 112 | MATPLOTLIB_FLAG = True 113 | mpl_logger = logging.getLogger('matplotlib') 114 | mpl_logger.setLevel(logging.WARNING) 115 | import matplotlib.pylab as plt 116 | import numpy as np 117 | 118 | fig, ax = plt.subplots(figsize=(10,2)) 119 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 120 | interpolation='none') 121 | plt.colorbar(im, ax=ax) 122 | plt.xlabel("Frames") 123 | plt.ylabel("Channels") 124 | plt.tight_layout() 125 | 126 | fig.canvas.draw() 127 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 128 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 129 | plt.close() 130 | return data 131 | 132 | 133 | def plot_alignment_to_numpy(alignment, info=None): 134 | global MATPLOTLIB_FLAG 135 | if not MATPLOTLIB_FLAG: 136 | import matplotlib 137 | matplotlib.use("Agg") 138 | MATPLOTLIB_FLAG = True 139 | mpl_logger = logging.getLogger('matplotlib') 140 | mpl_logger.setLevel(logging.WARNING) 141 | import matplotlib.pylab as plt 142 | import numpy as np 143 | 144 | fig, ax = plt.subplots(figsize=(6, 4)) 145 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 146 | interpolation='none') 147 | fig.colorbar(im, ax=ax) 148 | xlabel = 'Decoder timestep' 149 | if info is not None: 150 | xlabel += '\n\n' + info 151 | plt.xlabel(xlabel) 152 | plt.ylabel('Encoder timestep') 153 | plt.tight_layout() 154 | 155 | fig.canvas.draw() 156 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 157 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 158 | plt.close() 159 | return data 160 | 161 | """ 162 | def load_wav_to_torch(full_path): 163 | sampling_rate, data = read(full_path) 164 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 165 | """ 166 | import soundfile as sf 167 | def load_wav_to_torch(full_path): 168 | sampling_rate, wav = read(full_path.replace("\\", "/")) ### modify .replace("\\", "/") ### 169 | #sampling_rate, wav = sf.read(full_path.replace("\\", "/")) ### modify .replace("\\", "/") ### 170 | 171 | if len(wav.shape) == 2: 172 | wav = wav[:, 0] 173 | if wav.dtype == np.int16: 174 | wav = wav / 32768.0 175 | elif wav.dtype == np.int32: 176 | wav = wav / 2147483648.0 177 | elif wav.dtype == np.uint8: 178 | wav = (wav - 128) / 128.0 179 | wav = wav.astype(np.float32) 180 | 181 | if sampling_rate != 44100: 182 | print("ERROR SAMPLINGRATE") 183 | pass 184 | return torch.FloatTensor(wav), sampling_rate 185 | 186 | def load_filepaths_and_text(filename, split="|"): 187 | with open(filename, encoding='utf-8') as f: 188 | filepaths_and_text = [line.strip().split(split) for line in f] 189 | return filepaths_and_text 190 | 191 | 192 | def get_hparams(init=True): 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('-c', '--config', type=str, default="./configs/quickvc_44100.json", 195 | help='JSON file for configuration') 196 | parser.add_argument('-m', '--model', type=str,default="QuickVC_Ja", 197 | help='Model name') 198 | 199 | args = parser.parse_args() 200 | model_dir = os.path.join("./logs", args.model) 201 | 202 | if not os.path.exists(model_dir): 203 | os.makedirs(model_dir) 204 | 205 | config_path = args.config 206 | config_save_path = os.path.join(model_dir, "config.json") 207 | if init: 208 | with open(config_path, "r") as f: 209 | data = f.read() 210 | with open(config_save_path, "w") as f: 211 | f.write(data) 212 | else: 213 | with open(config_save_path, "r") as f: 214 | data = f.read() 215 | config = json.loads(data) 216 | 217 | hparams = HParams(**config) 218 | hparams.model_dir = model_dir 219 | return hparams 220 | 221 | def transform(mel, height): # 68-92 222 | #r = np.random.random() 223 | #rate = r * 0.3 + 0.85 # 0.85-1.15 224 | #height = int(mel.size(-2) * rate) 225 | tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1))) 226 | if height >= mel.size(-2): 227 | return tgt[:, :mel.size(-2), :] 228 | else: 229 | silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1) 230 | silence += torch.randn_like(silence) / 10 231 | return torch.cat((tgt, silence), 1) 232 | 233 | def get_hparams_from_dir(model_dir): 234 | config_save_path = os.path.join(model_dir, "config.json") 235 | with open(config_save_path, "r") as f: 236 | data = f.read() 237 | config = json.loads(data) 238 | 239 | hparams =HParams(**config) 240 | hparams.model_dir = model_dir 241 | return hparams 242 | 243 | 244 | def get_hparams_from_file(config_path): 245 | with open(config_path, "r") as f: 246 | data = f.read() 247 | config = json.loads(data) 248 | 249 | hparams =HParams(**config) 250 | return hparams 251 | 252 | 253 | def check_git_hash(model_dir): 254 | source_dir = os.path.dirname(os.path.realpath(__file__)) 255 | if not os.path.exists(os.path.join(source_dir, ".git")): 256 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 257 | source_dir 258 | )) 259 | return 260 | 261 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 262 | 263 | path = os.path.join(model_dir, "githash") 264 | if os.path.exists(path): 265 | saved_hash = open(path).read() 266 | if saved_hash != cur_hash: 267 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 268 | saved_hash[:8], cur_hash[:8])) 269 | else: 270 | open(path, "w").write(cur_hash) 271 | 272 | 273 | def get_logger(model_dir, filename="train.log"): 274 | global logger 275 | logger = logging.getLogger(os.path.basename(model_dir)) 276 | logger.setLevel(logging.DEBUG) 277 | 278 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 279 | if not os.path.exists(model_dir): 280 | os.makedirs(model_dir) 281 | h = logging.FileHandler(os.path.join(model_dir, filename)) 282 | h.setLevel(logging.DEBUG) 283 | h.setFormatter(formatter) 284 | logger.addHandler(h) 285 | return logger 286 | 287 | 288 | class HParams(): 289 | def __init__(self, **kwargs): 290 | for k, v in kwargs.items(): 291 | if type(v) == dict: 292 | v = HParams(**v) 293 | self[k] = v 294 | 295 | def keys(self): 296 | return self.__dict__.keys() 297 | 298 | def items(self): 299 | return self.__dict__.items() 300 | 301 | def values(self): 302 | return self.__dict__.values() 303 | 304 | def __len__(self): 305 | return len(self.__dict__) 306 | 307 | def __getitem__(self, key): 308 | return getattr(self, key) 309 | 310 | def __setitem__(self, key, value): 311 | return setattr(self, key, value) 312 | 313 | def __contains__(self, key): 314 | return key in self.__dict__ 315 | 316 | def __repr__(self): 317 | return self.__dict__.__repr__() 318 | 319 | 320 | 321 | ################################################################ 322 | ### add from https://github.com/svc-develop-team/so-vits-svc ### 323 | ################################################################ 324 | 325 | f0_bin = 256 326 | f0_max = 1100.0 327 | f0_min = 50.0 328 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 329 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 330 | 331 | def normalize_f0(f0, x_mask, uv, random_scale=True): 332 | # calculate means based on x_mask 333 | uv_sum = torch.sum(uv, dim=1, keepdim=True) 334 | uv_sum[uv_sum == 0] = 9999 335 | means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum 336 | 337 | if random_scale: 338 | factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device) 339 | else: 340 | factor = torch.ones(f0.shape[0], 1).to(f0.device) 341 | # normalize f0 based on means and factor 342 | f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) 343 | if torch.isnan(f0_norm).any(): 344 | exit(0) 345 | return f0_norm * x_mask 346 | 347 | def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05): 348 | from crepe import CrepePitchExtractor 349 | x = wav_numpy 350 | if p_len is None: 351 | p_len = x.shape[0]//hop_length 352 | else: 353 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 354 | 355 | F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold) 356 | f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len) 357 | return f0,uv 358 | 359 | 360 | def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 361 | import parselmouth 362 | x = wav_numpy 363 | if p_len is None: 364 | p_len = x.shape[0]//hop_length 365 | else: 366 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 367 | time_step = hop_length / sampling_rate * 1000 368 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac( 369 | time_step=time_step / 1000, voicing_threshold=0.6, 370 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 371 | 372 | pad_size=(p_len - len(f0) + 1) // 2 373 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 374 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 375 | return f0 376 | 377 | def resize_f0(x, target_len): 378 | source = np.array(x) 379 | source[source<0.001] = np.nan 380 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) 381 | res = np.nan_to_num(target) 382 | return res 383 | 384 | def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 385 | import pyworld 386 | if p_len is None: 387 | p_len = wav_numpy.shape[0]//hop_length 388 | f0, t = pyworld.dio( 389 | wav_numpy.astype(np.double), 390 | fs=sampling_rate, 391 | f0_ceil=800, 392 | frame_period=1000 * hop_length / sampling_rate, 393 | ) 394 | f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate) 395 | for index, pitch in enumerate(f0): 396 | f0[index] = round(pitch, 1) 397 | return resize_f0(f0, p_len) 398 | 399 | def f0_to_coarse(f0): 400 | is_torch = isinstance(f0, torch.Tensor) 401 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 402 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 403 | 404 | f0_mel[f0_mel <= 1] = 1 405 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 406 | f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int) 407 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 408 | return f0_coarse 409 | 410 | 411 | def interpolate_f0(f0): 412 | 413 | data = np.reshape(f0, (f0.size, 1)) 414 | 415 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32) 416 | vuv_vector[data > 0.0] = 1.0 417 | vuv_vector[data <= 0.0] = 0.0 418 | 419 | ip_data = data 420 | 421 | frame_number = data.size 422 | last_value = 0.0 423 | for i in range(frame_number): 424 | if data[i] <= 0.0: 425 | j = i + 1 426 | for j in range(i + 1, frame_number): 427 | if data[j] > 0.0: 428 | break 429 | if j < frame_number - 1: 430 | if last_value > 0.0: 431 | step = (data[j] - data[i - 1]) / float(j - i) 432 | for k in range(i, j): 433 | ip_data[k] = data[i - 1] + step * (k - i + 1) 434 | else: 435 | for k in range(i, j): 436 | ip_data[k] = data[j] 437 | else: 438 | for k in range(i, frame_number): 439 | ip_data[k] = last_value 440 | else: 441 | ip_data[i] = data[i] # this may not be necessary 442 | last_value = data[i] 443 | 444 | return ip_data[:,0], vuv_vector[:,0] 445 | 446 | 447 | def repeat_expand_2d(content, target_len): 448 | # content : [h, t] 449 | 450 | src_len = content.shape[-1] 451 | target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device) 452 | temp = torch.arange(src_len+1) * target_len / src_len 453 | current_pos = 0 454 | for i in range(target_len): 455 | if i < temp[current_pos+1]: 456 | target[:, i] = content[:, current_pos] 457 | else: 458 | current_pos += 1 459 | target[:, i] = content[:, current_pos] 460 | 461 | return target 462 | 463 | ### ### --------------------------------------------------------------------------------