├── LICENSE ├── README.md ├── commons.py ├── configs └── hubert-neuraldec-vits.json ├── conv.py ├── convert.ipynb ├── data_utils.py ├── downsample.py ├── extra ├── DSConv.py ├── attentions.py ├── commons.py └── modules.py ├── filelists ├── test.txt ├── train.txt └── val.txt ├── hifigan ├── __init__.py ├── config.json ├── generator_v1.txt └── models.py ├── losses.py ├── lstm.py ├── mel_processing.py ├── models.py ├── modules.py ├── norm.py ├── preprocess_code.py ├── preprocess_flist.py ├── preprocess_spk.py ├── preprocess_sr.py ├── requirements.txt ├── resources └── NeurlVC.png ├── speaker_encoder ├── __init__.py ├── audio.py ├── ckpt │ ├── pretrained_bak_5805000.pt │ └── pretrained_bak_5805000.pt.txt ├── compute_embed.py ├── config.py ├── data_objects │ ├── __init__.py │ ├── random_cycler.py │ ├── speaker.py │ ├── speaker_batch.py │ ├── speaker_verification_dataset.py │ └── utterance.py ├── hparams.py ├── inference.py ├── model.py ├── params_data.py ├── params_model.py ├── preprocess.py ├── train.py ├── visualizations.py └── voice_encoder.py ├── train.py ├── utils.py └── wavlm ├── WavLM-Large.pt.txt ├── WavLM.py ├── __init__.py └── modules.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jingyi Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralVC Any-to-Any Voice Conversion Using Neural Networks Decoder For Real-Time Voice Conversion 2 | 3 | 4 | In this paper, we adopt the end-to-end [VITS](https://arxiv.org/abs/2106.06103) framework for high-quality waveform reconstruction. By introducing HuBERT-Soft, we extract clean speech content information, and by incorporating a pre-trained speaker encoder, we extract speaker characteristics from the speech. Inspired by the structure of [speech compression models](https://arxiv.org/abs/2210.13438), we propose a **neural decoder** that synthesizes converted speech with the target speaker's voice by adding preprocessing and conditioning networks to receive and interpret speaker information. Additionally, we significantly improve the model's inference speed, achieving real-time voice conversion. 5 | 6 | Audio samples:https://jinyuanzhang999.github.io/NeuralVC_Demo.github.io/ 7 | 8 | We also provide the [pretrained models](https://1drv.ms/f/c/87587ec0bae9be5a/Ek_2ur6Uwr5Lq1g-C5-5FFUB5JkhHHhLPg9iQxKxFvHm0w?e=Zpcxec). 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
model framework
Model Framework
18 | 19 | 20 | ## Pre-requisites 21 | 22 | 1. Clone this repo: `git clone https://github.com/zzy1hjq/NeutralVC.git` 23 | 24 | 2. CD into this repo: `cd NeuralVC` 25 | 26 | 3. Install python requirements: `pip install -r requirements.txt` 27 | 28 | 4. Download the [VCTK](https://datashare.ed.ac.uk/handle/10283/3443) dataset (for training only) 29 | 30 | 31 | ## Inference Example 32 | 33 | Download the pretrained checkpoints and run: 34 | 35 | ```python 36 | # inference with NeuralVC 37 | # Replace the corresponding parameters 38 | convert.ipynb 39 | ``` 40 | 41 | ## Training Example 42 | 43 | 1. Preprocess 44 | 45 | ```python 46 | 47 | # run this if you want a different train-val-test split 48 | python preprocess_flist.py 49 | 50 | # run this if you want to use pretrained speaker encoder 51 | python preprocess_spk.py 52 | 53 | # run this if you want to use a different content feature extractor. 54 | python preprocess_code.py 55 | 56 | ``` 57 | 58 | 2. Train 59 | 60 | ```python 61 | # train NeuralVC 62 | python train.py 63 | 64 | 65 | ``` 66 | 67 | ## References 68 | 69 | - https://github.com/jaywalnut310/vits 70 | - https://github.com/OlaWod/FreeVC 71 | - https://github.com/quickvc/QuickVC-VoiceConversion 72 | - https://github.com/facebookresearch/encodec 73 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def rand_spec_segments(x, x_lengths=None, segment_size=4): 68 | b, d, t = x.size() 69 | if x_lengths is None: 70 | x_lengths = t 71 | ids_str_max = x_lengths - segment_size 72 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 73 | ret = slice_segments(x, ids_str, segment_size) 74 | return ret, ids_str 75 | 76 | 77 | def get_timing_signal_1d( 78 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 79 | position = torch.arange(length, dtype=torch.float) 80 | num_timescales = channels // 2 81 | log_timescale_increment = ( 82 | math.log(float(max_timescale) / float(min_timescale)) / 83 | (num_timescales - 1)) 84 | inv_timescales = min_timescale * torch.exp( 85 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 86 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 87 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 88 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 89 | signal = signal.view(1, channels, length) 90 | return signal 91 | 92 | 93 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 94 | b, channels, length = x.size() 95 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 96 | return x + signal.to(dtype=x.dtype, device=x.device) 97 | 98 | 99 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 100 | b, channels, length = x.size() 101 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 102 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 103 | 104 | 105 | def subsequent_mask(length): 106 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 107 | return mask 108 | 109 | 110 | @torch.jit.script 111 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 112 | n_channels_int = n_channels[0] 113 | in_act = input_a + input_b 114 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 115 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 116 | acts = t_act * s_act 117 | return acts 118 | 119 | 120 | def convert_pad_shape(pad_shape): 121 | l = pad_shape[::-1] 122 | pad_shape = [item for sublist in l for item in sublist] 123 | return pad_shape 124 | 125 | 126 | def shift_1d(x): 127 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 128 | return x 129 | 130 | 131 | def sequence_mask(length, max_length=None): 132 | if max_length is None: 133 | max_length = length.max() 134 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 135 | return x.unsqueeze(0) < length.unsqueeze(1) 136 | 137 | 138 | def generate_path(duration, mask): 139 | """ 140 | duration: [b, 1, t_x] 141 | mask: [b, 1, t_y, t_x] 142 | """ 143 | device = duration.device 144 | 145 | b, _, t_y, t_x = mask.shape 146 | cum_duration = torch.cumsum(duration, -1) 147 | 148 | cum_duration_flat = cum_duration.view(b * t_x) 149 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 150 | path = path.view(b, t_x, t_y) 151 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 152 | path = path.unsqueeze(1).transpose(2,3) * mask 153 | return path 154 | 155 | 156 | def clip_grad_value_(parameters, clip_value, norm_type=2): 157 | if isinstance(parameters, torch.Tensor): 158 | parameters = [parameters] 159 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 160 | norm_type = float(norm_type) 161 | if clip_value is not None: 162 | clip_value = float(clip_value) 163 | 164 | total_norm = 0 165 | for p in parameters: 166 | param_norm = p.grad.data.norm(norm_type) 167 | total_norm += param_norm.item() ** norm_type 168 | if clip_value is not None: 169 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 170 | total_norm = total_norm ** (1. / norm_type) 171 | return total_norm 172 | -------------------------------------------------------------------------------- /configs/hubert-neuraldec-vits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 10000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 64, 11 | "fp16_run": false, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8960, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0, 18 | "use_sr": false, 19 | "max_speclen": 128, 20 | "port": "8001" 21 | }, 22 | "data": { 23 | "training_files":"filelists/train.txt", 24 | "validation_files":"filelists/val.txt", 25 | "max_wav_value": 32768.0, 26 | "sampling_rate": 16000, 27 | "filter_length": 1280, 28 | "hop_length": 320, 29 | "win_length": 1280, 30 | "n_mel_channels": 80, 31 | "mel_fmin": 0.0, 32 | "mel_fmax": null 33 | }, 34 | "model": { 35 | "inter_channels": 192, 36 | "hidden_channels": 192, 37 | "filter_channels": 768, 38 | "n_heads": 2, 39 | "n_layers": 6, 40 | "kernel_size": 3, 41 | "p_dropout": 0.1, 42 | "resblock": "1", 43 | "resblock_kernel_sizes": [3,7,11], 44 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 45 | "upsample_rates": [10,8,2,2], 46 | "upsample_initial_channel": 512, 47 | "upsample_kernel_sizes": [16,16,4,4], 48 | "n_layers_q": 3, 49 | "use_spectral_norm": false, 50 | "gin_channels": 256, 51 | "ssl_dim": 256, 52 | "use_spk": true 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Convolutional layers wrappers and utilities.""" 8 | 9 | import math 10 | import typing as tp 11 | import warnings 12 | 13 | import torch 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.nn.utils import spectral_norm, weight_norm 17 | 18 | from norm import ConvLayerNorm 19 | 20 | 21 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', 22 | 'time_layer_norm', 'layer_norm', 'time_group_norm']) 23 | 24 | 25 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: 26 | assert norm in CONV_NORMALIZATIONS 27 | if norm == 'weight_norm': 28 | return weight_norm(module) 29 | elif norm == 'spectral_norm': 30 | return spectral_norm(module) 31 | else: 32 | # We already check was in CONV_NORMALIZATION, so any other choice 33 | # doesn't need reparametrization. 34 | return module 35 | 36 | 37 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: 38 | """Return the proper normalization module. If causal is True, this will ensure the returned 39 | module is causal, or return an error if the normalization doesn't support causal evaluation. 40 | """ 41 | assert norm in CONV_NORMALIZATIONS 42 | if norm == 'layer_norm': 43 | assert isinstance(module, nn.modules.conv._ConvNd) 44 | return ConvLayerNorm(module.out_channels, **norm_kwargs) 45 | elif norm == 'time_group_norm': 46 | if causal: 47 | raise ValueError("GroupNorm doesn't support causal evaluation.") 48 | assert isinstance(module, nn.modules.conv._ConvNd) 49 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs) 50 | else: 51 | return nn.Identity() 52 | 53 | 54 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, 55 | padding_total: int = 0) -> int: 56 | """See `pad_for_conv1d`. 57 | """ 58 | length = x.shape[-1] 59 | n_frames = (length - kernel_size + padding_total) / stride + 1 60 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) 61 | return ideal_length - length 62 | 63 | 64 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): 65 | """Pad for a convolution to make sure that the last window is full. 66 | Extra padding is added at the end. This is required to ensure that we can rebuild 67 | an output of the same length, as otherwise, even with padding, some time steps 68 | might get removed. 69 | For instance, with total padding = 4, kernel size = 4, stride = 2: 70 | 0 0 1 2 3 4 5 0 0 # (0s are padding) 71 | 1 2 3 # (output frames of a convolution, last 0 is never used) 72 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 73 | 1 2 3 4 # once you removed padding, we are missing one time step ! 74 | """ 75 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 76 | return F.pad(x, (0, extra_padding)) 77 | 78 | 79 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): 80 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input. 81 | If this is the case, we insert extra 0 padding to the right before the reflection happen. 82 | """ 83 | length = x.shape[-1] 84 | padding_left, padding_right = paddings 85 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 86 | if mode == 'reflect': 87 | max_pad = max(padding_left, padding_right) 88 | extra_pad = 0 89 | if length <= max_pad: 90 | extra_pad = max_pad - length + 1 91 | x = F.pad(x, (0, extra_pad)) 92 | padded = F.pad(x, paddings, mode, value) 93 | end = padded.shape[-1] - extra_pad 94 | return padded[..., :end] 95 | else: 96 | return F.pad(x, paddings, mode, value) 97 | 98 | 99 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 100 | """Remove padding from x, handling properly zero padding. Only for 1d!""" 101 | padding_left, padding_right = paddings 102 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 103 | assert (padding_left + padding_right) <= x.shape[-1] 104 | end = x.shape[-1] - padding_right 105 | return x[..., padding_left: end] 106 | 107 | 108 | class NormConv1d(nn.Module): 109 | """Wrapper around Conv1d and normalization applied to this conv 110 | to provide a uniform interface across normalization approaches. 111 | """ 112 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 113 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 114 | super().__init__() 115 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) 116 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) 117 | self.norm_type = norm 118 | 119 | def forward(self, x): 120 | x = self.conv(x) 121 | x = self.norm(x) 122 | return x 123 | 124 | 125 | class NormConv2d(nn.Module): 126 | """Wrapper around Conv2d and normalization applied to this conv 127 | to provide a uniform interface across normalization approaches. 128 | """ 129 | def __init__(self, *args, norm: str = 'none', 130 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 131 | super().__init__() 132 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) 133 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) 134 | self.norm_type = norm 135 | 136 | def forward(self, x): 137 | x = self.conv(x) 138 | x = self.norm(x) 139 | return x 140 | 141 | 142 | class NormConvTranspose1d(nn.Module): 143 | """Wrapper around ConvTranspose1d and normalization applied to this conv 144 | to provide a uniform interface across normalization approaches. 145 | """ 146 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 147 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 148 | super().__init__() 149 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) 150 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) 151 | self.norm_type = norm 152 | 153 | def forward(self, x): 154 | x = self.convtr(x) 155 | x = self.norm(x) 156 | return x 157 | 158 | 159 | class NormConvTranspose2d(nn.Module): 160 | """Wrapper around ConvTranspose2d and normalization applied to this conv 161 | to provide a uniform interface across normalization approaches. 162 | """ 163 | def __init__(self, *args, norm: str = 'none', 164 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 165 | super().__init__() 166 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) 167 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) 168 | 169 | def forward(self, x): 170 | x = self.convtr(x) 171 | x = self.norm(x) 172 | return x 173 | 174 | 175 | class SConv1d(nn.Module): 176 | """Conv1d with some builtin handling of asymmetric or causal padding 177 | and normalization. 178 | """ 179 | def __init__(self, in_channels: int, out_channels: int, 180 | kernel_size: int, stride: int = 1, dilation: int = 1, 181 | groups: int = 1, bias: bool = True, causal: bool = False, 182 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, 183 | pad_mode: str = 'reflect'): 184 | super().__init__() 185 | # warn user on unusual setup between dilation and stride 186 | if stride > 1 and dilation > 1: 187 | warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' 188 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') 189 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, 190 | dilation=dilation, groups=groups, bias=bias, causal=causal, 191 | norm=norm, norm_kwargs=norm_kwargs) 192 | self.causal = causal 193 | self.pad_mode = pad_mode 194 | 195 | def forward(self, x): 196 | B, C, T = x.shape 197 | kernel_size = self.conv.conv.kernel_size[0] 198 | stride = self.conv.conv.stride[0] 199 | dilation = self.conv.conv.dilation[0] 200 | kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations 201 | padding_total = kernel_size - stride 202 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 203 | if self.causal: 204 | # Left padding for causal 205 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) 206 | else: 207 | # Asymmetric padding required for odd strides 208 | padding_right = padding_total // 2 209 | padding_left = padding_total - padding_right 210 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) 211 | return self.conv(x) 212 | 213 | 214 | class SConvTranspose1d(nn.Module): 215 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding 216 | and normalization. 217 | """ 218 | def __init__(self, in_channels: int, out_channels: int, 219 | kernel_size: int, stride: int = 1, causal: bool = False, 220 | norm: str = 'none', trim_right_ratio: float = 1., 221 | norm_kwargs: tp.Dict[str, tp.Any] = {}): 222 | super().__init__() 223 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, 224 | causal=causal, norm=norm, norm_kwargs=norm_kwargs) 225 | self.causal = causal 226 | self.trim_right_ratio = trim_right_ratio 227 | assert self.causal or self.trim_right_ratio == 1., \ 228 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" 229 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. 230 | 231 | def forward(self, x): 232 | kernel_size = self.convtr.convtr.kernel_size[0] 233 | stride = self.convtr.convtr.stride[0] 234 | padding_total = kernel_size - stride 235 | 236 | y = self.convtr(x) 237 | 238 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be 239 | # removed at the very end, when keeping only the right length for the output, 240 | # as removing it here would require also passing the length at the matching layer 241 | # in the encoder. 242 | if self.causal: 243 | # Trim the padding on the right according to the specified ratio 244 | # if trim_right_ratio = 1.0, trim everything from right 245 | padding_right = math.ceil(padding_total * self.trim_right_ratio) 246 | padding_left = padding_total - padding_right 247 | y = unpad1d(y, (padding_left, padding_right)) 248 | else: 249 | # Asymmetric padding required for odd strides 250 | padding_right = padding_total // 2 251 | padding_left = padding_total - padding_right 252 | y = unpad1d(y, (padding_left, padding_right)) 253 | return y 254 | -------------------------------------------------------------------------------- /convert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import argparse\n", 11 | "import torch\n", 12 | "import librosa\n", 13 | "import time\n", 14 | "from scipy.io.wavfile import write\n", 15 | "from tqdm import tqdm\n", 16 | "import soundfile as sf\n", 17 | "import utils\n", 18 | "import time\n", 19 | "from models import HuBERT_NeuralDec_VITS\n", 20 | "from mel_processing import mel_spectrogram_torch\n", 21 | "import logging\n", 22 | "\n", 23 | "from speaker_encoder.voice_encoder import SpeakerEncoder\n", 24 | "logging.getLogger('numba').setLevel(logging.WARNING)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "class Parameters:\n", 34 | " def __init__(self):\n", 35 | " self.hpfile = \"logs/neuralvc/config.json\"\n", 36 | " self.ptfile = \"logs/neuralvc/G_990000.pth\"\n", 37 | " self.model_name = \"hubert-neuraldec-vits\"\n", 38 | " self.outdir = \"output/temp\"\n", 39 | " self.use_timestamp = False\n", 40 | "args = Parameters()" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "if not os.path.exists(args.outdir):\n", 50 | " os.makedirs(args.outdir)\n", 51 | "\n", 52 | "# hps = utils.get_hparams_from_file(args.hpfile)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "os.makedirs(args.outdir, exist_ok=True)\n", 62 | "hps = utils.get_hparams_from_file(args.hpfile)\n", 63 | "\n", 64 | "print(\"Loading model...\")\n", 65 | "net_g = HuBERT_NeuralDec_VITS(\n", 66 | " hps.data.filter_length // 2 + 1,\n", 67 | " hps.train.segment_size // hps.data.hop_length,\n", 68 | " **hps.model)\n", 69 | "_ = net_g.eval()\n", 70 | "\n", 71 | "print(\"Loading checkpoint...\")\n", 72 | "_ = utils.load_checkpoint(args.ptfile, net_g, None, True)\n", 73 | "\n", 74 | "print(\"Loading hubert...\")\n", 75 | "hubert = torch.hub.load(\"bshall/hubert:main\", f\"hubert_soft\").eval() \n", 76 | "\n", 77 | "if hps.model.use_spk:\n", 78 | " print(\"Loading speaker encoder...\")\n", 79 | " smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')\n", 80 | "print(\"ok\")" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from tqdm import tqdm\n", 90 | "\n", 91 | "def convert(src_list, tgt):\n", 92 | " tgtname = tgt.split(\"/\")[-1].split(\".\")[0]\n", 93 | " wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)\n", 94 | " if not os.path.exists(os.path.join(args.outdir, tgtname)):\n", 95 | " os.makedirs(os.path.join(args.outdir, tgtname))\n", 96 | " sf.write(os.path.join(args.outdir, tgtname, f\"tgt_{tgtname}.wav\"), wav_tgt, hps.data.sampling_rate)\n", 97 | " wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)\n", 98 | " g_tgt = smodel.embed_utterance(wav_tgt)\n", 99 | " g_tgt = torch.from_numpy(g_tgt).unsqueeze(0)\n", 100 | " for src in tqdm(src_list):\n", 101 | " srcname = src.split(\"/\")[-1].split(\".\")[0]\n", 102 | " title = srcname + \"-\" + tgtname\n", 103 | " with torch.no_grad():\n", 104 | " wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)\n", 105 | " sf.write(os.path.join(args.outdir, tgtname, f\"src_{srcname}.wav\"), wav_src, hps.data.sampling_rate)\n", 106 | " wav_src = torch.from_numpy(wav_src).unsqueeze(0).unsqueeze(0)\n", 107 | " c = hubert.units(wav_src)\n", 108 | " c = c.transpose(1,2)\n", 109 | " audio = net_g.infer(c, g=g_tgt)\n", 110 | " audio = audio[0][0].data.cpu().float().numpy()\n", 111 | " write(os.path.join(args.outdir, tgtname, f\"{title}.wav\"), hps.data.sampling_rate, audio)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# Test\n", 121 | "import time\n", 122 | "\n", 123 | "tgt1 = \"/mnt/hd/cma/zzy/dataset/test/M_5105_28233_000016_000001.wav\"\n", 124 | "\n", 125 | "src_list1 = [\"/mnt/hd/cma/zzy/dataset/test/F_3575_170457_000032_000001.wav\"]\n", 126 | "\n", 127 | "convert(src_list1, tgt1)" 128 | ] 129 | } 130 | ], 131 | "metadata": { 132 | "language_info": { 133 | "name": "python" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 2 138 | } 139 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | import commons 9 | from mel_processing import spectrogram_torch, spec_to_mel_torch 10 | from utils import load_wav_to_torch, load_filepaths_and_text, transform 11 | #import h5py 12 | 13 | 14 | """Multi speaker version""" 15 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 16 | """ 17 | 1) loads audio, speaker_id, text pairs 18 | 2) normalizes text and converts them to sequences of integers 19 | 3) computes spectrograms from audio files. 20 | """ 21 | def __init__(self, audiopaths, hparams): 22 | self.audiopaths = load_filepaths_and_text(audiopaths) 23 | self.max_wav_value = hparams.data.max_wav_value 24 | self.sampling_rate = hparams.data.sampling_rate 25 | self.filter_length = hparams.data.filter_length 26 | self.hop_length = hparams.data.hop_length 27 | self.win_length = hparams.data.win_length 28 | self.sampling_rate = hparams.data.sampling_rate 29 | self.use_sr = hparams.train.use_sr 30 | self.use_spk = hparams.model.use_spk 31 | self.spec_len = hparams.train.max_speclen 32 | 33 | random.seed(1235) 34 | random.shuffle(self.audiopaths) 35 | self._filter() 36 | 37 | def _filter(self): 38 | """ 39 | Filter text & store spec lengths 40 | """ 41 | # Store spectrogram lengths for Bucketing 42 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 43 | # spec_length = wav_length // hop_length 44 | 45 | lengths = [] 46 | for audiopath in self.audiopaths: 47 | lengths.append(os.path.getsize(audiopath[0]) // (2 * self.hop_length)) 48 | self.lengths = lengths 49 | 50 | def get_audio(self, filename): 51 | audio, sampling_rate = load_wav_to_torch(filename) 52 | if sampling_rate != self.sampling_rate: 53 | raise ValueError("{} SR doesn't match target {} SR,the audio is{}".format( 54 | sampling_rate, self.sampling_rate,filename)) 55 | audio_norm = audio / self.max_wav_value 56 | audio_norm = audio_norm.unsqueeze(0) 57 | spec_filename = filename.replace(".wav", ".spec.pt") 58 | if os.path.exists(spec_filename): 59 | spec = torch.load(spec_filename) 60 | else: 61 | spec = spectrogram_torch(audio_norm, self.filter_length, 62 | self.sampling_rate, self.hop_length, self.win_length, 63 | center=False) 64 | spec = torch.squeeze(spec, 0) 65 | torch.save(spec, spec_filename) 66 | 67 | if self.use_spk: 68 | spk_filename = filename.replace(".wav", ".npy").replace("vctk-mini-16k", "spk") 69 | spk = torch.from_numpy(np.load(spk_filename)) 70 | 71 | c_filename = filename.replace(".wav", ".pt") 72 | c=torch.load(c_filename) 73 | c=c.transpose(1,0).squeeze(1) 74 | 75 | return c, spec, audio_norm, spk 76 | 77 | def __getitem__(self, index): 78 | return self.get_audio(self.audiopaths[index][0]) 79 | 80 | def __len__(self): 81 | return len(self.audiopaths) 82 | 83 | 84 | 85 | class TextAudioSpeakerCollate(): 86 | """ Zero-pads model inputs and targets 87 | """ 88 | def __init__(self, hps): 89 | self.hps = hps 90 | self.use_sr = hps.train.use_sr 91 | self.use_spk = hps.model.use_spk 92 | 93 | def __call__(self, batch): 94 | """Collate's training batch from normalized text, audio and speaker identities 95 | PARAMS 96 | ------ 97 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 98 | """ 99 | # Right zero-pad all one-hot text sequences to max input length 100 | _, ids_sorted_decreasing = torch.sort( 101 | torch.LongTensor([x[0].size(1) for x in batch]), 102 | dim=0, descending=True) 103 | 104 | max_c_len = max([x[0].size(1) for x in batch]) 105 | max_spec_len = max([x[1].size(1) for x in batch]) 106 | max_wav_len = max([x[2].size(1) for x in batch]) 107 | 108 | spec_lengths = torch.LongTensor(len(batch)) 109 | wav_lengths = torch.LongTensor(len(batch)) 110 | if self.use_spk: 111 | spks = torch.FloatTensor(len(batch), batch[0][3].size(0)) 112 | else: 113 | spks = None 114 | 115 | c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_c_len) 116 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 117 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 118 | c_padded.zero_() 119 | spec_padded.zero_() 120 | wav_padded.zero_() 121 | 122 | for i in range(len(ids_sorted_decreasing)): 123 | row = batch[ids_sorted_decreasing[i]] 124 | 125 | c = row[0] 126 | c_padded[i, :, :c.size(1)] = c 127 | 128 | spec = row[1] 129 | spec_padded[i, :, :spec.size(1)] = spec 130 | spec_lengths[i] = spec.size(1) 131 | 132 | wav = row[2] 133 | wav_padded[i, :, :wav.size(1)] = wav 134 | wav_lengths[i] = wav.size(1) 135 | 136 | if self.use_spk: 137 | spks[i] = row[3] 138 | 139 | spec_seglen = spec_lengths[-1] if spec_lengths[-1] < self.hps.train.max_speclen + 1 else self.hps.train.max_speclen + 1 140 | wav_seglen = spec_seglen * self.hps.data.hop_length 141 | 142 | spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen) 143 | wav_padded = commons.slice_segments(wav_padded, ids_slice * self.hps.data.hop_length, wav_seglen) 144 | 145 | c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:,:,:-1] 146 | 147 | spec_padded = spec_padded[:,:,:-1] 148 | wav_padded = wav_padded[:,:,:-self.hps.data.hop_length] 149 | 150 | if self.use_spk: 151 | return c_padded, spec_padded, wav_padded, spks 152 | else: 153 | return c_padded, spec_padded, wav_padded 154 | 155 | 156 | 157 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 158 | """ 159 | Maintain similar input lengths in a batch. 160 | Length groups are specified by boundaries. 161 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 162 | 163 | It removes samples which are not included in the boundaries. 164 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 165 | """ 166 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): 167 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 168 | self.lengths = dataset.lengths 169 | self.batch_size = batch_size 170 | self.boundaries = boundaries 171 | 172 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 173 | self.total_size = sum(self.num_samples_per_bucket) 174 | self.num_samples = self.total_size // self.num_replicas 175 | 176 | def _create_buckets(self): 177 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 178 | for i in range(len(self.lengths)): 179 | length = self.lengths[i] 180 | idx_bucket = self._bisect(length) 181 | if idx_bucket != -1: 182 | buckets[idx_bucket].append(i) 183 | 184 | for i in range(len(buckets) - 1, 0, -1): 185 | if len(buckets[i]) == 0: 186 | buckets.pop(i) 187 | self.boundaries.pop(i+1) 188 | 189 | num_samples_per_bucket = [] 190 | for i in range(len(buckets)): 191 | len_bucket = len(buckets[i]) 192 | total_batch_size = self.num_replicas * self.batch_size 193 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 194 | num_samples_per_bucket.append(len_bucket + rem) 195 | return buckets, num_samples_per_bucket 196 | 197 | def __iter__(self): 198 | # deterministically shuffle based on epoch 199 | g = torch.Generator() 200 | g.manual_seed(self.epoch) 201 | 202 | indices = [] 203 | if self.shuffle: 204 | for bucket in self.buckets: 205 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 206 | else: 207 | for bucket in self.buckets: 208 | indices.append(list(range(len(bucket)))) 209 | 210 | batches = [] 211 | for i in range(len(self.buckets)): 212 | bucket = self.buckets[i] 213 | len_bucket = len(bucket) 214 | ids_bucket = indices[i] 215 | num_samples_bucket = self.num_samples_per_bucket[i] 216 | 217 | # add extra samples to make it evenly divisible 218 | rem = num_samples_bucket - len_bucket 219 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] 220 | 221 | # subsample 222 | ids_bucket = ids_bucket[self.rank::self.num_replicas] 223 | 224 | # batching 225 | for j in range(len(ids_bucket) // self.batch_size): 226 | batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]] 227 | batches.append(batch) 228 | 229 | if self.shuffle: 230 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 231 | batches = [batches[i] for i in batch_ids] 232 | self.batches = batches 233 | 234 | assert len(self.batches) * self.batch_size == self.num_samples 235 | return iter(self.batches) 236 | 237 | def _bisect(self, x, lo=0, hi=None): 238 | if hi is None: 239 | hi = len(self.boundaries) - 1 240 | 241 | if hi > lo: 242 | mid = (hi + lo) // 2 243 | if self.boundaries[mid] < x and x <= self.boundaries[mid+1]: 244 | return mid 245 | elif x <= self.boundaries[mid]: 246 | return self._bisect(x, lo, mid) 247 | else: 248 | return self._bisect(x, mid + 1, hi) 249 | else: 250 | return -1 251 | 252 | def __len__(self): 253 | return self.num_samples // self.batch_size 254 | -------------------------------------------------------------------------------- /downsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import librosa 4 | import numpy as np 5 | from multiprocessing import Pool, cpu_count 6 | from scipy.io import wavfile 7 | from tqdm import tqdm 8 | 9 | 10 | def process(wav_name): 11 | # speaker 's5', 'p280', 'p315' are excluded, 12 | speaker = wav_name[:4] 13 | wav_path = os.path.join(args.in_dir, speaker, wav_name) 14 | if os.path.exists(wav_path) and '_mic2.flac' in wav_path: 15 | os.makedirs(os.path.join(args.out_dir1, speaker), exist_ok=True) 16 | os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True) 17 | wav, sr = librosa.load(wav_path) 18 | wav, _ = librosa.effects.trim(wav, top_db=20) 19 | peak = np.abs(wav).max() 20 | if peak > 1.0: 21 | wav = 0.98 * wav / peak 22 | wav1 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr1) 23 | wav2 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr2) 24 | save_name = wav_name.replace("_mic2.flac", ".wav") 25 | save_path1 = os.path.join(args.out_dir1, speaker, save_name) 26 | save_path2 = os.path.join(args.out_dir2, speaker, save_name) 27 | wavfile.write( 28 | save_path1, 29 | args.sr1, 30 | (wav1 * np.iinfo(np.int16).max).astype(np.int16) 31 | ) 32 | wavfile.write( 33 | save_path2, 34 | args.sr2, 35 | (wav2 * np.iinfo(np.int16).max).astype(np.int16) 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--sr1", type=int, default=16000, help="sampling rate") 42 | parser.add_argument("--sr2", type=int, default=22050, help="sampling rate") 43 | parser.add_argument("--in_dir", type=str, default="/home/Datasets/lijingyi/data/vctk/wav48_silence_trimmed/", help="path to source dir") 44 | parser.add_argument("--out_dir1", type=str, default="./dataset/vctk-16k", help="path to target dir") 45 | parser.add_argument("--out_dir2", type=str, default="./dataset/vctk-22k", help="path to target dir") 46 | args = parser.parse_args() 47 | 48 | pool = Pool(processes=cpu_count()-2) 49 | 50 | for speaker in os.listdir(args.in_dir): 51 | spk_dir = os.path.join(args.in_dir, speaker) 52 | if os.path.isdir(spk_dir): 53 | for _ in tqdm(pool.imap_unordered(process, os.listdir(spk_dir))): 54 | pass 55 | 56 | -------------------------------------------------------------------------------- /extra/DSConv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import remove_weight_norm, weight_norm 3 | 4 | 5 | class Depthwise_Separable_Conv1D(nn.Module): 6 | def __init__( 7 | self, 8 | in_channels, 9 | out_channels, 10 | kernel_size, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | bias=True, 15 | padding_mode='zeros', # TODO: refine this type 16 | device=None, 17 | dtype=None 18 | ): 19 | super().__init__() 20 | self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 21 | groups=in_channels, stride=stride, padding=padding, dilation=dilation, bias=bias, 22 | padding_mode=padding_mode, device=device, dtype=dtype) 23 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, 24 | device=device, dtype=dtype) 25 | 26 | def forward(self, input): 27 | return self.point_conv(self.depth_conv(input)) 28 | 29 | def weight_norm(self): 30 | self.depth_conv = weight_norm(self.depth_conv, name='weight') 31 | self.point_conv = weight_norm(self.point_conv, name='weight') 32 | 33 | def remove_weight_norm(self): 34 | self.depth_conv = remove_weight_norm(self.depth_conv, name='weight') 35 | self.point_conv = remove_weight_norm(self.point_conv, name='weight') 36 | 37 | 38 | class Depthwise_Separable_TransposeConv1D(nn.Module): 39 | def __init__( 40 | self, 41 | in_channels, 42 | out_channels, 43 | kernel_size, 44 | stride=1, 45 | padding=0, 46 | output_padding=0, 47 | bias=True, 48 | dilation=1, 49 | padding_mode='zeros', # TODO: refine this type 50 | device=None, 51 | dtype=None 52 | ): 53 | super().__init__() 54 | self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 55 | groups=in_channels, stride=stride, output_padding=output_padding, 56 | padding=padding, dilation=dilation, bias=bias, padding_mode=padding_mode, 57 | device=device, dtype=dtype) 58 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, 59 | device=device, dtype=dtype) 60 | 61 | def forward(self, input): 62 | return self.point_conv(self.depth_conv(input)) 63 | 64 | def weight_norm(self): 65 | self.depth_conv = weight_norm(self.depth_conv, name='weight') 66 | self.point_conv = weight_norm(self.point_conv, name='weight') 67 | 68 | def remove_weight_norm(self): 69 | remove_weight_norm(self.depth_conv, name='weight') 70 | remove_weight_norm(self.point_conv, name='weight') 71 | 72 | 73 | def weight_norm_modules(module, name='weight', dim=0): 74 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): 75 | module.weight_norm() 76 | return module 77 | else: 78 | return weight_norm(module, name, dim) 79 | 80 | 81 | def remove_weight_norm_modules(module, name='weight'): 82 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D): 83 | module.remove_weight_norm() 84 | else: 85 | remove_weight_norm(module, name) -------------------------------------------------------------------------------- /extra/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | import extra.commons as commons 8 | from extra.DSConv import weight_norm_modules 9 | from extra.modules import LayerNorm 10 | 11 | 12 | class FFT(nn.Module): 13 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0., 14 | proximal_bias=False, proximal_init=True, isflow=False, **kwargs): 15 | super().__init__() 16 | self.hidden_channels = hidden_channels 17 | self.filter_channels = filter_channels 18 | self.n_heads = n_heads 19 | self.n_layers = n_layers 20 | self.kernel_size = kernel_size 21 | self.p_dropout = p_dropout 22 | self.proximal_bias = proximal_bias 23 | self.proximal_init = proximal_init 24 | if isflow: 25 | cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1) 26 | self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) 27 | self.cond_layer = weight_norm_modules(cond_layer, name='weight') 28 | self.gin_channels = kwargs["gin_channels"] 29 | self.drop = nn.Dropout(p_dropout) 30 | self.self_attn_layers = nn.ModuleList() 31 | self.norm_layers_0 = nn.ModuleList() 32 | self.ffn_layers = nn.ModuleList() 33 | self.norm_layers_1 = nn.ModuleList() 34 | for i in range(self.n_layers): 35 | self.self_attn_layers.append( 36 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 37 | proximal_bias=proximal_bias, 38 | proximal_init=proximal_init)) 39 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 40 | self.ffn_layers.append( 41 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 42 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 43 | 44 | def forward(self, x, x_mask, g=None): 45 | """ 46 | x: decoder input 47 | h: encoder output 48 | """ 49 | if g is not None: 50 | g = self.cond_layer(g) 51 | 52 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 53 | x = x * x_mask 54 | for i in range(self.n_layers): 55 | if g is not None: 56 | x = self.cond_pre(x) 57 | cond_offset = i * 2 * self.hidden_channels 58 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] 59 | x = commons.fused_add_tanh_sigmoid_multiply( 60 | x, 61 | g_l, 62 | torch.IntTensor([self.hidden_channels])) 63 | y = self.self_attn_layers[i](x, x, self_attn_mask) 64 | y = self.drop(y) 65 | x = self.norm_layers_0[i](x + y) 66 | 67 | y = self.ffn_layers[i](x, x_mask) 68 | y = self.drop(y) 69 | x = self.norm_layers_1[i](x + y) 70 | x = x * x_mask 71 | return x 72 | 73 | 74 | class Encoder(nn.Module): 75 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, 76 | **kwargs): 77 | super().__init__() 78 | self.hidden_channels = hidden_channels 79 | self.filter_channels = filter_channels 80 | self.n_heads = n_heads 81 | self.n_layers = n_layers 82 | self.kernel_size = kernel_size 83 | self.p_dropout = p_dropout 84 | self.window_size = window_size 85 | 86 | self.drop = nn.Dropout(p_dropout) 87 | self.attn_layers = nn.ModuleList() 88 | self.norm_layers_1 = nn.ModuleList() 89 | self.ffn_layers = nn.ModuleList() 90 | self.norm_layers_2 = nn.ModuleList() 91 | for i in range(self.n_layers): 92 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 93 | window_size=window_size)) 94 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 95 | self.ffn_layers.append( 96 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 97 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 98 | 99 | def forward(self, x, x_mask): 100 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 101 | x = x * x_mask 102 | for i in range(self.n_layers): 103 | y = self.attn_layers[i](x, x, attn_mask) 104 | y = self.drop(y) 105 | x = self.norm_layers_1[i](x + y) 106 | 107 | y = self.ffn_layers[i](x, x_mask) 108 | y = self.drop(y) 109 | x = self.norm_layers_2[i](x + y) 110 | x = x * x_mask 111 | return x 112 | 113 | 114 | class Decoder(nn.Module): 115 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., 116 | proximal_bias=False, proximal_init=True, **kwargs): 117 | super().__init__() 118 | self.hidden_channels = hidden_channels 119 | self.filter_channels = filter_channels 120 | self.n_heads = n_heads 121 | self.n_layers = n_layers 122 | self.kernel_size = kernel_size 123 | self.p_dropout = p_dropout 124 | self.proximal_bias = proximal_bias 125 | self.proximal_init = proximal_init 126 | 127 | self.drop = nn.Dropout(p_dropout) 128 | self.self_attn_layers = nn.ModuleList() 129 | self.norm_layers_0 = nn.ModuleList() 130 | self.encdec_attn_layers = nn.ModuleList() 131 | self.norm_layers_1 = nn.ModuleList() 132 | self.ffn_layers = nn.ModuleList() 133 | self.norm_layers_2 = nn.ModuleList() 134 | for i in range(self.n_layers): 135 | self.self_attn_layers.append( 136 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 137 | proximal_bias=proximal_bias, proximal_init=proximal_init)) 138 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 139 | self.encdec_attn_layers.append( 140 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 141 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 142 | self.ffn_layers.append( 143 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 144 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 145 | 146 | def forward(self, x, x_mask, h, h_mask): 147 | """ 148 | x: decoder input 149 | h: encoder output 150 | """ 151 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 152 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 153 | x = x * x_mask 154 | for i in range(self.n_layers): 155 | y = self.self_attn_layers[i](x, x, self_attn_mask) 156 | y = self.drop(y) 157 | x = self.norm_layers_0[i](x + y) 158 | 159 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 160 | y = self.drop(y) 161 | x = self.norm_layers_1[i](x + y) 162 | 163 | y = self.ffn_layers[i](x, x_mask) 164 | y = self.drop(y) 165 | x = self.norm_layers_2[i](x + y) 166 | x = x * x_mask 167 | return x 168 | 169 | 170 | class MultiHeadAttention(nn.Module): 171 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, 172 | block_length=None, proximal_bias=False, proximal_init=False): 173 | super().__init__() 174 | assert channels % n_heads == 0 175 | 176 | self.channels = channels 177 | self.out_channels = out_channels 178 | self.n_heads = n_heads 179 | self.p_dropout = p_dropout 180 | self.window_size = window_size 181 | self.heads_share = heads_share 182 | self.block_length = block_length 183 | self.proximal_bias = proximal_bias 184 | self.proximal_init = proximal_init 185 | self.attn = None 186 | 187 | self.k_channels = channels // n_heads 188 | self.conv_q = nn.Conv1d(channels, channels, 1) 189 | self.conv_k = nn.Conv1d(channels, channels, 1) 190 | self.conv_v = nn.Conv1d(channels, channels, 1) 191 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 192 | self.drop = nn.Dropout(p_dropout) 193 | 194 | if window_size is not None: 195 | n_heads_rel = 1 if heads_share else n_heads 196 | rel_stddev = self.k_channels ** -0.5 197 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 198 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 199 | 200 | nn.init.xavier_uniform_(self.conv_q.weight) 201 | nn.init.xavier_uniform_(self.conv_k.weight) 202 | nn.init.xavier_uniform_(self.conv_v.weight) 203 | if proximal_init: 204 | with torch.no_grad(): 205 | self.conv_k.weight.copy_(self.conv_q.weight) 206 | self.conv_k.bias.copy_(self.conv_q.bias) 207 | 208 | def forward(self, x, c, attn_mask=None): 209 | q = self.conv_q(x) 210 | k = self.conv_k(c) 211 | v = self.conv_v(c) 212 | 213 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 214 | 215 | x = self.conv_o(x) 216 | return x 217 | 218 | def attention(self, query, key, value, mask=None): 219 | # reshape [b, d, t] -> [b, n_h, t, d_k] 220 | b, d, t_s, t_t = (*key.size(), query.size(2)) 221 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 222 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 223 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 224 | 225 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 226 | if self.window_size is not None: 227 | assert t_s == t_t, "Relative attention is only available for self-attention." 228 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 229 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) 230 | scores_local = self._relative_position_to_absolute_position(rel_logits) 231 | scores = scores + scores_local 232 | if self.proximal_bias: 233 | assert t_s == t_t, "Proximal bias is only available for self-attention." 234 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 235 | if mask is not None: 236 | scores = scores.masked_fill(mask == 0, -1e4) 237 | if self.block_length is not None: 238 | assert t_s == t_t, "Local attention is only available for self-attention." 239 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 240 | scores = scores.masked_fill(block_mask == 0, -1e4) 241 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 242 | p_attn = self.drop(p_attn) 243 | output = torch.matmul(p_attn, value) 244 | if self.window_size is not None: 245 | relative_weights = self._absolute_position_to_relative_position(p_attn) 246 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 247 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 248 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 249 | return output, p_attn 250 | 251 | def _matmul_with_relative_values(self, x, y): 252 | """ 253 | x: [b, h, l, m] 254 | y: [h or 1, m, d] 255 | ret: [b, h, l, d] 256 | """ 257 | ret = torch.matmul(x, y.unsqueeze(0)) 258 | return ret 259 | 260 | def _matmul_with_relative_keys(self, x, y): 261 | """ 262 | x: [b, h, l, d] 263 | y: [h or 1, m, d] 264 | ret: [b, h, l, m] 265 | """ 266 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 267 | return ret 268 | 269 | def _get_relative_embeddings(self, relative_embeddings, length): 270 | 2 * self.window_size + 1 271 | # Pad first before slice to avoid using cond ops. 272 | pad_length = max(length - (self.window_size + 1), 0) 273 | slice_start_position = max((self.window_size + 1) - length, 0) 274 | slice_end_position = slice_start_position + 2 * length - 1 275 | if pad_length > 0: 276 | padded_relative_embeddings = F.pad( 277 | relative_embeddings, 278 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 279 | else: 280 | padded_relative_embeddings = relative_embeddings 281 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] 282 | return used_relative_embeddings 283 | 284 | def _relative_position_to_absolute_position(self, x): 285 | """ 286 | x: [b, h, l, 2*l-1] 287 | ret: [b, h, l, l] 288 | """ 289 | batch, heads, length, _ = x.size() 290 | # Concat columns of pad to shift from relative to absolute indexing. 291 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 292 | 293 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 294 | x_flat = x.view([batch, heads, length * 2 * length]) 295 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 296 | 297 | # Reshape and slice out the padded elements. 298 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] 299 | return x_final 300 | 301 | def _absolute_position_to_relative_position(self, x): 302 | """ 303 | x: [b, h, l, l] 304 | ret: [b, h, l, 2*l-1] 305 | """ 306 | batch, heads, length, _ = x.size() 307 | # padd along column 308 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 309 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 310 | # add 0's in the beginning that will skew the elements after reshape 311 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 312 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 313 | return x_final 314 | 315 | def _attention_bias_proximal(self, length): 316 | """Bias for self-attention to encourage attention to close positions. 317 | Args: 318 | length: an integer scalar. 319 | Returns: 320 | a Tensor with shape [1, 1, length, length] 321 | """ 322 | r = torch.arange(length, dtype=torch.float32) 323 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 324 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 325 | 326 | 327 | class FFN(nn.Module): 328 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, 329 | causal=False): 330 | super().__init__() 331 | self.in_channels = in_channels 332 | self.out_channels = out_channels 333 | self.filter_channels = filter_channels 334 | self.kernel_size = kernel_size 335 | self.p_dropout = p_dropout 336 | self.activation = activation 337 | self.causal = causal 338 | 339 | if causal: 340 | self.padding = self._causal_padding 341 | else: 342 | self.padding = self._same_padding 343 | 344 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 345 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 346 | self.drop = nn.Dropout(p_dropout) 347 | 348 | def forward(self, x, x_mask): 349 | x = self.conv_1(self.padding(x * x_mask)) 350 | if self.activation == "gelu": 351 | x = x * torch.sigmoid(1.702 * x) 352 | else: 353 | x = torch.relu(x) 354 | x = self.drop(x) 355 | x = self.conv_2(self.padding(x * x_mask)) 356 | return x * x_mask 357 | 358 | def _causal_padding(self, x): 359 | if self.kernel_size == 1: 360 | return x 361 | pad_l = self.kernel_size - 1 362 | pad_r = 0 363 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 364 | x = F.pad(x, commons.convert_pad_shape(padding)) 365 | return x 366 | 367 | def _same_padding(self, x): 368 | if self.kernel_size == 1: 369 | return x 370 | pad_l = (self.kernel_size - 1) // 2 371 | pad_r = self.kernel_size // 2 372 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 373 | x = F.pad(x, commons.convert_pad_shape(padding)) 374 | return x 375 | -------------------------------------------------------------------------------- /extra/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def slice_pitch_segments(x, ids_str, segment_size=4): 8 | ret = torch.zeros_like(x[:, :segment_size]) 9 | for i in range(x.size(0)): 10 | idx_str = ids_str[i] 11 | idx_end = idx_str + segment_size 12 | ret[i] = x[i, idx_str:idx_end] 13 | return ret 14 | 15 | 16 | def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): 17 | b, d, t = x.size() 18 | if x_lengths is None: 19 | x_lengths = t 20 | ids_str_max = x_lengths - segment_size + 1 21 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 22 | ret = slice_segments(x, ids_str, segment_size) 23 | ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) 24 | return ret, ret_pitch, ids_str 25 | 26 | 27 | def init_weights(m, mean=0.0, std=0.01): 28 | classname = m.__class__.__name__ 29 | if "Depthwise_Separable" in classname: 30 | m.depth_conv.weight.data.normal_(mean, std) 31 | m.point_conv.weight.data.normal_(mean, std) 32 | elif classname.find("Conv") != -1: 33 | m.weight.data.normal_(mean, std) 34 | 35 | 36 | def get_padding(kernel_size, dilation=1): 37 | return int((kernel_size * dilation - dilation) / 2) 38 | 39 | 40 | def convert_pad_shape(pad_shape): 41 | l = pad_shape[::-1] 42 | pad_shape = [item for sublist in l for item in sublist] 43 | return pad_shape 44 | 45 | 46 | def intersperse(lst, item): 47 | result = [item] * (len(lst) * 2 + 1) 48 | result[1::2] = lst 49 | return result 50 | 51 | 52 | def kl_divergence(m_p, logs_p, m_q, logs_q): 53 | """KL(P||Q)""" 54 | kl = (logs_q - logs_p) - 0.5 55 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q) 56 | return kl 57 | 58 | 59 | def rand_gumbel(shape): 60 | """Sample from the Gumbel distribution, protect from overflows.""" 61 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 62 | return -torch.log(-torch.log(uniform_samples)) 63 | 64 | 65 | def rand_gumbel_like(x): 66 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 67 | return g 68 | 69 | 70 | def slice_segments(x, ids_str, segment_size=4): 71 | ret = torch.zeros_like(x[:, :, :segment_size]) 72 | for i in range(x.size(0)): 73 | idx_str = ids_str[i] 74 | idx_end = idx_str + segment_size 75 | ret[i] = x[i, :, idx_str:idx_end] 76 | return ret 77 | 78 | 79 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 80 | b, d, t = x.size() 81 | if x_lengths is None: 82 | x_lengths = t 83 | ids_str_max = x_lengths - segment_size + 1 84 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 85 | ret = slice_segments(x, ids_str, segment_size) 86 | return ret, ids_str 87 | 88 | 89 | def rand_spec_segments(x, x_lengths=None, segment_size=4): 90 | b, d, t = x.size() 91 | if x_lengths is None: 92 | x_lengths = t 93 | ids_str_max = x_lengths - segment_size 94 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 95 | ret = slice_segments(x, ids_str, segment_size) 96 | return ret, ids_str 97 | 98 | 99 | def get_timing_signal_1d( 100 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 101 | position = torch.arange(length, dtype=torch.float) 102 | num_timescales = channels // 2 103 | log_timescale_increment = ( 104 | math.log(float(max_timescale) / float(min_timescale)) / 105 | (num_timescales - 1)) 106 | inv_timescales = min_timescale * torch.exp( 107 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 108 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 109 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 110 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 111 | signal = signal.view(1, channels, length) 112 | return signal 113 | 114 | 115 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 116 | b, channels, length = x.size() 117 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 118 | return x + signal.to(dtype=x.dtype, device=x.device) 119 | 120 | 121 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 122 | b, channels, length = x.size() 123 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 124 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 125 | 126 | 127 | def subsequent_mask(length): 128 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 129 | return mask 130 | 131 | 132 | @torch.jit.script 133 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 134 | n_channels_int = n_channels[0] 135 | in_act = input_a + input_b 136 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 137 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 138 | acts = t_act * s_act 139 | return acts 140 | 141 | 142 | def shift_1d(x): 143 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 144 | return x 145 | 146 | 147 | def sequence_mask(length, max_length=None): 148 | if max_length is None: 149 | max_length = length.max() 150 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 151 | return x.unsqueeze(0) < length.unsqueeze(1) 152 | 153 | 154 | def generate_path(duration, mask): 155 | """ 156 | duration: [b, 1, t_x] 157 | mask: [b, 1, t_y, t_x] 158 | """ 159 | 160 | b, _, t_y, t_x = mask.shape 161 | cum_duration = torch.cumsum(duration, -1) 162 | 163 | cum_duration_flat = cum_duration.view(b * t_x) 164 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 165 | path = path.view(b, t_x, t_y) 166 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 167 | path = path.unsqueeze(1).transpose(2, 3) * mask 168 | return path 169 | 170 | 171 | def clip_grad_value_(parameters, clip_value, norm_type=2): 172 | if isinstance(parameters, torch.Tensor): 173 | parameters = [parameters] 174 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 175 | norm_type = float(norm_type) 176 | if clip_value is not None: 177 | clip_value = float(clip_value) 178 | 179 | total_norm = 0 180 | for p in parameters: 181 | param_norm = p.grad.data.norm(norm_type) 182 | total_norm += param_norm.item() ** norm_type 183 | if clip_value is not None: 184 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 185 | total_norm = total_norm ** (1. / norm_type) 186 | return total_norm 187 | -------------------------------------------------------------------------------- /extra/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | import extra.attentions as attentions 6 | import extra.commons as commons 7 | from extra.commons import get_padding, init_weights 8 | from extra.DSConv import ( 9 | Depthwise_Separable_Conv1D, 10 | remove_weight_norm_modules, 11 | weight_norm_modules, 12 | ) 13 | 14 | LRELU_SLOPE = 0.1 15 | 16 | Conv1dModel = nn.Conv1d 17 | 18 | 19 | def set_Conv1dModel(use_depthwise_conv): 20 | global Conv1dModel 21 | Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d 22 | 23 | 24 | class LayerNorm(nn.Module): 25 | def __init__(self, channels, eps=1e-5): 26 | super().__init__() 27 | self.channels = channels 28 | self.eps = eps 29 | 30 | self.gamma = nn.Parameter(torch.ones(channels)) 31 | self.beta = nn.Parameter(torch.zeros(channels)) 32 | 33 | def forward(self, x): 34 | x = x.transpose(1, -1) 35 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 36 | return x.transpose(1, -1) 37 | 38 | 39 | class ConvReluNorm(nn.Module): 40 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 41 | super().__init__() 42 | self.in_channels = in_channels 43 | self.hidden_channels = hidden_channels 44 | self.out_channels = out_channels 45 | self.kernel_size = kernel_size 46 | self.n_layers = n_layers 47 | self.p_dropout = p_dropout 48 | assert n_layers > 1, "Number of layers should be larger than 0." 49 | 50 | self.conv_layers = nn.ModuleList() 51 | self.norm_layers = nn.ModuleList() 52 | self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 53 | self.norm_layers.append(LayerNorm(hidden_channels)) 54 | self.relu_drop = nn.Sequential( 55 | nn.ReLU(), 56 | nn.Dropout(p_dropout)) 57 | for _ in range(n_layers - 1): 58 | self.conv_layers.append( 59 | Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 60 | self.norm_layers.append(LayerNorm(hidden_channels)) 61 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 62 | self.proj.weight.data.zero_() 63 | self.proj.bias.data.zero_() 64 | 65 | def forward(self, x, x_mask): 66 | x_org = x 67 | for i in range(self.n_layers): 68 | x = self.conv_layers[i](x * x_mask) 69 | x = self.norm_layers[i](x) 70 | x = self.relu_drop(x) 71 | x = x_org + self.proj(x) 72 | return x * x_mask 73 | 74 | 75 | class WN(torch.nn.Module): 76 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 77 | super(WN, self).__init__() 78 | assert (kernel_size % 2 == 1) 79 | self.hidden_channels = hidden_channels 80 | self.kernel_size = kernel_size, 81 | self.dilation_rate = dilation_rate 82 | self.n_layers = n_layers 83 | self.gin_channels = gin_channels 84 | self.p_dropout = p_dropout 85 | 86 | self.in_layers = torch.nn.ModuleList() 87 | self.res_skip_layers = torch.nn.ModuleList() 88 | self.drop = nn.Dropout(p_dropout) 89 | 90 | if gin_channels != 0: 91 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 92 | self.cond_layer = weight_norm_modules(cond_layer, name='weight') 93 | 94 | for i in range(n_layers): 95 | dilation = dilation_rate ** i 96 | padding = int((kernel_size * dilation - dilation) / 2) 97 | in_layer = Conv1dModel(hidden_channels, 2 * hidden_channels, kernel_size, 98 | dilation=dilation, padding=padding) 99 | in_layer = weight_norm_modules(in_layer, name='weight') 100 | self.in_layers.append(in_layer) 101 | 102 | # last one is not necessary 103 | if i < n_layers - 1: 104 | res_skip_channels = 2 * hidden_channels 105 | else: 106 | res_skip_channels = hidden_channels 107 | 108 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 109 | res_skip_layer = weight_norm_modules(res_skip_layer, name='weight') 110 | self.res_skip_layers.append(res_skip_layer) 111 | 112 | def forward(self, x, x_mask, g=None, **kwargs): 113 | output = torch.zeros_like(x) 114 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 115 | 116 | if g is not None: 117 | g = self.cond_layer(g) 118 | 119 | for i in range(self.n_layers): 120 | x_in = self.in_layers[i](x) 121 | if g is not None: 122 | cond_offset = i * 2 * self.hidden_channels 123 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] 124 | else: 125 | g_l = torch.zeros_like(x_in) 126 | 127 | acts = commons.fused_add_tanh_sigmoid_multiply( 128 | x_in, 129 | g_l, 130 | n_channels_tensor) 131 | acts = self.drop(acts) 132 | 133 | res_skip_acts = self.res_skip_layers[i](acts) 134 | if i < self.n_layers - 1: 135 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 136 | x = (x + res_acts) * x_mask 137 | output = output + res_skip_acts[:, self.hidden_channels:, :] 138 | else: 139 | output = output + res_skip_acts 140 | return output * x_mask 141 | 142 | def remove_weight_norm(self): 143 | if self.gin_channels != 0: 144 | remove_weight_norm_modules(self.cond_layer) 145 | for l in self.in_layers: 146 | remove_weight_norm_modules(l) 147 | for l in self.res_skip_layers: 148 | remove_weight_norm_modules(l) 149 | 150 | 151 | class ResBlock1(torch.nn.Module): 152 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 153 | super(ResBlock1, self).__init__() 154 | self.convs1 = nn.ModuleList([ 155 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], 156 | padding=get_padding(kernel_size, dilation[0]))), 157 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], 158 | padding=get_padding(kernel_size, dilation[1]))), 159 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2], 160 | padding=get_padding(kernel_size, dilation[2]))) 161 | ]) 162 | self.convs1.apply(init_weights) 163 | 164 | self.convs2 = nn.ModuleList([ 165 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, 166 | padding=get_padding(kernel_size, 1))), 167 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, 168 | padding=get_padding(kernel_size, 1))), 169 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1, 170 | padding=get_padding(kernel_size, 1))) 171 | ]) 172 | self.convs2.apply(init_weights) 173 | 174 | def forward(self, x, x_mask=None): 175 | for c1, c2 in zip(self.convs1, self.convs2): 176 | xt = F.leaky_relu(x, LRELU_SLOPE) 177 | if x_mask is not None: 178 | xt = xt * x_mask 179 | xt = c1(xt) 180 | xt = F.leaky_relu(xt, LRELU_SLOPE) 181 | if x_mask is not None: 182 | xt = xt * x_mask 183 | xt = c2(xt) 184 | x = xt + x 185 | if x_mask is not None: 186 | x = x * x_mask 187 | return x 188 | 189 | def remove_weight_norm(self): 190 | for l in self.convs1: 191 | remove_weight_norm_modules(l) 192 | for l in self.convs2: 193 | remove_weight_norm_modules(l) 194 | 195 | 196 | class ResBlock2(torch.nn.Module): 197 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 198 | super(ResBlock2, self).__init__() 199 | self.convs = nn.ModuleList([ 200 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0], 201 | padding=get_padding(kernel_size, dilation[0]))), 202 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1], 203 | padding=get_padding(kernel_size, dilation[1]))) 204 | ]) 205 | self.convs.apply(init_weights) 206 | 207 | def forward(self, x, x_mask=None): 208 | for c in self.convs: 209 | xt = F.leaky_relu(x, LRELU_SLOPE) 210 | if x_mask is not None: 211 | xt = xt * x_mask 212 | xt = c(xt) 213 | x = xt + x 214 | if x_mask is not None: 215 | x = x * x_mask 216 | return x 217 | 218 | def remove_weight_norm(self): 219 | for l in self.convs: 220 | remove_weight_norm_modules(l) 221 | 222 | 223 | class Log(nn.Module): 224 | def forward(self, x, x_mask, reverse=False, **kwargs): 225 | if not reverse: 226 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 227 | logdet = torch.sum(-y, [1, 2]) 228 | return y, logdet 229 | else: 230 | x = torch.exp(x) * x_mask 231 | return x 232 | 233 | 234 | class Flip(nn.Module): 235 | def forward(self, x, *args, reverse=False, **kwargs): 236 | x = torch.flip(x, [1]) 237 | if not reverse: 238 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 239 | return x, logdet 240 | else: 241 | return x 242 | 243 | 244 | class ElementwiseAffine(nn.Module): 245 | def __init__(self, channels): 246 | super().__init__() 247 | self.channels = channels 248 | self.m = nn.Parameter(torch.zeros(channels, 1)) 249 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 250 | 251 | def forward(self, x, x_mask, reverse=False, **kwargs): 252 | if not reverse: 253 | y = self.m + torch.exp(self.logs) * x 254 | y = y * x_mask 255 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 256 | return y, logdet 257 | else: 258 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 259 | return x 260 | 261 | 262 | class ResidualCouplingLayer(nn.Module): 263 | def __init__(self, 264 | channels, 265 | hidden_channels, 266 | kernel_size, 267 | dilation_rate, 268 | n_layers, 269 | p_dropout=0, 270 | gin_channels=0, 271 | mean_only=False, 272 | wn_sharing_parameter=None 273 | ): 274 | assert channels % 2 == 0, "channels should be divisible by 2" 275 | super().__init__() 276 | self.channels = channels 277 | self.hidden_channels = hidden_channels 278 | self.kernel_size = kernel_size 279 | self.dilation_rate = dilation_rate 280 | self.n_layers = n_layers 281 | self.half_channels = channels // 2 282 | self.mean_only = mean_only 283 | 284 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 285 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, 286 | gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter 287 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 288 | self.post.weight.data.zero_() 289 | self.post.bias.data.zero_() 290 | 291 | def forward(self, x, x_mask, g=None, reverse=False): 292 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 293 | h = self.pre(x0) * x_mask 294 | h = self.enc(h, x_mask, g=g) 295 | stats = self.post(h) * x_mask 296 | if not self.mean_only: 297 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 298 | else: 299 | m = stats 300 | logs = torch.zeros_like(m) 301 | 302 | if not reverse: 303 | x1 = m + x1 * torch.exp(logs) * x_mask 304 | x = torch.cat([x0, x1], 1) 305 | logdet = torch.sum(logs, [1, 2]) 306 | return x, logdet 307 | else: 308 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 309 | x = torch.cat([x0, x1], 1) 310 | return x 311 | 312 | 313 | class TransformerCouplingLayer(nn.Module): 314 | def __init__(self, 315 | channels, 316 | hidden_channels, 317 | kernel_size, 318 | n_layers, 319 | n_heads, 320 | p_dropout=0, 321 | filter_channels=0, 322 | mean_only=False, 323 | wn_sharing_parameter=None, 324 | gin_channels=0 325 | ): 326 | assert channels % 2 == 0, "channels should be divisible by 2" 327 | super().__init__() 328 | self.channels = channels 329 | self.hidden_channels = hidden_channels 330 | self.kernel_size = kernel_size 331 | self.n_layers = n_layers 332 | self.half_channels = channels // 2 333 | self.mean_only = mean_only 334 | 335 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 336 | self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, 337 | isflow=True, 338 | gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter 339 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 340 | self.post.weight.data.zero_() 341 | self.post.bias.data.zero_() 342 | 343 | def forward(self, x, x_mask, g=None, reverse=False): 344 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 345 | h = self.pre(x0) * x_mask 346 | h = self.enc(h, x_mask, g=g) 347 | stats = self.post(h) * x_mask 348 | if not self.mean_only: 349 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 350 | else: 351 | m = stats 352 | logs = torch.zeros_like(m) 353 | 354 | if not reverse: 355 | x1 = m + x1 * torch.exp(logs) * x_mask 356 | x = torch.cat([x0, x1], 1) 357 | logdet = torch.sum(logs, [1, 2]) 358 | return x, logdet 359 | else: 360 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 361 | x = torch.cat([x0, x1], 1) 362 | return x 363 | -------------------------------------------------------------------------------- /filelists/val.txt: -------------------------------------------------------------------------------- 1 | DUMMY/p376/p376_392.wav 2 | DUMMY/p286/p286_273.wav 3 | DUMMY/p234/p234_034.wav 4 | DUMMY/p374/p374_392.wav 5 | DUMMY/p287/p287_200.wav 6 | DUMMY/p271/p271_053.wav 7 | DUMMY/p227/p227_159.wav 8 | DUMMY/p261/p261_019.wav 9 | DUMMY/p268/p268_174.wav 10 | DUMMY/p225/p225_021.wav 11 | DUMMY/p361/p361_283.wav 12 | DUMMY/p230/p230_016.wav 13 | DUMMY/p238/p238_193.wav 14 | DUMMY/p257/p257_063.wav 15 | DUMMY/p237/p237_058.wav 16 | DUMMY/p272/p272_381.wav 17 | DUMMY/p343/p343_313.wav 18 | DUMMY/p313/p313_201.wav 19 | DUMMY/p243/p243_041.wav 20 | DUMMY/p363/p363_231.wav 21 | DUMMY/p292/p292_394.wav 22 | DUMMY/p260/p260_027.wav 23 | DUMMY/p273/p273_206.wav 24 | DUMMY/p329/p329_355.wav 25 | DUMMY/p318/p318_197.wav 26 | DUMMY/p293/p293_223.wav 27 | DUMMY/p274/p274_289.wav 28 | DUMMY/p362/p362_019.wav 29 | DUMMY/p229/p229_226.wav 30 | DUMMY/p228/p228_063.wav 31 | DUMMY/p288/p288_171.wav 32 | DUMMY/p243/p243_130.wav 33 | DUMMY/p360/p360_250.wav 34 | DUMMY/p264/p264_078.wav 35 | DUMMY/p301/p301_268.wav 36 | DUMMY/p239/p239_269.wav 37 | DUMMY/p330/p330_306.wav 38 | DUMMY/p273/p273_311.wav 39 | DUMMY/p329/p329_094.wav 40 | DUMMY/p305/p305_163.wav 41 | DUMMY/p347/p347_226.wav 42 | DUMMY/p265/p265_051.wav 43 | DUMMY/p282/p282_078.wav 44 | DUMMY/p226/p226_181.wav 45 | DUMMY/p326/p326_173.wav 46 | DUMMY/p310/p310_254.wav 47 | DUMMY/p313/p313_410.wav 48 | DUMMY/p301/p301_209.wav 49 | DUMMY/p239/p239_343.wav 50 | DUMMY/p249/p249_025.wav 51 | DUMMY/p267/p267_333.wav 52 | DUMMY/p312/p312_337.wav 53 | DUMMY/p340/p340_020.wav 54 | DUMMY/p229/p229_368.wav 55 | DUMMY/p270/p270_450.wav 56 | DUMMY/p298/p298_145.wav 57 | DUMMY/p316/p316_034.wav 58 | DUMMY/p253/p253_148.wav 59 | DUMMY/p279/p279_301.wav 60 | DUMMY/p300/p300_231.wav 61 | DUMMY/p270/p270_068.wav 62 | DUMMY/p258/p258_061.wav 63 | DUMMY/p282/p282_231.wav 64 | DUMMY/p277/p277_315.wav 65 | DUMMY/p362/p362_134.wav 66 | DUMMY/p244/p244_063.wav 67 | DUMMY/p275/p275_212.wav 68 | DUMMY/p233/p233_044.wav 69 | DUMMY/p284/p284_192.wav 70 | DUMMY/p304/p304_156.wav 71 | DUMMY/p249/p249_102.wav 72 | DUMMY/p236/p236_076.wav 73 | DUMMY/p312/p312_373.wav 74 | DUMMY/p259/p259_299.wav 75 | DUMMY/p347/p347_164.wav 76 | DUMMY/p330/p330_363.wav 77 | DUMMY/p303/p303_046.wav 78 | DUMMY/p304/p304_167.wav 79 | DUMMY/p314/p314_235.wav 80 | DUMMY/p336/p336_040.wav 81 | DUMMY/p317/p317_077.wav 82 | DUMMY/p281/p281_402.wav 83 | DUMMY/p241/p241_345.wav 84 | DUMMY/p292/p292_022.wav 85 | DUMMY/p262/p262_240.wav 86 | DUMMY/p263/p263_098.wav 87 | DUMMY/p250/p250_304.wav 88 | DUMMY/p376/p376_062.wav 89 | DUMMY/p264/p264_181.wav 90 | DUMMY/p260/p260_109.wav 91 | DUMMY/p333/p333_282.wav 92 | DUMMY/p310/p310_073.wav 93 | DUMMY/p343/p343_189.wav 94 | DUMMY/p257/p257_217.wav 95 | DUMMY/p288/p288_135.wav 96 | DUMMY/p285/p285_076.wav 97 | DUMMY/p265/p265_242.wav 98 | DUMMY/p226/p226_337.wav 99 | DUMMY/p302/p302_125.wav 100 | DUMMY/p341/p341_020.wav 101 | DUMMY/p246/p246_060.wav 102 | DUMMY/p244/p244_381.wav 103 | DUMMY/p283/p283_222.wav 104 | DUMMY/p266/p266_335.wav 105 | DUMMY/p297/p297_374.wav 106 | DUMMY/p245/p245_190.wav 107 | DUMMY/p231/p231_471.wav 108 | DUMMY/p284/p284_217.wav 109 | DUMMY/p245/p245_118.wav 110 | DUMMY/p240/p240_196.wav 111 | DUMMY/p236/p236_205.wav 112 | DUMMY/p256/p256_042.wav 113 | DUMMY/p326/p326_064.wav 114 | DUMMY/p255/p255_353.wav 115 | DUMMY/p311/p311_138.wav 116 | DUMMY/p345/p345_057.wav 117 | DUMMY/p351/p351_418.wav 118 | DUMMY/p234/p234_023.wav 119 | DUMMY/p307/p307_067.wav 120 | DUMMY/p283/p283_137.wav 121 | DUMMY/p268/p268_058.wav 122 | DUMMY/p339/p339_143.wav 123 | DUMMY/p258/p258_287.wav 124 | DUMMY/p363/p363_322.wav 125 | DUMMY/p237/p237_196.wav 126 | DUMMY/p341/p341_008.wav 127 | DUMMY/p323/p323_278.wav 128 | DUMMY/p231/p231_453.wav 129 | DUMMY/p307/p307_412.wav 130 | DUMMY/p267/p267_266.wav 131 | DUMMY/p293/p293_272.wav 132 | DUMMY/p306/p306_321.wav 133 | DUMMY/p262/p262_393.wav 134 | DUMMY/p314/p314_267.wav 135 | DUMMY/p274/p274_373.wav 136 | DUMMY/p250/p250_336.wav 137 | DUMMY/p334/p334_148.wav 138 | DUMMY/p251/p251_294.wav 139 | DUMMY/p255/p255_042.wav 140 | DUMMY/p294/p294_251.wav 141 | DUMMY/p254/p254_217.wav 142 | DUMMY/p299/p299_122.wav 143 | DUMMY/p269/p269_150.wav 144 | DUMMY/p272/p272_084.wav 145 | DUMMY/p345/p345_320.wav 146 | DUMMY/p300/p300_267.wav 147 | DUMMY/p299/p299_109.wav 148 | DUMMY/p246/p246_160.wav 149 | DUMMY/p278/p278_366.wav 150 | DUMMY/p241/p241_306.wav 151 | DUMMY/p240/p240_213.wav 152 | DUMMY/p311/p311_376.wav 153 | DUMMY/p256/p256_006.wav 154 | DUMMY/p254/p254_356.wav 155 | DUMMY/p276/p276_111.wav 156 | DUMMY/p263/p263_270.wav 157 | DUMMY/p295/p295_371.wav 158 | DUMMY/p230/p230_184.wav 159 | DUMMY/p286/p286_341.wav 160 | DUMMY/p302/p302_160.wav 161 | DUMMY/p232/p232_396.wav 162 | DUMMY/p278/p278_077.wav 163 | DUMMY/p281/p281_367.wav 164 | DUMMY/p336/p336_316.wav 165 | DUMMY/p335/p335_365.wav 166 | DUMMY/p233/p233_041.wav 167 | DUMMY/p225/p225_038.wav 168 | DUMMY/p248/p248_160.wav 169 | DUMMY/p228/p228_230.wav 170 | DUMMY/p285/p285_197.wav 171 | DUMMY/p360/p360_157.wav 172 | DUMMY/p333/p333_259.wav 173 | DUMMY/p308/p308_335.wav 174 | DUMMY/p339/p339_218.wav 175 | DUMMY/p247/p247_320.wav 176 | DUMMY/p364/p364_014.wav 177 | DUMMY/p227/p227_255.wav 178 | DUMMY/p238/p238_060.wav 179 | DUMMY/p323/p323_373.wav 180 | DUMMY/p277/p277_045.wav 181 | DUMMY/p361/p361_152.wav 182 | DUMMY/p275/p275_380.wav 183 | DUMMY/p232/p232_180.wav 184 | DUMMY/p269/p269_130.wav 185 | DUMMY/p316/p316_055.wav 186 | DUMMY/p252/p252_247.wav 187 | DUMMY/p340/p340_036.wav 188 | DUMMY/p294/p294_414.wav 189 | DUMMY/p298/p298_228.wav 190 | DUMMY/p287/p287_348.wav 191 | DUMMY/p295/p295_214.wav 192 | DUMMY/p251/p251_222.wav 193 | DUMMY/p253/p253_339.wav 194 | DUMMY/p305/p305_327.wav 195 | DUMMY/p279/p279_283.wav 196 | DUMMY/p318/p318_342.wav 197 | DUMMY/p351/p351_194.wav 198 | DUMMY/p248/p248_016.wav 199 | DUMMY/p276/p276_321.wav 200 | DUMMY/p259/p259_262.wav 201 | DUMMY/p261/p261_018.wav 202 | DUMMY/p303/p303_320.wav 203 | DUMMY/p297/p297_122.wav 204 | DUMMY/p374/p374_106.wav 205 | DUMMY/p271/p271_200.wav 206 | DUMMY/p247/p247_315.wav 207 | DUMMY/p252/p252_402.wav 208 | DUMMY/p335/p335_308.wav 209 | DUMMY/p308/p308_104.wav 210 | DUMMY/p266/p266_040.wav 211 | DUMMY/p306/p306_312.wav 212 | DUMMY/p317/p317_201.wav 213 | DUMMY/p334/p334_357.wav 214 | DUMMY/p364/p364_027.wav 215 | -------------------------------------------------------------------------------- /hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Generator 2 | 3 | 4 | class AttrDict(dict): 5 | def __init__(self, *args, **kwargs): 6 | super(AttrDict, self).__init__(*args, **kwargs) 7 | self.__dict__ = self -------------------------------------------------------------------------------- /hifigan/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,2,2], 12 | "upsample_kernel_sizes": [16,16,4,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | "resblock_initial_channel": 256, 17 | 18 | "segment_size": 8192, 19 | "num_mels": 80, 20 | "num_freq": 1025, 21 | "n_fft": 1024, 22 | "hop_size": 256, 23 | "win_size": 1024, 24 | 25 | "sampling_rate": 22050, 26 | 27 | "fmin": 0, 28 | "fmax": 8000, 29 | "fmax_loss": null, 30 | 31 | "num_workers": 4, 32 | 33 | "dist_config": { 34 | "dist_backend": "nccl", 35 | "dist_url": "tcp://localhost:54321", 36 | "world_size": 1 37 | } 38 | } -------------------------------------------------------------------------------- /hifigan/generator_v1.txt: -------------------------------------------------------------------------------- 1 | https://github.com/jik876/hifi-gan -------------------------------------------------------------------------------- /hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2 ** i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import commons 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1-dr)**2) 26 | g_loss = torch.mean(dg**2) 27 | loss += (r_loss + g_loss) 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1-dg)**2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | #print(logs_p) 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | -------------------------------------------------------------------------------- /lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """LSTM layers module.""" 8 | 9 | from torch import nn 10 | 11 | 12 | class SLSTM(nn.Module): 13 | """ 14 | LSTM without worrying about the hidden state, nor the layout of the data. 15 | Expects input as convolutional layout. 16 | """ 17 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): 18 | super().__init__() 19 | self.skip = skip 20 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 21 | 22 | def forward(self, x): 23 | x = x.permute(2, 0, 1) 24 | y, _ = self.lstm(x) 25 | if self.skip: 26 | y = y + x 27 | y = y.permute(1, 2, 0) 28 | return y 29 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | 66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 67 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 68 | 69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 70 | return spec 71 | 72 | 73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 74 | global mel_basis 75 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 76 | fmax_dtype_device = str(fmax) + '_' + dtype_device 77 | if fmax_dtype_device not in mel_basis: 78 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 81 | spec = spectral_normalize_torch(spec) 82 | return spec 83 | 84 | 85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 86 | if torch.min(y) < -1.: 87 | print('min value is ', torch.min(y)) 88 | if torch.max(y) > 1.: 89 | print('max value is ', torch.max(y)) 90 | 91 | global mel_basis, hann_window 92 | dtype_device = str(y.dtype) + '_' + str(y.device) 93 | fmax_dtype_device = str(fmax) + '_' + dtype_device 94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 95 | if fmax_dtype_device not in mel_basis: 96 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 98 | if wnsize_dtype_device not in hann_window: 99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 100 | 101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 102 | y = y.squeeze(1) 103 | 104 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 105 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 106 | 107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 108 | 109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 110 | spec = spectral_normalize_torch(spec) 111 | 112 | return spec 113 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import commons 13 | from commons import init_weights, get_padding 14 | 15 | 16 | LRELU_SLOPE = 0.1 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | def __init__(self, channels, eps=1e-5): 21 | super().__init__() 22 | self.channels = channels 23 | self.eps = eps 24 | 25 | self.gamma = nn.Parameter(torch.ones(channels)) 26 | self.beta = nn.Parameter(torch.zeros(channels)) 27 | 28 | def forward(self, x): 29 | x = x.transpose(1, -1) 30 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 31 | return x.transpose(1, -1) 32 | 33 | 34 | class ConvReluNorm(nn.Module): 35 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 36 | super().__init__() 37 | self.in_channels = in_channels 38 | self.hidden_channels = hidden_channels 39 | self.out_channels = out_channels 40 | self.kernel_size = kernel_size 41 | self.n_layers = n_layers 42 | self.p_dropout = p_dropout 43 | assert n_layers > 1, "Number of layers should be larger than 0." 44 | 45 | self.conv_layers = nn.ModuleList() 46 | self.norm_layers = nn.ModuleList() 47 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 48 | self.norm_layers.append(LayerNorm(hidden_channels)) 49 | self.relu_drop = nn.Sequential( 50 | nn.ReLU(), 51 | nn.Dropout(p_dropout)) 52 | for _ in range(n_layers-1): 53 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 54 | self.norm_layers.append(LayerNorm(hidden_channels)) 55 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 56 | self.proj.weight.data.zero_() 57 | self.proj.bias.data.zero_() 58 | 59 | def forward(self, x, x_mask): 60 | x_org = x 61 | for i in range(self.n_layers): 62 | x = self.conv_layers[i](x * x_mask) 63 | x = self.norm_layers[i](x) 64 | x = self.relu_drop(x) 65 | x = x_org + self.proj(x) 66 | return x * x_mask 67 | 68 | 69 | class DDSConv(nn.Module): 70 | """ 71 | Dialted and Depth-Separable Convolution 72 | """ 73 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 74 | super().__init__() 75 | self.channels = channels 76 | self.kernel_size = kernel_size 77 | self.n_layers = n_layers 78 | self.p_dropout = p_dropout 79 | 80 | self.drop = nn.Dropout(p_dropout) 81 | self.convs_sep = nn.ModuleList() 82 | self.convs_1x1 = nn.ModuleList() 83 | self.norms_1 = nn.ModuleList() 84 | self.norms_2 = nn.ModuleList() 85 | for i in range(n_layers): 86 | dilation = kernel_size ** i 87 | padding = (kernel_size * dilation - dilation) // 2 88 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 89 | groups=channels, dilation=dilation, padding=padding 90 | )) 91 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 92 | self.norms_1.append(LayerNorm(channels)) 93 | self.norms_2.append(LayerNorm(channels)) 94 | 95 | def forward(self, x, x_mask, g=None): 96 | if g is not None: 97 | x = x + g 98 | for i in range(self.n_layers): 99 | y = self.convs_sep[i](x * x_mask) 100 | y = self.norms_1[i](y) 101 | y = F.gelu(y) 102 | y = self.convs_1x1[i](y) 103 | y = self.norms_2[i](y) 104 | y = F.gelu(y) 105 | y = self.drop(y) 106 | x = x + y 107 | return x * x_mask 108 | 109 | 110 | class WN(torch.nn.Module): 111 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 112 | super(WN, self).__init__() 113 | assert(kernel_size % 2 == 1) 114 | self.hidden_channels =hidden_channels 115 | self.kernel_size = kernel_size, 116 | self.dilation_rate = dilation_rate 117 | self.n_layers = n_layers 118 | self.gin_channels = gin_channels 119 | self.p_dropout = p_dropout 120 | 121 | self.in_layers = torch.nn.ModuleList() 122 | self.res_skip_layers = torch.nn.ModuleList() 123 | self.drop = nn.Dropout(p_dropout) 124 | 125 | if gin_channels != 0: 126 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 127 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 128 | 129 | for i in range(n_layers): 130 | dilation = dilation_rate ** i 131 | padding = int((kernel_size * dilation - dilation) / 2) 132 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 133 | dilation=dilation, padding=padding) 134 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 135 | self.in_layers.append(in_layer) 136 | 137 | # last one is not necessary 138 | if i < n_layers - 1: 139 | res_skip_channels = 2 * hidden_channels 140 | else: 141 | res_skip_channels = hidden_channels 142 | 143 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 144 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 145 | self.res_skip_layers.append(res_skip_layer) 146 | 147 | def forward(self, x, x_mask, g=None, **kwargs): 148 | output = torch.zeros_like(x) 149 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 150 | 151 | if g is not None: 152 | g = self.cond_layer(g) 153 | 154 | for i in range(self.n_layers): 155 | x_in = self.in_layers[i](x) 156 | if g is not None: 157 | cond_offset = i * 2 * self.hidden_channels 158 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 159 | else: 160 | g_l = torch.zeros_like(x_in) 161 | 162 | acts = commons.fused_add_tanh_sigmoid_multiply( 163 | x_in, 164 | g_l, 165 | n_channels_tensor) 166 | acts = self.drop(acts) 167 | 168 | res_skip_acts = self.res_skip_layers[i](acts) 169 | if i < self.n_layers - 1: 170 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 171 | x = (x + res_acts) * x_mask 172 | output = output + res_skip_acts[:,self.hidden_channels:,:] 173 | else: 174 | output = output + res_skip_acts 175 | return output * x_mask 176 | 177 | def remove_weight_norm(self): 178 | if self.gin_channels != 0: 179 | torch.nn.utils.remove_weight_norm(self.cond_layer) 180 | for l in self.in_layers: 181 | torch.nn.utils.remove_weight_norm(l) 182 | for l in self.res_skip_layers: 183 | torch.nn.utils.remove_weight_norm(l) 184 | 185 | 186 | class ResBlock1(torch.nn.Module): 187 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 188 | super(ResBlock1, self).__init__() 189 | self.convs1 = nn.ModuleList([ 190 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 191 | padding=get_padding(kernel_size, dilation[0]))), 192 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 193 | padding=get_padding(kernel_size, dilation[1]))), 194 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 195 | padding=get_padding(kernel_size, dilation[2]))) 196 | ]) 197 | self.convs1.apply(init_weights) 198 | 199 | self.convs2 = nn.ModuleList([ 200 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 201 | padding=get_padding(kernel_size, 1))), 202 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 203 | padding=get_padding(kernel_size, 1))), 204 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 205 | padding=get_padding(kernel_size, 1))) 206 | ]) 207 | self.convs2.apply(init_weights) 208 | 209 | def forward(self, x, x_mask=None): 210 | for c1, c2 in zip(self.convs1, self.convs2): 211 | xt = F.leaky_relu(x, LRELU_SLOPE) 212 | if x_mask is not None: 213 | xt = xt * x_mask 214 | xt = c1(xt) 215 | xt = F.leaky_relu(xt, LRELU_SLOPE) 216 | if x_mask is not None: 217 | xt = xt * x_mask 218 | xt = c2(xt) 219 | x = xt + x 220 | if x_mask is not None: 221 | x = x * x_mask 222 | return x 223 | 224 | def remove_weight_norm(self): 225 | for l in self.convs1: 226 | remove_weight_norm(l) 227 | for l in self.convs2: 228 | remove_weight_norm(l) 229 | 230 | 231 | class ResBlock2(torch.nn.Module): 232 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 233 | super(ResBlock2, self).__init__() 234 | self.convs = nn.ModuleList([ 235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 236 | padding=get_padding(kernel_size, dilation[0]))), 237 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 238 | padding=get_padding(kernel_size, dilation[1]))) 239 | ]) 240 | self.convs.apply(init_weights) 241 | 242 | def forward(self, x, x_mask=None): 243 | for c in self.convs: 244 | xt = F.leaky_relu(x, LRELU_SLOPE) 245 | if x_mask is not None: 246 | xt = xt * x_mask 247 | xt = c(xt) 248 | x = xt + x 249 | if x_mask is not None: 250 | x = x * x_mask 251 | return x 252 | 253 | def remove_weight_norm(self): 254 | for l in self.convs: 255 | remove_weight_norm(l) 256 | 257 | 258 | class Log(nn.Module): 259 | def forward(self, x, x_mask, reverse=False, **kwargs): 260 | if not reverse: 261 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 262 | logdet = torch.sum(-y, [1, 2]) 263 | return y, logdet 264 | else: 265 | x = torch.exp(x) * x_mask 266 | return x 267 | 268 | 269 | class Flip(nn.Module): 270 | def forward(self, x, *args, reverse=False, **kwargs): 271 | x = torch.flip(x, [1]) 272 | if not reverse: 273 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 274 | return x, logdet 275 | else: 276 | return x 277 | 278 | 279 | class ElementwiseAffine(nn.Module): 280 | def __init__(self, channels): 281 | super().__init__() 282 | self.channels = channels 283 | self.m = nn.Parameter(torch.zeros(channels,1)) 284 | self.logs = nn.Parameter(torch.zeros(channels,1)) 285 | 286 | def forward(self, x, x_mask, reverse=False, **kwargs): 287 | if not reverse: 288 | y = self.m + torch.exp(self.logs) * x 289 | y = y * x_mask 290 | logdet = torch.sum(self.logs * x_mask, [1,2]) 291 | return y, logdet 292 | else: 293 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 294 | return x 295 | 296 | 297 | class ResidualCouplingLayer(nn.Module): 298 | def __init__(self, 299 | channels, 300 | hidden_channels, 301 | kernel_size, 302 | dilation_rate, 303 | n_layers, 304 | p_dropout=0, 305 | gin_channels=0, 306 | mean_only=False): 307 | assert channels % 2 == 0, "channels should be divisible by 2" 308 | super().__init__() 309 | self.channels = channels 310 | self.hidden_channels = hidden_channels 311 | self.kernel_size = kernel_size 312 | self.dilation_rate = dilation_rate 313 | self.n_layers = n_layers 314 | self.half_channels = channels // 2 315 | self.mean_only = mean_only 316 | 317 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 318 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 319 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 320 | self.post.weight.data.zero_() 321 | self.post.bias.data.zero_() 322 | 323 | def forward(self, x, x_mask, g=None, reverse=False): 324 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 325 | h = self.pre(x0) * x_mask 326 | h = self.enc(h, x_mask, g=g) 327 | stats = self.post(h) * x_mask 328 | if not self.mean_only: 329 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 330 | else: 331 | m = stats 332 | logs = torch.zeros_like(m) 333 | 334 | if not reverse: 335 | x1 = m + x1 * torch.exp(logs) * x_mask 336 | x = torch.cat([x0, x1], 1) 337 | logdet = torch.sum(logs, [1,2]) 338 | return x, logdet 339 | else: 340 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 341 | x = torch.cat([x0, x1], 1) 342 | return x 343 | -------------------------------------------------------------------------------- /norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Normalization modules.""" 8 | 9 | import typing as tp 10 | 11 | import einops 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class ConvLayerNorm(nn.LayerNorm): 17 | """ 18 | Convolution-friendly LayerNorm that moves channels to last dimensions 19 | before running the normalization and moves them back to original position right after. 20 | """ 21 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): 22 | super().__init__(normalized_shape, **kwargs) 23 | 24 | def forward(self, x): 25 | x = einops.rearrange(x, 'b ... t -> b t ...') 26 | x = super().forward(x) 27 | x = einops.rearrange(x, 'b t ... -> b ... t') 28 | return 29 | -------------------------------------------------------------------------------- /preprocess_code.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import librosa 4 | 5 | hubert = torch.hub.load("bshall/hubert:main", f"hubert_soft").eval() 6 | 7 | def get_code(vctk_path): 8 | speakers = os.listdir(vctk_path) 9 | for spk in speakers: 10 | files_path = f"{vctk_path}/{spk}" 11 | wavs_path = os.listdir(files_path) 12 | for wav_name in wavs_path: 13 | wav_path = f"{files_path}/{wav_name}" 14 | wav, r = librosa.load(wav_path) 15 | wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0) 16 | c = hubert.units(wav) 17 | c = c.transpose(1,2) 18 | torch.save(c, wav_path.replace(".wav", ".pt")) 19 | c_path = wav_path.replace(".wav", ".pt") 20 | print(f"content code saved in {c_path}") 21 | 22 | if __name__ == "__main__": 23 | vctk_path = ".dataset/vctk-16k" 24 | get_code(vctk_path) -------------------------------------------------------------------------------- /preprocess_flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from random import shuffle 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list") 10 | parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list") 11 | parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list") 12 | parser.add_argument("--source_dir", type=str, default="./dataset/vctk-16k", help="path to source dir") 13 | args = parser.parse_args() 14 | 15 | train = [] 16 | val = [] 17 | test = [] 18 | idx = 0 19 | 20 | for speaker in tqdm(os.listdir(args.source_dir)): 21 | wavs = os.listdir(os.path.join(args.source_dir, speaker)) 22 | shuffle(wavs) 23 | train += wavs[2:-10] 24 | val += wavs[:2] 25 | test += wavs[-10:] 26 | 27 | shuffle(train) 28 | shuffle(val) 29 | shuffle(test) 30 | 31 | print("Writing", args.train_list) 32 | with open(args.train_list, "w") as f: 33 | for fname in tqdm(train): 34 | speaker = fname[:4] 35 | wavpath = os.path.join("DUMMY", speaker, fname) 36 | f.write(wavpath + "\n") 37 | 38 | print("Writing", args.val_list) 39 | with open(args.val_list, "w") as f: 40 | for fname in tqdm(val): 41 | speaker = fname[:4] 42 | wavpath = os.path.join("DUMMY", speaker, fname) 43 | f.write(wavpath + "\n") 44 | 45 | print("Writing", args.test_list) 46 | with open(args.test_list, "w") as f: 47 | for fname in tqdm(test): 48 | speaker = fname[:4] 49 | wavpath = os.path.join("DUMMY", speaker, fname) 50 | f.write(wavpath + "\n") 51 | -------------------------------------------------------------------------------- /preprocess_spk.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from speaker_encoder.voice_encoder import SpeakerEncoder 3 | from speaker_encoder.audio import preprocess_wav 4 | from pathlib import Path 5 | import numpy as np 6 | from os.path import join, basename, split 7 | from tqdm import tqdm 8 | from multiprocessing import cpu_count 9 | from concurrent.futures import ProcessPoolExecutor 10 | from functools import partial 11 | import glob 12 | import argparse 13 | 14 | 15 | def build_from_path(in_dir, out_dir, weights_fpath, num_workers=1): 16 | executor = ProcessPoolExecutor(max_workers=num_workers) 17 | futures = [] 18 | wavfile_paths = glob.glob(os.path.join(in_dir, '*.wav')) 19 | wavfile_paths= sorted(wavfile_paths) 20 | for wav_path in wavfile_paths: 21 | futures.append(executor.submit( 22 | partial(_compute_spkEmbed, out_dir, wav_path, weights_fpath))) 23 | return [future.result() for future in tqdm(futures)] 24 | 25 | def _compute_spkEmbed(out_dir, wav_path, weights_fpath): 26 | utt_id = os.path.basename(wav_path).rstrip(".wav") 27 | fpath = Path(wav_path) 28 | wav = preprocess_wav(fpath) 29 | 30 | encoder = SpeakerEncoder(weights_fpath) 31 | embed = encoder.embed_utterance(wav) 32 | fname_save = os.path.join(out_dir, f"{utt_id}.npy") 33 | np.save(fname_save, embed, allow_pickle=False) 34 | return os.path.basename(fname_save) 35 | 36 | def preprocess(in_dir, out_dir_root, spk, weights_fpath, num_workers): 37 | out_dir = os.path.join(out_dir_root, spk) 38 | os.makedirs(out_dir, exist_ok=True) 39 | metadata = build_from_path(in_dir, out_dir, weights_fpath, num_workers) 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--in_dir', type=str, 44 | default='dataset/vctk-16k/') 45 | parser.add_argument('--num_workers', type=int, default=12) 46 | parser.add_argument('--out_dir_root', type=str, 47 | default='dataset/vctk-16k/') 48 | parser.add_argument('--spk_encoder_ckpt', type=str, \ 49 | default='speaker_encoder/ckpt/pretrained_bak_5805000.pt') 50 | 51 | args = parser.parse_args() 52 | 53 | #split_list = ['train-clean-100', 'train-clean-360'] 54 | 55 | sub_folder_list = os.listdir(args.in_dir) 56 | sub_folder_list.sort() 57 | 58 | args.num_workers = args.num_workers if args.num_workers is not None else cpu_count() 59 | print("Number of workers: ", args.num_workers) 60 | ckpt_step = os.path.basename(args.spk_encoder_ckpt).split('.')[0].split('_')[-1] 61 | spk_embed_out_dir = os.path.join(args.out_dir_root, "spk") 62 | print("[INFO] spk_embed_out_dir: ", spk_embed_out_dir) 63 | os.makedirs(spk_embed_out_dir, exist_ok=True) 64 | 65 | #for data_split in split_list: 66 | # sub_folder_list = os.listdir(args.in_dir, data_split) 67 | for spk in sub_folder_list: 68 | print("Preprocessing {} ...".format(spk)) 69 | in_dir = os.path.join(args.in_dir, spk) 70 | if not os.path.isdir(in_dir): 71 | continue 72 | #out_dir = os.path.join(args.out_dir, spk) 73 | preprocess(in_dir, spk_embed_out_dir, spk, args.spk_encoder_ckpt, args.num_workers) 74 | ''' 75 | for data_split in split_list: 76 | in_dir = os.path.join(args.in_dir, data_split) 77 | preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers) 78 | ''' 79 | 80 | print("DONE!") 81 | sys.exit(0) 82 | 83 | 84 | -------------------------------------------------------------------------------- /preprocess_sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import librosa 5 | import json 6 | from glob import glob 7 | from tqdm import tqdm 8 | from scipy.io import wavfile 9 | 10 | import utils 11 | from mel_processing import mel_spectrogram_torch 12 | from wavlm import WavLM, WavLMConfig 13 | #import h5py 14 | import logging 15 | logging.getLogger('numba').setLevel(logging.WARNING) 16 | 17 | 18 | def process(filename): 19 | basename = os.path.basename(filename) 20 | speaker = filename.split("/")[-2]#basename[:4] 21 | wav_dir = os.path.join(args.wav_dir, speaker) 22 | ssl_dir = os.path.join(args.ssl_dir, speaker) 23 | os.makedirs(wav_dir, exist_ok=True) 24 | os.makedirs(ssl_dir, exist_ok=True) 25 | wav, _ = librosa.load(filename, sr=hps.sampling_rate) 26 | wav = torch.from_numpy(wav).unsqueeze(0).cuda() 27 | mel = mel_spectrogram_torch( 28 | wav, 29 | hps.n_fft, 30 | hps.num_mels, 31 | hps.sampling_rate, 32 | hps.hop_size, 33 | hps.win_size, 34 | hps.fmin, 35 | hps.fmax 36 | ) 37 | ''' 38 | f = {} 39 | for i in range(args.min, args.max+1): 40 | fpath = os.path.join(ssl_dir, f"{i}.hdf5") 41 | f[i] = h5py.File(fpath, "a") 42 | ''' 43 | for i in range(args.min, args.max+1): 44 | mel_rs = utils.transform(mel, i) 45 | wav_rs = vocoder(mel_rs)[0][0].detach().cpu().numpy() 46 | _wav_rs = librosa.resample(wav_rs, orig_sr=hps.sampling_rate, target_sr=args.sr) 47 | wav_rs = torch.from_numpy(_wav_rs).cuda().unsqueeze(0) 48 | c = utils.get_content(cmodel, wav_rs) 49 | ssl_path = os.path.join(ssl_dir, basename.replace(".wav", f"_{i}.pt")) 50 | torch.save(c.cpu(), ssl_path) 51 | #print(wav_rs.size(), c.size()) 52 | wav_path = os.path.join(wav_dir, basename.replace(".wav", f"_{i}.wav")) 53 | wavfile.write( 54 | wav_path, 55 | args.sr, 56 | _wav_rs 57 | ) 58 | ''' 59 | f[i][basename[:-4]] = c.cpu() 60 | for i in range(args.min, args.max+1): 61 | f[i].close() 62 | ''' 63 | 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("--sr", type=int, default=16000, help="sampling rate") 69 | parser.add_argument("--min", type=int, default=68, help="min") 70 | parser.add_argument("--max", type=int, default=92, help="max") 71 | parser.add_argument("--config", type=str, default="hifigan/config.json", help="path to config file") 72 | parser.add_argument("--in_dir", type=str, default="dataset/vctk-22k", help="path to input dir") 73 | parser.add_argument("--wav_dir", type=str, default="dataset/sr/wav", help="path to output wav dir") 74 | parser.add_argument("--ssl_dir", type=str, default="dataset/sr/wavlm", help="path to output ssl dir") 75 | args = parser.parse_args() 76 | 77 | print("Loading WavLM for content...") 78 | checkpoint = torch.load('wavlm/WavLM-Large.pt') 79 | cfg = WavLMConfig(checkpoint['cfg']) 80 | cmodel = WavLM(cfg).cuda() 81 | cmodel.load_state_dict(checkpoint['model']) 82 | cmodel.eval() 83 | print("Loaded WavLM.") 84 | 85 | print("Loading vocoder...") 86 | vocoder = utils.get_vocoder(0) 87 | vocoder.eval() 88 | print("Loaded vocoder.") 89 | 90 | config_path = args.config 91 | with open(config_path, "r") as f: 92 | data = f.read() 93 | config = json.loads(data) 94 | hps = utils.HParams(**config) 95 | 96 | filenames = glob(f'{args.in_dir}/*/*.wav', recursive=True)#[:10] 97 | 98 | for filename in tqdm(filenames): 99 | process(filename) 100 | 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | glob2==0.7 2 | tqdm==4.62.3 3 | librosa==0.8.1 4 | numpy==1.21.6 5 | scipy==1.7.2 6 | tensorboard==2.7.0 7 | torch==1.10.0 8 | torchvision==0.9.0 9 | webrtcvad==2.0.10 10 | -------------------------------------------------------------------------------- /resources/NeurlVC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzy1hjq/NeuralVC/9b562356ac008c76ead4b31251ea956a0c15eda2/resources/NeurlVC.png -------------------------------------------------------------------------------- /speaker_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzy1hjq/NeuralVC/9b562356ac008c76ead4b31251ea956a0c15eda2/speaker_encoder/__init__.py -------------------------------------------------------------------------------- /speaker_encoder/audio.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import binary_dilation 2 | from speaker_encoder.params_data import * 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | import numpy as np 6 | import webrtcvad 7 | import librosa 8 | import struct 9 | 10 | int16_max = (2 ** 15) - 1 11 | 12 | 13 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], 14 | source_sr: Optional[int] = None): 15 | """ 16 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform 17 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters. 18 | 19 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 20 | just .wav), either the waveform as a numpy array of floats. 21 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 22 | preprocessing. After preprocessing, the waveform's sampling rate will match the data 23 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 24 | this argument will be ignored. 25 | """ 26 | # Load the wav from disk if needed 27 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 28 | wav, source_sr = librosa.load(fpath_or_wav, sr=None) 29 | else: 30 | wav = fpath_or_wav 31 | 32 | # Resample the wav if needed 33 | if source_sr is not None and source_sr != sampling_rate: 34 | wav = librosa.resample(wav, source_sr, sampling_rate) 35 | 36 | # Apply the preprocessing: normalize volume and shorten long silences 37 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) 38 | wav = trim_long_silences(wav) 39 | 40 | return wav 41 | 42 | 43 | def wav_to_mel_spectrogram(wav): 44 | """ 45 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 46 | Note: this not a log-mel spectrogram. 47 | """ 48 | frames = librosa.feature.melspectrogram( 49 | y=wav, 50 | sr=sampling_rate, 51 | n_fft=int(sampling_rate * mel_window_length / 1000), 52 | hop_length=int(sampling_rate * mel_window_step / 1000), 53 | n_mels=mel_n_channels 54 | ) 55 | return frames.astype(np.float32).T 56 | 57 | 58 | def trim_long_silences(wav): 59 | """ 60 | Ensures that segments without voice in the waveform remain no longer than a 61 | threshold determined by the VAD parameters in params.py. 62 | 63 | :param wav: the raw waveform as a numpy array of floats 64 | :return: the same waveform with silences trimmed away (length <= original wav length) 65 | """ 66 | # Compute the voice detection window size 67 | samples_per_window = (vad_window_length * sampling_rate) // 1000 68 | 69 | # Trim the end of the audio to have a multiple of the window size 70 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 71 | 72 | # Convert the float waveform to 16-bit mono PCM 73 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 74 | 75 | # Perform voice activation detection 76 | voice_flags = [] 77 | vad = webrtcvad.Vad(mode=3) 78 | for window_start in range(0, len(wav), samples_per_window): 79 | window_end = window_start + samples_per_window 80 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 81 | sample_rate=sampling_rate)) 82 | voice_flags = np.array(voice_flags) 83 | 84 | # Smooth the voice detection with a moving average 85 | def moving_average(array, width): 86 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 87 | ret = np.cumsum(array_padded, dtype=float) 88 | ret[width:] = ret[width:] - ret[:-width] 89 | return ret[width - 1:] / width 90 | 91 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 92 | audio_mask = np.round(audio_mask).astype(np.bool) 93 | 94 | # Dilate the voiced regions 95 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 96 | audio_mask = np.repeat(audio_mask, samples_per_window) 97 | 98 | return wav[audio_mask == True] 99 | 100 | 101 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): 102 | if increase_only and decrease_only: 103 | raise ValueError("Both increase only and decrease only are set") 104 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) 105 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): 106 | return wav 107 | return wav * (10 ** (dBFS_change / 20)) 108 | -------------------------------------------------------------------------------- /speaker_encoder/ckpt/pretrained_bak_5805000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzy1hjq/NeuralVC/9b562356ac008c76ead4b31251ea956a0c15eda2/speaker_encoder/ckpt/pretrained_bak_5805000.pt -------------------------------------------------------------------------------- /speaker_encoder/ckpt/pretrained_bak_5805000.pt.txt: -------------------------------------------------------------------------------- 1 | https://github.com/liusongxiang/ppg-vc/tree/main/speaker_encoder/ckpt -------------------------------------------------------------------------------- /speaker_encoder/compute_embed.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder import inference as encoder 2 | from multiprocessing.pool import Pool 3 | from functools import partial 4 | from pathlib import Path 5 | # from utils import logmmse 6 | # from tqdm import tqdm 7 | # import numpy as np 8 | # import librosa 9 | 10 | 11 | def embed_utterance(fpaths, encoder_model_fpath): 12 | if not encoder.is_loaded(): 13 | encoder.load_model(encoder_model_fpath) 14 | 15 | # Compute the speaker embedding of the utterance 16 | wav_fpath, embed_fpath = fpaths 17 | wav = np.load(wav_fpath) 18 | wav = encoder.preprocess_wav(wav) 19 | embed = encoder.embed_utterance(wav) 20 | np.save(embed_fpath, embed, allow_pickle=False) 21 | 22 | 23 | def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int): 24 | 25 | wav_dir = outdir_root.joinpath("audio") 26 | metadata_fpath = synthesizer_root.joinpath("train.txt") 27 | assert wav_dir.exists() and metadata_fpath.exists() 28 | embed_dir = synthesizer_root.joinpath("embeds") 29 | embed_dir.mkdir(exist_ok=True) 30 | 31 | # Gather the input wave filepath and the target output embed filepath 32 | with metadata_fpath.open("r") as metadata_file: 33 | metadata = [line.split("|") for line in metadata_file] 34 | fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata] 35 | 36 | # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here. 37 | # Embed the utterances in separate threads 38 | func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath) 39 | job = Pool(n_processes).imap(func, fpaths) 40 | list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) -------------------------------------------------------------------------------- /speaker_encoder/config.py: -------------------------------------------------------------------------------- 1 | librispeech_datasets = { 2 | "train": { 3 | "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], 4 | "other": ["LibriSpeech/train-other-500"] 5 | }, 6 | "test": { 7 | "clean": ["LibriSpeech/test-clean"], 8 | "other": ["LibriSpeech/test-other"] 9 | }, 10 | "dev": { 11 | "clean": ["LibriSpeech/dev-clean"], 12 | "other": ["LibriSpeech/dev-other"] 13 | }, 14 | } 15 | libritts_datasets = { 16 | "train": { 17 | "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], 18 | "other": ["LibriTTS/train-other-500"] 19 | }, 20 | "test": { 21 | "clean": ["LibriTTS/test-clean"], 22 | "other": ["LibriTTS/test-other"] 23 | }, 24 | "dev": { 25 | "clean": ["LibriTTS/dev-clean"], 26 | "other": ["LibriTTS/dev-other"] 27 | }, 28 | } 29 | voxceleb_datasets = { 30 | "voxceleb1" : { 31 | "train": ["VoxCeleb1/wav"], 32 | "test": ["VoxCeleb1/test_wav"] 33 | }, 34 | "voxceleb2" : { 35 | "train": ["VoxCeleb2/dev/aac"], 36 | "test": ["VoxCeleb2/test_wav"] 37 | } 38 | } 39 | 40 | other_datasets = [ 41 | "LJSpeech-1.1", 42 | "VCTK-Corpus/wav48", 43 | ] 44 | 45 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] 46 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader 3 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/random_cycler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class RandomCycler: 4 | """ 5 | Creates an internal copy of a sequence and allows access to its items in a constrained random 6 | order. For a source sequence of n items and one or several consecutive queries of a total 7 | of m items, the following guarantees hold (one implies the other): 8 | - Each item will be returned between m // n and ((m - 1) // n) + 1 times. 9 | - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. 10 | """ 11 | 12 | def __init__(self, source): 13 | if len(source) == 0: 14 | raise Exception("Can't create RandomCycler from an empty collection") 15 | self.all_items = list(source) 16 | self.next_items = [] 17 | 18 | def sample(self, count: int): 19 | shuffle = lambda l: random.sample(l, len(l)) 20 | 21 | out = [] 22 | while count > 0: 23 | if count >= len(self.all_items): 24 | out.extend(shuffle(list(self.all_items))) 25 | count -= len(self.all_items) 26 | continue 27 | n = min(count, len(self.next_items)) 28 | out.extend(self.next_items[:n]) 29 | count -= n 30 | self.next_items = self.next_items[n:] 31 | if len(self.next_items) == 0: 32 | self.next_items = shuffle(list(self.all_items)) 33 | return out 34 | 35 | def __next__(self): 36 | return self.sample(1)[0] 37 | 38 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/speaker.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.random_cycler import RandomCycler 2 | from speaker_encoder.data_objects.utterance import Utterance 3 | from pathlib import Path 4 | 5 | # Contains the set of utterances of a single speaker 6 | class Speaker: 7 | def __init__(self, root: Path): 8 | self.root = root 9 | self.name = root.name 10 | self.utterances = None 11 | self.utterance_cycler = None 12 | 13 | def _load_utterances(self): 14 | with self.root.joinpath("_sources.txt").open("r") as sources_file: 15 | sources = [l.split(",") for l in sources_file] 16 | sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} 17 | self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] 18 | self.utterance_cycler = RandomCycler(self.utterances) 19 | 20 | def random_partial(self, count, n_frames): 21 | """ 22 | Samples a batch of unique partial utterances from the disk in a way that all 23 | utterances come up at least once every two cycles and in a random order every time. 24 | 25 | :param count: The number of partial utterances to sample from the set of utterances from 26 | that speaker. Utterances are guaranteed not to be repeated if is not larger than 27 | the number of utterances available. 28 | :param n_frames: The number of frames in the partial utterance. 29 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, 30 | frames are the frames of the partial utterances and range is the range of the partial 31 | utterance with regard to the complete utterance. 32 | """ 33 | if self.utterances is None: 34 | self._load_utterances() 35 | 36 | utterances = self.utterance_cycler.sample(count) 37 | 38 | a = [(u,) + u.random_partial(n_frames) for u in utterances] 39 | 40 | return a 41 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/speaker_batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | from speaker_encoder.data_objects.speaker import Speaker 4 | 5 | class SpeakerBatch: 6 | def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): 7 | self.speakers = speakers 8 | self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} 9 | 10 | # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with 11 | # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) 12 | self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) 13 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/speaker_verification_dataset.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.random_cycler import RandomCycler 2 | from speaker_encoder.data_objects.speaker_batch import SpeakerBatch 3 | from speaker_encoder.data_objects.speaker import Speaker 4 | from speaker_encoder.params_data import partials_n_frames 5 | from torch.utils.data import Dataset, DataLoader 6 | from pathlib import Path 7 | 8 | # TODO: improve with a pool of speakers for data efficiency 9 | 10 | class SpeakerVerificationDataset(Dataset): 11 | def __init__(self, datasets_root: Path): 12 | self.root = datasets_root 13 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] 14 | if len(speaker_dirs) == 0: 15 | raise Exception("No speakers found. Make sure you are pointing to the directory " 16 | "containing all preprocessed speaker directories.") 17 | self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] 18 | self.speaker_cycler = RandomCycler(self.speakers) 19 | 20 | def __len__(self): 21 | return int(1e10) 22 | 23 | def __getitem__(self, index): 24 | return next(self.speaker_cycler) 25 | 26 | def get_logs(self): 27 | log_string = "" 28 | for log_fpath in self.root.glob("*.txt"): 29 | with log_fpath.open("r") as log_file: 30 | log_string += "".join(log_file.readlines()) 31 | return log_string 32 | 33 | 34 | class SpeakerVerificationDataLoader(DataLoader): 35 | def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, 36 | batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, 37 | worker_init_fn=None): 38 | self.utterances_per_speaker = utterances_per_speaker 39 | 40 | super().__init__( 41 | dataset=dataset, 42 | batch_size=speakers_per_batch, 43 | shuffle=False, 44 | sampler=sampler, 45 | batch_sampler=batch_sampler, 46 | num_workers=num_workers, 47 | collate_fn=self.collate, 48 | pin_memory=pin_memory, 49 | drop_last=False, 50 | timeout=timeout, 51 | worker_init_fn=worker_init_fn 52 | ) 53 | 54 | def collate(self, speakers): 55 | return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) 56 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/utterance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Utterance: 5 | def __init__(self, frames_fpath, wave_fpath): 6 | self.frames_fpath = frames_fpath 7 | self.wave_fpath = wave_fpath 8 | 9 | def get_frames(self): 10 | return np.load(self.frames_fpath) 11 | 12 | def random_partial(self, n_frames): 13 | """ 14 | Crops the frames into a partial utterance of n_frames 15 | 16 | :param n_frames: The number of frames of the partial utterance 17 | :return: the partial utterance frames and a tuple indicating the start and end of the 18 | partial utterance in the complete utterance. 19 | """ 20 | frames = self.get_frames() 21 | if frames.shape[0] == n_frames: 22 | start = 0 23 | else: 24 | start = np.random.randint(0, frames.shape[0] - n_frames) 25 | end = start + n_frames 26 | return frames[start:end], (start, end) -------------------------------------------------------------------------------- /speaker_encoder/hparams.py: -------------------------------------------------------------------------------- 1 | ## Mel-filterbank 2 | mel_window_length = 25 # In milliseconds 3 | mel_window_step = 10 # In milliseconds 4 | mel_n_channels = 40 5 | 6 | 7 | ## Audio 8 | sampling_rate = 16000 9 | # Number of spectrogram frames in a partial utterance 10 | partials_n_frames = 160 # 1600 ms 11 | 12 | 13 | ## Voice Activation Detection 14 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 15 | # This sets the granularity of the VAD. Should not need to be changed. 16 | vad_window_length = 30 # In milliseconds 17 | # Number of frames to average together when performing the moving average smoothing. 18 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 19 | vad_moving_average_width = 8 20 | # Maximum number of consecutive silent frames a segment can have. 21 | vad_max_silence_length = 6 22 | 23 | 24 | ## Audio volume normalization 25 | audio_norm_target_dBFS = -30 26 | 27 | 28 | ## Model parameters 29 | model_hidden_size = 256 30 | model_embedding_size = 256 31 | model_num_layers = 3 -------------------------------------------------------------------------------- /speaker_encoder/inference.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.params_data import * 2 | from speaker_encoder.model import SpeakerEncoder 3 | from speaker_encoder.audio import preprocess_wav # We want to expose this function from here 4 | from matplotlib import cm 5 | from speaker_encoder import audio 6 | from pathlib import Path 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | 11 | _model = None # type: SpeakerEncoder 12 | _device = None # type: torch.device 13 | 14 | 15 | def load_model(weights_fpath: Path, device=None): 16 | """ 17 | Loads the model in memory. If this function is not explicitely called, it will be run on the 18 | first call to embed_frames() with the default weights file. 19 | 20 | :param weights_fpath: the path to saved model weights. 21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The 22 | model will be loaded and will run on this device. Outputs will however always be on the cpu. 23 | If None, will default to your GPU if it"s available, otherwise your CPU. 24 | """ 25 | # TODO: I think the slow loading of the encoder might have something to do with the device it 26 | # was saved on. Worth investigating. 27 | global _model, _device 28 | if device is None: 29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | elif isinstance(device, str): 31 | _device = torch.device(device) 32 | _model = SpeakerEncoder(_device, torch.device("cpu")) 33 | checkpoint = torch.load(weights_fpath) 34 | _model.load_state_dict(checkpoint["model_state"]) 35 | _model.eval() 36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) 37 | 38 | 39 | def is_loaded(): 40 | return _model is not None 41 | 42 | 43 | def embed_frames_batch(frames_batch): 44 | """ 45 | Computes embeddings for a batch of mel spectrogram. 46 | 47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape 48 | (batch_size, n_frames, n_channels) 49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) 50 | """ 51 | if _model is None: 52 | raise Exception("Model was not loaded. Call load_model() before inference.") 53 | 54 | frames = torch.from_numpy(frames_batch).to(_device) 55 | embed = _model.forward(frames).detach().cpu().numpy() 56 | return embed 57 | 58 | 59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, 60 | min_pad_coverage=0.75, overlap=0.5): 61 | """ 62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain 63 | partial utterances of each. Both the waveform and the mel 64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to 65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those 66 | defined in params_data.py. 67 | 68 | The returned ranges may be indexing further than the length of the waveform. It is 69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop. 70 | 71 | :param n_samples: the number of samples in the waveform 72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial 73 | utterance 74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have 75 | enough frames. If at least of are present, 76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise, 77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial 78 | utterance, this parameter is ignored so that the function always returns at least 1 slice. 79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial 80 | utterances are entirely disjoint. 81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial 83 | utterances. 84 | """ 85 | assert 0 <= overlap < 1 86 | assert 0 < min_pad_coverage <= 1 87 | 88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) 91 | 92 | # Compute the slices 93 | wav_slices, mel_slices = [], [] 94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) 95 | for i in range(0, steps, frame_step): 96 | mel_range = np.array([i, i + partial_utterance_n_frames]) 97 | wav_range = mel_range * samples_per_frame 98 | mel_slices.append(slice(*mel_range)) 99 | wav_slices.append(slice(*wav_range)) 100 | 101 | # Evaluate whether extra padding is warranted or not 102 | last_wav_range = wav_slices[-1] 103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 104 | if coverage < min_pad_coverage and len(mel_slices) > 1: 105 | mel_slices = mel_slices[:-1] 106 | wav_slices = wav_slices[:-1] 107 | 108 | return wav_slices, mel_slices 109 | 110 | 111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): 112 | """ 113 | Computes an embedding for a single utterance. 114 | 115 | # TODO: handle multiple wavs to benefit from batching on GPU 116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 117 | :param using_partials: if True, then the utterance is split in partial utterances of 118 | frames and the utterance embedding is computed from their 119 | normalized average. If False, the utterance is instead computed from feeding the entire 120 | spectogram to the network. 121 | :param return_partials: if True, the partial embeddings will also be returned along with the 122 | wav slices that correspond to the partial embeddings. 123 | :param kwargs: additional arguments to compute_partial_splits() 124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 125 | is True, the partial utterances as a numpy array of float32 of shape 126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 127 | returned. If is simultaneously set to False, both these values will be None 128 | instead. 129 | """ 130 | # Process the entire utterance if not using partials 131 | if not using_partials: 132 | frames = audio.wav_to_mel_spectrogram(wav) 133 | embed = embed_frames_batch(frames[None, ...])[0] 134 | if return_partials: 135 | return embed, None, None 136 | return embed 137 | 138 | # Compute where to split the utterance into partials and pad if necessary 139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 140 | max_wave_length = wave_slices[-1].stop 141 | if max_wave_length >= len(wav): 142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 143 | 144 | # Split the utterance into partials 145 | frames = audio.wav_to_mel_spectrogram(wav) 146 | frames_batch = np.array([frames[s] for s in mel_slices]) 147 | partial_embeds = embed_frames_batch(frames_batch) 148 | 149 | # Compute the utterance embedding from the partial embeddings 150 | raw_embed = np.mean(partial_embeds, axis=0) 151 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 152 | 153 | if return_partials: 154 | return embed, partial_embeds, wave_slices 155 | return embed 156 | 157 | 158 | def embed_speaker(wavs, **kwargs): 159 | raise NotImplemented() 160 | 161 | 162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): 163 | if ax is None: 164 | ax = plt.gca() 165 | 166 | if shape is None: 167 | height = int(np.sqrt(len(embed))) 168 | shape = (height, -1) 169 | embed = embed.reshape(shape) 170 | 171 | cmap = cm.get_cmap() 172 | mappable = ax.imshow(embed, cmap=cmap) 173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) 174 | cbar.set_clim(*color_range) 175 | 176 | ax.set_xticks([]), ax.set_yticks([]) 177 | ax.set_title(title) 178 | -------------------------------------------------------------------------------- /speaker_encoder/model.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.params_model import * 2 | from speaker_encoder.params_data import * 3 | from scipy.interpolate import interp1d 4 | from sklearn.metrics import roc_curve 5 | from torch.nn.utils import clip_grad_norm_ 6 | from scipy.optimize import brentq 7 | from torch import nn 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class SpeakerEncoder(nn.Module): 13 | def __init__(self, device, loss_device): 14 | super().__init__() 15 | self.loss_device = loss_device 16 | 17 | # Network defition 18 | self.lstm = nn.LSTM(input_size=mel_n_channels, # 40 19 | hidden_size=model_hidden_size, # 256 20 | num_layers=model_num_layers, # 3 21 | batch_first=True).to(device) 22 | self.linear = nn.Linear(in_features=model_hidden_size, 23 | out_features=model_embedding_size).to(device) 24 | self.relu = torch.nn.ReLU().to(device) 25 | 26 | # Cosine similarity scaling (with fixed initial parameter values) 27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) 28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) 29 | 30 | # Loss 31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device) 32 | 33 | def do_gradient_ops(self): 34 | # Gradient scale 35 | self.similarity_weight.grad *= 0.01 36 | self.similarity_bias.grad *= 0.01 37 | 38 | # Gradient clipping 39 | clip_grad_norm_(self.parameters(), 3, norm_type=2) 40 | 41 | def forward(self, utterances, hidden_init=None): 42 | """ 43 | Computes the embeddings of a batch of utterance spectrograms. 44 | 45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 46 | (batch_size, n_frames, n_channels) 47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 48 | batch_size, hidden_size). Will default to a tensor of zeros if None. 49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 50 | """ 51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 52 | # and the final cell state. 53 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 54 | 55 | # We take only the hidden state of the last layer 56 | embeds_raw = self.relu(self.linear(hidden[-1])) 57 | 58 | # L2-normalize it 59 | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 60 | 61 | return embeds 62 | 63 | def similarity_matrix(self, embeds): 64 | """ 65 | Computes the similarity matrix according the section 2.1 of GE2E. 66 | 67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 68 | utterances_per_speaker, embedding_size) 69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch, 70 | utterances_per_speaker, speakers_per_batch) 71 | """ 72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 73 | 74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation 75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) 76 | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) 77 | 78 | # Exclusive centroids (1 per utterance) 79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) 80 | centroids_excl /= (utterances_per_speaker - 1) 81 | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) 82 | 83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot 84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum). 85 | # We vectorize the computation for efficiency. 86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, 87 | speakers_per_batch).to(self.loss_device) 88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) 89 | for j in range(speakers_per_batch): 90 | mask = np.where(mask_matrix[j])[0] 91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) 92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) 93 | 94 | ## Even more vectorized version (slower maybe because of transpose) 95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker 96 | # ).to(self.loss_device) 97 | # eye = np.eye(speakers_per_batch, dtype=np.int) 98 | # mask = np.where(1 - eye) 99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) 100 | # mask = np.where(eye) 101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) 102 | # sim_matrix2 = sim_matrix2.transpose(1, 2) 103 | 104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias 105 | return sim_matrix 106 | 107 | def loss(self, embeds): 108 | """ 109 | Computes the softmax loss according the section 2.1 of GE2E. 110 | 111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 112 | utterances_per_speaker, embedding_size) 113 | :return: the loss and the EER for this batch of embeddings. 114 | """ 115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 116 | 117 | # Loss 118 | sim_matrix = self.similarity_matrix(embeds) 119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 120 | speakers_per_batch)) 121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) 122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device) 123 | loss = self.loss_fn(sim_matrix, target) 124 | 125 | # EER (not backpropagated) 126 | with torch.no_grad(): 127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] 128 | labels = np.array([inv_argmax(i) for i in ground_truth]) 129 | preds = sim_matrix.detach().cpu().numpy() 130 | 131 | # Snippet from https://yangcha.github.io/EER-ROC/ 132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) 133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 134 | 135 | return loss, eer -------------------------------------------------------------------------------- /speaker_encoder/params_data.py: -------------------------------------------------------------------------------- 1 | 2 | ## Mel-filterbank 3 | mel_window_length = 25 # In milliseconds 4 | mel_window_step = 10 # In milliseconds 5 | mel_n_channels = 40 6 | 7 | 8 | ## Audio 9 | sampling_rate = 16000 10 | # Number of spectrogram frames in a partial utterance 11 | partials_n_frames = 160 # 1600 ms 12 | # Number of spectrogram frames at inference 13 | inference_n_frames = 80 # 800 ms 14 | 15 | 16 | ## Voice Activation Detection 17 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 18 | # This sets the granularity of the VAD. Should not need to be changed. 19 | vad_window_length = 30 # In milliseconds 20 | # Number of frames to average together when performing the moving average smoothing. 21 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 22 | vad_moving_average_width = 8 23 | # Maximum number of consecutive silent frames a segment can have. 24 | vad_max_silence_length = 6 25 | 26 | 27 | ## Audio volume normalization 28 | audio_norm_target_dBFS = -30 29 | 30 | -------------------------------------------------------------------------------- /speaker_encoder/params_model.py: -------------------------------------------------------------------------------- 1 | 2 | ## Model parameters 3 | model_hidden_size = 256 4 | model_embedding_size = 256 5 | model_num_layers = 3 6 | 7 | 8 | ## Training parameters 9 | learning_rate_init = 1e-4 10 | speakers_per_batch = 64 11 | utterances_per_speaker = 10 12 | -------------------------------------------------------------------------------- /speaker_encoder/preprocess.py: -------------------------------------------------------------------------------- 1 | from multiprocess.pool import ThreadPool 2 | from speaker_encoder.params_data import * 3 | from speaker_encoder.config import librispeech_datasets, anglophone_nationalites 4 | from datetime import datetime 5 | from speaker_encoder import audio 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | class DatasetLog: 12 | """ 13 | Registers metadata about the dataset in a text file. 14 | """ 15 | def __init__(self, root, name): 16 | self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") 17 | self.sample_data = dict() 18 | 19 | start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 20 | self.write_line("Creating dataset %s on %s" % (name, start_time)) 21 | self.write_line("-----") 22 | self._log_params() 23 | 24 | def _log_params(self): 25 | from speaker_encoder import params_data 26 | self.write_line("Parameter values:") 27 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 28 | value = getattr(params_data, param_name) 29 | self.write_line("\t%s: %s" % (param_name, value)) 30 | self.write_line("-----") 31 | 32 | def write_line(self, line): 33 | self.text_file.write("%s\n" % line) 34 | 35 | def add_sample(self, **kwargs): 36 | for param_name, value in kwargs.items(): 37 | if not param_name in self.sample_data: 38 | self.sample_data[param_name] = [] 39 | self.sample_data[param_name].append(value) 40 | 41 | def finalize(self): 42 | self.write_line("Statistics:") 43 | for param_name, values in self.sample_data.items(): 44 | self.write_line("\t%s:" % param_name) 45 | self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) 46 | self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) 47 | self.write_line("-----") 48 | end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 49 | self.write_line("Finished on %s" % end_time) 50 | self.text_file.close() 51 | 52 | 53 | def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog): 54 | dataset_root = datasets_root.joinpath(dataset_name) 55 | if not dataset_root.exists(): 56 | print("Couldn\'t find %s, skipping this dataset." % dataset_root) 57 | return None, None 58 | return dataset_root, DatasetLog(out_dir, dataset_name) 59 | 60 | 61 | def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, 62 | skip_existing, logger): 63 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) 64 | 65 | # Function to preprocess utterances for one speaker 66 | def preprocess_speaker(speaker_dir: Path): 67 | # Give a name to the speaker that includes its dataset 68 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) 69 | 70 | # Create an output directory with that name, as well as a txt file containing a 71 | # reference to each source file. 72 | speaker_out_dir = out_dir.joinpath(speaker_name) 73 | speaker_out_dir.mkdir(exist_ok=True) 74 | sources_fpath = speaker_out_dir.joinpath("_sources.txt") 75 | 76 | # There's a possibility that the preprocessing was interrupted earlier, check if 77 | # there already is a sources file. 78 | if sources_fpath.exists(): 79 | try: 80 | with sources_fpath.open("r") as sources_file: 81 | existing_fnames = {line.split(",")[0] for line in sources_file} 82 | except: 83 | existing_fnames = {} 84 | else: 85 | existing_fnames = {} 86 | 87 | # Gather all audio files for that speaker recursively 88 | sources_file = sources_fpath.open("a" if skip_existing else "w") 89 | for in_fpath in speaker_dir.glob("**/*.%s" % extension): 90 | # Check if the target output file already exists 91 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) 92 | out_fname = out_fname.replace(".%s" % extension, ".npy") 93 | if skip_existing and out_fname in existing_fnames: 94 | continue 95 | 96 | # Load and preprocess the waveform 97 | wav = audio.preprocess_wav(in_fpath) 98 | if len(wav) == 0: 99 | continue 100 | 101 | # Create the mel spectrogram, discard those that are too short 102 | frames = audio.wav_to_mel_spectrogram(wav) 103 | if len(frames) < partials_n_frames: 104 | continue 105 | 106 | out_fpath = speaker_out_dir.joinpath(out_fname) 107 | np.save(out_fpath, frames) 108 | logger.add_sample(duration=len(wav) / sampling_rate) 109 | sources_file.write("%s,%s\n" % (out_fname, in_fpath)) 110 | 111 | sources_file.close() 112 | 113 | # Process the utterances for each speaker 114 | with ThreadPool(8) as pool: 115 | list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), 116 | unit="speakers")) 117 | logger.finalize() 118 | print("Done preprocessing %s.\n" % dataset_name) 119 | 120 | 121 | # Function to preprocess utterances for one speaker 122 | def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool): 123 | # Give a name to the speaker that includes its dataset 124 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) 125 | 126 | # Create an output directory with that name, as well as a txt file containing a 127 | # reference to each source file. 128 | speaker_out_dir = out_dir.joinpath(speaker_name) 129 | speaker_out_dir.mkdir(exist_ok=True) 130 | sources_fpath = speaker_out_dir.joinpath("_sources.txt") 131 | 132 | # There's a possibility that the preprocessing was interrupted earlier, check if 133 | # there already is a sources file. 134 | # if sources_fpath.exists(): 135 | # try: 136 | # with sources_fpath.open("r") as sources_file: 137 | # existing_fnames = {line.split(",")[0] for line in sources_file} 138 | # except: 139 | # existing_fnames = {} 140 | # else: 141 | # existing_fnames = {} 142 | existing_fnames = {} 143 | # Gather all audio files for that speaker recursively 144 | sources_file = sources_fpath.open("a" if skip_existing else "w") 145 | 146 | for in_fpath in speaker_dir.glob("**/*.%s" % extension): 147 | # Check if the target output file already exists 148 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) 149 | out_fname = out_fname.replace(".%s" % extension, ".npy") 150 | if skip_existing and out_fname in existing_fnames: 151 | continue 152 | 153 | # Load and preprocess the waveform 154 | wav = audio.preprocess_wav(in_fpath) 155 | if len(wav) == 0: 156 | continue 157 | 158 | # Create the mel spectrogram, discard those that are too short 159 | frames = audio.wav_to_mel_spectrogram(wav) 160 | if len(frames) < partials_n_frames: 161 | continue 162 | 163 | out_fpath = speaker_out_dir.joinpath(out_fname) 164 | np.save(out_fpath, frames) 165 | # logger.add_sample(duration=len(wav) / sampling_rate) 166 | sources_file.write("%s,%s\n" % (out_fname, in_fpath)) 167 | 168 | sources_file.close() 169 | return len(wav) 170 | 171 | def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension, 172 | skip_existing, logger): 173 | # from multiprocessing import Pool, cpu_count 174 | from pathos.multiprocessing import ProcessingPool as Pool 175 | # Function to preprocess utterances for one speaker 176 | def __preprocess_speaker(speaker_dir: Path): 177 | # Give a name to the speaker that includes its dataset 178 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) 179 | 180 | # Create an output directory with that name, as well as a txt file containing a 181 | # reference to each source file. 182 | speaker_out_dir = out_dir.joinpath(speaker_name) 183 | speaker_out_dir.mkdir(exist_ok=True) 184 | sources_fpath = speaker_out_dir.joinpath("_sources.txt") 185 | 186 | existing_fnames = {} 187 | # Gather all audio files for that speaker recursively 188 | sources_file = sources_fpath.open("a" if skip_existing else "w") 189 | wav_lens = [] 190 | for in_fpath in speaker_dir.glob("**/*.%s" % extension): 191 | # Check if the target output file already exists 192 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) 193 | out_fname = out_fname.replace(".%s" % extension, ".npy") 194 | if skip_existing and out_fname in existing_fnames: 195 | continue 196 | 197 | # Load and preprocess the waveform 198 | wav = audio.preprocess_wav(in_fpath) 199 | if len(wav) == 0: 200 | continue 201 | 202 | # Create the mel spectrogram, discard those that are too short 203 | frames = audio.wav_to_mel_spectrogram(wav) 204 | if len(frames) < partials_n_frames: 205 | continue 206 | 207 | out_fpath = speaker_out_dir.joinpath(out_fname) 208 | np.save(out_fpath, frames) 209 | # logger.add_sample(duration=len(wav) / sampling_rate) 210 | sources_file.write("%s,%s\n" % (out_fname, in_fpath)) 211 | wav_lens.append(len(wav)) 212 | sources_file.close() 213 | return wav_lens 214 | 215 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) 216 | # Process the utterances for each speaker 217 | # with ThreadPool(8) as pool: 218 | # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), 219 | # unit="speakers")) 220 | pool = Pool(processes=20) 221 | for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1): 222 | for wav_len in wav_lens: 223 | logger.add_sample(duration=wav_len / sampling_rate) 224 | print(f'{i}/{len(speaker_dirs)} \r') 225 | 226 | logger.finalize() 227 | print("Done preprocessing %s.\n" % dataset_name) 228 | 229 | 230 | def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): 231 | for dataset_name in librispeech_datasets["train"]["other"]: 232 | # Initialize the preprocessing 233 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 234 | if not dataset_root: 235 | return 236 | 237 | # Preprocess all speakers 238 | speaker_dirs = list(dataset_root.glob("*")) 239 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac", 240 | skip_existing, logger) 241 | 242 | 243 | def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): 244 | # Initialize the preprocessing 245 | dataset_name = "VoxCeleb1" 246 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 247 | if not dataset_root: 248 | return 249 | 250 | # Get the contents of the meta file 251 | with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: 252 | metadata = [line.split("\t") for line in metafile][1:] 253 | 254 | # Select the ID and the nationality, filter out non-anglophone speakers 255 | nationalities = {line[0]: line[3] for line in metadata} 256 | # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if 257 | # nationality.lower() in anglophone_nationalites] 258 | keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()] 259 | print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % 260 | (len(keep_speaker_ids), len(nationalities))) 261 | 262 | # Get the speaker directories for anglophone speakers only 263 | speaker_dirs = dataset_root.joinpath("wav").glob("*") 264 | speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if 265 | speaker_dir.name in keep_speaker_ids] 266 | print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." % 267 | (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs))) 268 | 269 | # Preprocess all speakers 270 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", 271 | skip_existing, logger) 272 | 273 | 274 | def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): 275 | # Initialize the preprocessing 276 | dataset_name = "VoxCeleb2" 277 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 278 | if not dataset_root: 279 | return 280 | 281 | # Get the speaker directories 282 | # Preprocess all speakers 283 | speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) 284 | _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", 285 | skip_existing, logger) 286 | -------------------------------------------------------------------------------- /speaker_encoder/train.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.visualizations import Visualizations 2 | from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset 3 | from speaker_encoder.params_model import * 4 | from speaker_encoder.model import SpeakerEncoder 5 | from utils.profiler import Profiler 6 | from pathlib import Path 7 | import torch 8 | 9 | def sync(device: torch.device): 10 | # FIXME 11 | return 12 | # For correct profiling (cuda operations are async) 13 | if device.type == "cuda": 14 | torch.cuda.synchronize(device) 15 | 16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, 17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, 18 | no_visdom: bool): 19 | # Create a dataset and a dataloader 20 | dataset = SpeakerVerificationDataset(clean_data_root) 21 | loader = SpeakerVerificationDataLoader( 22 | dataset, 23 | speakers_per_batch, # 64 24 | utterances_per_speaker, # 10 25 | num_workers=8, 26 | ) 27 | 28 | # Setup the device on which to run the forward pass and the loss. These can be different, 29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your 30 | # hyperparameters) faster on the CPU. 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # FIXME: currently, the gradient is None if loss_device is cuda 33 | loss_device = torch.device("cpu") 34 | 35 | # Create the model and the optimizer 36 | model = SpeakerEncoder(device, loss_device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) 38 | init_step = 1 39 | 40 | # Configure file path for the model 41 | state_fpath = models_dir.joinpath(run_id + ".pt") 42 | backup_dir = models_dir.joinpath(run_id + "_backups") 43 | 44 | # Load any existing model 45 | if not force_restart: 46 | if state_fpath.exists(): 47 | print("Found existing model \"%s\", loading it and resuming training." % run_id) 48 | checkpoint = torch.load(state_fpath) 49 | init_step = checkpoint["step"] 50 | model.load_state_dict(checkpoint["model_state"]) 51 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 52 | optimizer.param_groups[0]["lr"] = learning_rate_init 53 | else: 54 | print("No model \"%s\" found, starting training from scratch." % run_id) 55 | else: 56 | print("Starting the training from scratch.") 57 | model.train() 58 | 59 | # Initialize the visualization environment 60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) 61 | vis.log_dataset(dataset) 62 | vis.log_params() 63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") 64 | vis.log_implementation({"Device": device_name}) 65 | 66 | # Training loop 67 | profiler = Profiler(summarize_every=10, disabled=False) 68 | for step, speaker_batch in enumerate(loader, init_step): 69 | profiler.tick("Blocking, waiting for batch (threaded)") 70 | 71 | # Forward pass 72 | inputs = torch.from_numpy(speaker_batch.data).to(device) 73 | sync(device) 74 | profiler.tick("Data to %s" % device) 75 | embeds = model(inputs) 76 | sync(device) 77 | profiler.tick("Forward pass") 78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) 79 | loss, eer = model.loss(embeds_loss) 80 | sync(loss_device) 81 | profiler.tick("Loss") 82 | 83 | # Backward pass 84 | model.zero_grad() 85 | loss.backward() 86 | profiler.tick("Backward pass") 87 | model.do_gradient_ops() 88 | optimizer.step() 89 | profiler.tick("Parameter update") 90 | 91 | # Update visualizations 92 | # learning_rate = optimizer.param_groups[0]["lr"] 93 | vis.update(loss.item(), eer, step) 94 | 95 | # Draw projections and save them to the backup folder 96 | if umap_every != 0 and step % umap_every == 0: 97 | print("Drawing and saving projections (step %d)" % step) 98 | backup_dir.mkdir(exist_ok=True) 99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) 100 | embeds = embeds.detach().cpu().numpy() 101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) 102 | vis.save() 103 | 104 | # Overwrite the latest version of the model 105 | if save_every != 0 and step % save_every == 0: 106 | print("Saving the model (step %d)" % step) 107 | torch.save({ 108 | "step": step + 1, 109 | "model_state": model.state_dict(), 110 | "optimizer_state": optimizer.state_dict(), 111 | }, state_fpath) 112 | 113 | # Make a backup 114 | if backup_every != 0 and step % backup_every == 0: 115 | print("Making a backup (step %d)" % step) 116 | backup_dir.mkdir(exist_ok=True) 117 | backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) 118 | torch.save({ 119 | "step": step + 1, 120 | "model_state": model.state_dict(), 121 | "optimizer_state": optimizer.state_dict(), 122 | }, backup_fpath) 123 | 124 | profiler.tick("Extras (visualizations, saving)") 125 | -------------------------------------------------------------------------------- /speaker_encoder/visualizations.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from datetime import datetime 3 | from time import perf_counter as timer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | # import webbrowser 7 | import visdom 8 | import umap 9 | 10 | colormap = np.array([ 11 | [76, 255, 0], 12 | [0, 127, 70], 13 | [255, 0, 0], 14 | [255, 217, 38], 15 | [0, 135, 255], 16 | [165, 0, 165], 17 | [255, 167, 255], 18 | [0, 255, 255], 19 | [255, 96, 38], 20 | [142, 76, 0], 21 | [33, 0, 127], 22 | [0, 0, 0], 23 | [183, 183, 183], 24 | ], dtype=np.float) / 255 25 | 26 | 27 | class Visualizations: 28 | def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): 29 | # Tracking data 30 | self.last_update_timestamp = timer() 31 | self.update_every = update_every 32 | self.step_times = [] 33 | self.losses = [] 34 | self.eers = [] 35 | print("Updating the visualizations every %d steps." % update_every) 36 | 37 | # If visdom is disabled TODO: use a better paradigm for that 38 | self.disabled = disabled 39 | if self.disabled: 40 | return 41 | 42 | # Set the environment name 43 | now = str(datetime.now().strftime("%d-%m %Hh%M")) 44 | if env_name is None: 45 | self.env_name = now 46 | else: 47 | self.env_name = "%s (%s)" % (env_name, now) 48 | 49 | # Connect to visdom and open the corresponding window in the browser 50 | try: 51 | self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) 52 | except ConnectionError: 53 | raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " 54 | "start it.") 55 | # webbrowser.open("http://localhost:8097/env/" + self.env_name) 56 | 57 | # Create the windows 58 | self.loss_win = None 59 | self.eer_win = None 60 | # self.lr_win = None 61 | self.implementation_win = None 62 | self.projection_win = None 63 | self.implementation_string = "" 64 | 65 | def log_params(self): 66 | if self.disabled: 67 | return 68 | from speaker_encoder import params_data 69 | from speaker_encoder import params_model 70 | param_string = "Model parameters:
" 71 | for param_name in (p for p in dir(params_model) if not p.startswith("__")): 72 | value = getattr(params_model, param_name) 73 | param_string += "\t%s: %s
" % (param_name, value) 74 | param_string += "Data parameters:
" 75 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 76 | value = getattr(params_data, param_name) 77 | param_string += "\t%s: %s
" % (param_name, value) 78 | self.vis.text(param_string, opts={"title": "Parameters"}) 79 | 80 | def log_dataset(self, dataset: SpeakerVerificationDataset): 81 | if self.disabled: 82 | return 83 | dataset_string = "" 84 | dataset_string += "Speakers: %s\n" % len(dataset.speakers) 85 | dataset_string += "\n" + dataset.get_logs() 86 | dataset_string = dataset_string.replace("\n", "
") 87 | self.vis.text(dataset_string, opts={"title": "Dataset"}) 88 | 89 | def log_implementation(self, params): 90 | if self.disabled: 91 | return 92 | implementation_string = "" 93 | for param, value in params.items(): 94 | implementation_string += "%s: %s\n" % (param, value) 95 | implementation_string = implementation_string.replace("\n", "
") 96 | self.implementation_string = implementation_string 97 | self.implementation_win = self.vis.text( 98 | implementation_string, 99 | opts={"title": "Training implementation"} 100 | ) 101 | 102 | def update(self, loss, eer, step): 103 | # Update the tracking data 104 | now = timer() 105 | self.step_times.append(1000 * (now - self.last_update_timestamp)) 106 | self.last_update_timestamp = now 107 | self.losses.append(loss) 108 | self.eers.append(eer) 109 | print(".", end="") 110 | 111 | # Update the plots every steps 112 | if step % self.update_every != 0: 113 | return 114 | time_string = "Step time: mean: %5dms std: %5dms" % \ 115 | (int(np.mean(self.step_times)), int(np.std(self.step_times))) 116 | print("\nStep %6d Loss: %.4f EER: %.4f %s" % 117 | (step, np.mean(self.losses), np.mean(self.eers), time_string)) 118 | if not self.disabled: 119 | self.loss_win = self.vis.line( 120 | [np.mean(self.losses)], 121 | [step], 122 | win=self.loss_win, 123 | update="append" if self.loss_win else None, 124 | opts=dict( 125 | legend=["Avg. loss"], 126 | xlabel="Step", 127 | ylabel="Loss", 128 | title="Loss", 129 | ) 130 | ) 131 | self.eer_win = self.vis.line( 132 | [np.mean(self.eers)], 133 | [step], 134 | win=self.eer_win, 135 | update="append" if self.eer_win else None, 136 | opts=dict( 137 | legend=["Avg. EER"], 138 | xlabel="Step", 139 | ylabel="EER", 140 | title="Equal error rate" 141 | ) 142 | ) 143 | if self.implementation_win is not None: 144 | self.vis.text( 145 | self.implementation_string + ("%s" % time_string), 146 | win=self.implementation_win, 147 | opts={"title": "Training implementation"}, 148 | ) 149 | 150 | # Reset the tracking 151 | self.losses.clear() 152 | self.eers.clear() 153 | self.step_times.clear() 154 | 155 | def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, 156 | max_speakers=10): 157 | max_speakers = min(max_speakers, len(colormap)) 158 | embeds = embeds[:max_speakers * utterances_per_speaker] 159 | 160 | n_speakers = len(embeds) // utterances_per_speaker 161 | ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) 162 | colors = [colormap[i] for i in ground_truth] 163 | 164 | reducer = umap.UMAP() 165 | projected = reducer.fit_transform(embeds) 166 | plt.scatter(projected[:, 0], projected[:, 1], c=colors) 167 | plt.gca().set_aspect("equal", "datalim") 168 | plt.title("UMAP projection (step %d)" % step) 169 | if not self.disabled: 170 | self.projection_win = self.vis.matplot(plt, win=self.projection_win) 171 | if out_fpath is not None: 172 | plt.savefig(out_fpath) 173 | plt.clf() 174 | 175 | def save(self): 176 | if not self.disabled: 177 | self.vis.save([self.env_name]) 178 | -------------------------------------------------------------------------------- /speaker_encoder/voice_encoder.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.hparams import * 2 | from speaker_encoder import audio 3 | from pathlib import Path 4 | from typing import Union, List 5 | from torch import nn 6 | from time import perf_counter as timer 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class SpeakerEncoder(nn.Module): 12 | def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True): 13 | """ 14 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). 15 | If None, defaults to cuda if it is available on your machine, otherwise the model will 16 | run on cpu. Outputs are always returned on the cpu, as numpy arrays. 17 | """ 18 | super().__init__() 19 | 20 | # Define the network 21 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) 22 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 23 | self.relu = nn.ReLU() 24 | 25 | # Get the target device 26 | if device is None: 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | elif isinstance(device, str): 29 | device = torch.device(device) 30 | self.device = device 31 | 32 | # Load the pretrained model'speaker weights 33 | # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt") 34 | # if not weights_fpath.exists(): 35 | # raise Exception("Couldn't find the voice encoder pretrained model at %s." % 36 | # weights_fpath) 37 | 38 | start = timer() 39 | checkpoint = torch.load(weights_fpath, map_location="cpu") 40 | 41 | self.load_state_dict(checkpoint["model_state"], strict=False) 42 | self.to(device) 43 | 44 | if verbose: 45 | print("Loaded the voice encoder model on %s in %.2f seconds." % 46 | (device.type, timer() - start)) 47 | 48 | def forward(self, mels: torch.FloatTensor): 49 | """ 50 | Computes the embeddings of a batch of utterance spectrograms. 51 | :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape 52 | (batch_size, n_frames, n_channels) 53 | :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). 54 | Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. 55 | """ 56 | # Pass the input through the LSTM layers and retrieve the final hidden state of the last 57 | # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. 58 | _, (hidden, _) = self.lstm(mels) 59 | embeds_raw = self.relu(self.linear(hidden[-1])) 60 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 61 | 62 | @staticmethod 63 | def compute_partial_slices(n_samples: int, rate, min_coverage): 64 | """ 65 | Computes where to split an utterance waveform and its corresponding mel spectrogram to 66 | obtain partial utterances of each. Both the waveform and the 67 | mel spectrogram slices are returned, so as to make each partial utterance waveform 68 | correspond to its spectrogram. 69 | 70 | The returned ranges may be indexing further than the length of the waveform. It is 71 | recommended that you pad the waveform with zeros up to wav_slices[-1].stop. 72 | 73 | :param n_samples: the number of samples in the waveform 74 | :param rate: how many partial utterances should occur per second. Partial utterances must 75 | cover the span of the entire utterance, thus the rate should not be lower than the inverse 76 | of the duration of a partial utterance. By default, partial utterances are 1.6s long and 77 | the minimum rate is thus 0.625. 78 | :param min_coverage: when reaching the last partial utterance, it may or may not have 79 | enough frames. If at least of are present, 80 | then the last partial utterance will be considered by zero-padding the audio. Otherwise, 81 | it will be discarded. If there aren't enough frames for one partial utterance, 82 | this parameter is ignored so that the function always returns at least one slice. 83 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 84 | respectively the waveform and the mel spectrogram with these slices to obtain the partial 85 | utterances. 86 | """ 87 | assert 0 < min_coverage <= 1 88 | 89 | # Compute how many frames separate two partial utterances 90 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 91 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 92 | frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) 93 | assert 0 < frame_step, "The rate is too high" 94 | assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \ 95 | (sampling_rate / (samples_per_frame * partials_n_frames)) 96 | 97 | # Compute the slices 98 | wav_slices, mel_slices = [], [] 99 | steps = max(1, n_frames - partials_n_frames + frame_step + 1) 100 | for i in range(0, steps, frame_step): 101 | mel_range = np.array([i, i + partials_n_frames]) 102 | wav_range = mel_range * samples_per_frame 103 | mel_slices.append(slice(*mel_range)) 104 | wav_slices.append(slice(*wav_range)) 105 | 106 | # Evaluate whether extra padding is warranted or not 107 | last_wav_range = wav_slices[-1] 108 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 109 | if coverage < min_coverage and len(mel_slices) > 1: 110 | mel_slices = mel_slices[:-1] 111 | wav_slices = wav_slices[:-1] 112 | 113 | return wav_slices, mel_slices 114 | 115 | def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75): 116 | """ 117 | Computes an embedding for a single utterance. The utterance is divided in partial 118 | utterances and an embedding is computed for each. The complete utterance embedding is the 119 | L2-normed average embedding of the partial utterances. 120 | 121 | TODO: independent batched version of this function 122 | 123 | :param wav: a preprocessed utterance waveform as a numpy array of float32 124 | :param return_partials: if True, the partial embeddings will also be returned along with 125 | the wav slices corresponding to each partial utterance. 126 | :param rate: how many partial utterances should occur per second. Partial utterances must 127 | cover the span of the entire utterance, thus the rate should not be lower than the inverse 128 | of the duration of a partial utterance. By default, partial utterances are 1.6s long and 129 | the minimum rate is thus 0.625. 130 | :param min_coverage: when reaching the last partial utterance, it may or may not have 131 | enough frames. If at least of are present, 132 | then the last partial utterance will be considered by zero-padding the audio. Otherwise, 133 | it will be discarded. If there aren't enough frames for one partial utterance, 134 | this parameter is ignored so that the function always returns at least one slice. 135 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 136 | is True, the partial utterances as a numpy array of float32 of shape 137 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 138 | returned. 139 | """ 140 | # Compute where to split the utterance into partials and pad the waveform with zeros if 141 | # the partial utterances cover a larger range. 142 | wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) 143 | max_wave_length = wav_slices[-1].stop 144 | if max_wave_length >= len(wav): 145 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 146 | 147 | # Split the utterance into partials and forward them through the model 148 | mel = audio.wav_to_mel_spectrogram(wav) 149 | mels = np.array([mel[s] for s in mel_slices]) 150 | with torch.no_grad(): 151 | mels = torch.from_numpy(mels).to(self.device) 152 | partial_embeds = self(mels).cpu().numpy() 153 | 154 | # Compute the utterance embedding from the partial embeddings 155 | raw_embed = np.mean(partial_embeds, axis=0) 156 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 157 | 158 | if return_partials: 159 | return embed, partial_embeds, wav_slices 160 | return embed 161 | 162 | def embed_speaker(self, wavs: List[np.ndarray], **kwargs): 163 | """ 164 | Compute the embedding of a collection of wavs (presumably from the same speaker) by 165 | averaging their embedding and L2-normalizing it. 166 | 167 | :param wavs: list of wavs a numpy arrays of float32. 168 | :param kwargs: extra arguments to embed_utterance() 169 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). 170 | """ 171 | raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \ 172 | for wav in wavs], axis=0) 173 | return raw_embed / np.linalg.norm(raw_embed, 2) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import itertools 5 | import math 6 | import torch 7 | from torch import nn, optim 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from utils import HParams 11 | 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.cuda.amp import autocast, GradScaler 14 | 15 | import commons 16 | import utils 17 | from data_utils import ( 18 | TextAudioSpeakerLoader, 19 | TextAudioSpeakerCollate, 20 | DistributedBucketSampler 21 | ) 22 | from models import ( 23 | HuBERT_NeuralDec_VITS, 24 | MultiPeriodDiscriminator, 25 | ) 26 | from losses import ( 27 | generator_loss, 28 | discriminator_loss, 29 | feature_loss, 30 | kl_loss 31 | ) 32 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch 33 | from speaker_encoder.voice_encoder import SpeakerEncoder 34 | smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt') 35 | torch.backends.cudnn.benchmark = True 36 | global_step = 0 37 | 38 | 39 | class Parameters: 40 | def __init__(self): 41 | self.config = "./configs/hubert-neuraldec-vits.json" 42 | self.model = "neuralvc-temp" 43 | 44 | 45 | args = Parameters() 46 | 47 | def get_hparams(init=True): 48 | 49 | model_dir = os.path.join("./logs", args.model) 50 | 51 | if not os.path.exists(model_dir): 52 | os.makedirs(model_dir) 53 | 54 | config_path = args.config 55 | config_save_path = os.path.join(model_dir, "config.json") 56 | if init: 57 | with open(config_path, "r") as f: 58 | data = f.read() 59 | with open(config_save_path, "w") as f: 60 | f.write(data) 61 | else: 62 | with open(config_save_path, "r") as f: 63 | data = f.read() 64 | config = json.loads(data) 65 | 66 | hparams = HParams(**config) 67 | hparams.model_dir = model_dir 68 | return hparams 69 | 70 | hps = get_hparams() 71 | 72 | 73 | def spk_loss(tgt, gen, batch_size): 74 | loss = 0.0 75 | for i in range(batch_size): 76 | tgt_emb = smodel.embed_utterance(tgt[i][0]) 77 | gen_emb = smodel.embed_utterance(gen[i][0]) 78 | 79 | loss += F.l1_loss(torch.from_numpy(tgt_emb), torch.from_numpy(gen_emb)) 80 | 81 | return loss/batch_size 82 | 83 | def main(): 84 | """Assume Single Node Multi GPUs Training Only""" 85 | assert torch.cuda.is_available(), "CPU training is not allowed." 86 | n_gpus = torch.cuda.device_count() 87 | run(0,n_gpus, hps) 88 | 89 | def run(rank, n_gpus, hps): 90 | global global_step 91 | if rank == 0: 92 | logger = utils.get_logger(hps.model_dir) 93 | logger.info(hps) 94 | utils.check_git_hash(hps.model_dir) 95 | writer = SummaryWriter(log_dir=hps.model_dir) 96 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 97 | 98 | #dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 99 | torch.manual_seed(hps.train.seed) 100 | torch.cuda.set_device(rank) 101 | 102 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps) 103 | train_sampler = DistributedBucketSampler( 104 | train_dataset, 105 | hps.train.batch_size, 106 | [75,100,125,150,175,200,225,250,300,350,400,450,500,550,600,650,700,750,800,850,900,950,1000,1100,1200,1300,1400,1500,2000,3000,4000,5000], 107 | num_replicas=n_gpus, 108 | rank=rank, 109 | shuffle=True) 110 | collate_fn = TextAudioSpeakerCollate(hps) 111 | train_loader = DataLoader(train_dataset, num_workers=0, shuffle=False, pin_memory=True, 112 | collate_fn=collate_fn, batch_sampler=train_sampler) 113 | if rank == 0: 114 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps) 115 | eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=True, 116 | batch_size=hps.train.batch_size, pin_memory=False, 117 | drop_last=False, collate_fn=collate_fn) 118 | 119 | net_g = HuBERT_NeuralDec_VITS( 120 | hps.data.filter_length // 2 + 1, 121 | hps.train.segment_size // hps.data.hop_length, 122 | **hps.model).cuda(rank) 123 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 124 | optim_g = torch.optim.AdamW( 125 | net_g.parameters(), 126 | hps.train.learning_rate, 127 | betas=hps.train.betas, 128 | eps=hps.train.eps) 129 | optim_d = torch.optim.AdamW( 130 | net_d.parameters(), 131 | hps.train.learning_rate, 132 | betas=hps.train.betas, 133 | eps=hps.train.eps) 134 | 135 | try: 136 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g) 137 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) 138 | global_step = (epoch_str - 1) * len(train_loader) 139 | except: 140 | epoch_str = 1 141 | global_step = 0 142 | 143 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) 144 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2) 145 | 146 | scaler = GradScaler(enabled=hps.train.fp16_run) 147 | 148 | for epoch in range(epoch_str, hps.train.epochs + 1): 149 | if rank==0: 150 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) 151 | else: 152 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None) 153 | scheduler_g.step() 154 | scheduler_d.step() 155 | 156 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): 157 | 158 | net_g, net_d = nets 159 | optim_g, optim_d = optims 160 | scheduler_g, scheduler_d = schedulers 161 | train_loader, eval_loader = loaders 162 | if writers is not None: 163 | writer, writer_eval = writers 164 | 165 | train_loader.batch_sampler.set_epoch(epoch) 166 | global global_step 167 | 168 | net_g.train() 169 | net_d.train() 170 | for batch_idx, items in enumerate(train_loader): 171 | if hps.model.use_spk: 172 | c, spec, y, spk = items 173 | g = spk.cuda(rank, non_blocking=True) 174 | else: 175 | c, spec, y = items 176 | g = None 177 | spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True) 178 | c = c.cuda(rank, non_blocking=True) 179 | # print("***" + spec.shape) 180 | mel = spec_to_mel_torch( 181 | spec, 182 | hps.data.filter_length, 183 | hps.data.n_mel_channels, 184 | hps.data.sampling_rate, 185 | hps.data.mel_fmin, 186 | hps.data.mel_fmax) 187 | real_mel = mel_spectrogram_torch( 188 | y.squeeze(1), 189 | hps.data.filter_length, 190 | hps.data.n_mel_channels, 191 | hps.data.sampling_rate, 192 | hps.data.hop_length, 193 | hps.data.win_length, 194 | hps.data.mel_fmin, 195 | hps.data.mel_fmax 196 | ) 197 | #print(torch.max(mel),torch.min(mel),torch.max(real_mel),torch.min(real_mel)) 198 | with autocast(enabled=hps.train.fp16_run): 199 | y_hat, ids_slice, z_mask,\ 200 | (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g, mel=mel) 201 | #print(torch.max(y),torch.min(y),torch.max(y_hat),torch.min(y_hat)) 202 | y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) 203 | y_hat_mel = mel_spectrogram_torch( 204 | y_hat.squeeze(1), 205 | hps.data.filter_length, 206 | hps.data.n_mel_channels, 207 | hps.data.sampling_rate, 208 | hps.data.hop_length, 209 | hps.data.win_length, 210 | hps.data.mel_fmin, 211 | hps.data.mel_fmax 212 | ) 213 | y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 214 | 215 | # Discriminator 216 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 217 | with autocast(enabled=False): 218 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) 219 | loss_disc_all = loss_disc 220 | optim_d.zero_grad() 221 | scaler.scale(loss_disc_all).backward() 222 | scaler.unscale_(optim_d) 223 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) 224 | scaler.step(optim_d) 225 | with autocast(enabled=hps.train.fp16_run): 226 | # Generator 227 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 228 | with autocast(enabled=False): 229 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 230 | loss_mel = F.l1_loss(y_hat_mel, y_mel) * hps.train.c_mel 231 | loss_fm = feature_loss(fmap_r, fmap_g) 232 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 233 | loss_spk = spk_loss(y.detach().cpu().numpy(), y_hat.detach().cpu().numpy(), hps.train.batch_size) 234 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_spk 235 | optim_g.zero_grad() 236 | scaler.scale(loss_gen_all).backward() 237 | scaler.unscale_(optim_g) 238 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) 239 | scaler.step(optim_g) 240 | scaler.update() 241 | 242 | if rank==0: 243 | if global_step % hps.train.log_interval == 0: 244 | lr = optim_g.param_groups[0]['lr'] 245 | losses = [loss_disc, loss_gen, loss_fm, loss_mel] 246 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 247 | epoch, 248 | 100. * batch_idx / len(train_loader))) 249 | logger.info([x.item() for x in losses] + [global_step, lr]) 250 | 251 | scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} 252 | scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel}) 253 | 254 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) 255 | scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) 256 | scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) 257 | image_dict = { 258 | "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), 259 | "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), 260 | "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), 261 | } 262 | utils.summarize( 263 | writer=writer, 264 | global_step=global_step, 265 | images=image_dict, 266 | scalars=scalar_dict) 267 | 268 | if global_step % hps.train.eval_interval == 0: 269 | # evaluate(hps, net_g, eval_loader, writer_eval) 270 | utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 271 | utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) 272 | global_step += 1 273 | 274 | if rank == 0: 275 | logger.info('====> Epoch: {}'.format(epoch)) 276 | 277 | 278 | if __name__ == "__main__": 279 | main() 280 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | import torchvision 12 | from torch.nn import functional as F 13 | from commons import sequence_mask 14 | import hifigan 15 | from wavlm import WavLM, WavLMConfig 16 | 17 | MATPLOTLIB_FLAG = False 18 | 19 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 20 | logger = logging 21 | 22 | 23 | def get_cmodel(rank): 24 | checkpoint = torch.load('wavlm/WavLM-Large.pt') 25 | cfg = WavLMConfig(checkpoint['cfg']) 26 | cmodel = WavLM(cfg).cuda(rank) 27 | cmodel.load_state_dict(checkpoint['model']) 28 | cmodel.eval() 29 | return cmodel 30 | 31 | 32 | def get_content(cmodel, y): 33 | with torch.no_grad(): 34 | c = cmodel.extract_features(y.squeeze(1))[0] 35 | c = c.transpose(1, 2) 36 | return c 37 | 38 | 39 | def get_vocoder(rank): 40 | with open("hifigan/config.json", "r") as f: 41 | config = json.load(f) 42 | config = hifigan.AttrDict(config) 43 | vocoder = hifigan.Generator(config) 44 | ckpt = torch.load("hifigan/generator_v1") 45 | vocoder.load_state_dict(ckpt["generator"]) 46 | vocoder.eval() 47 | vocoder.remove_weight_norm() 48 | vocoder.cuda(rank) 49 | return vocoder 50 | 51 | 52 | def transform(mel, height): # 68-92 53 | #r = np.random.random() 54 | #rate = r * 0.3 + 0.85 # 0.85-1.15 55 | #height = int(mel.size(-2) * rate) 56 | tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1))) 57 | if height >= mel.size(-2): 58 | return tgt[:, :mel.size(-2), :] 59 | else: 60 | silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1) 61 | silence += torch.randn_like(silence) / 10 62 | return torch.cat((tgt, silence), 1) 63 | 64 | 65 | def stretch(mel, width): # 0.5-2 66 | return torchvision.transforms.functional.resize(mel, (mel.size(-2), width)) 67 | 68 | 69 | def load_checkpoint(checkpoint_path, model, optimizer=None, strict=False): 70 | assert os.path.isfile(checkpoint_path) 71 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 72 | iteration = checkpoint_dict['iteration'] 73 | learning_rate = checkpoint_dict['learning_rate'] 74 | if optimizer is not None: 75 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 76 | saved_state_dict = checkpoint_dict['model'] 77 | if hasattr(model, 'module'): 78 | state_dict = model.module.state_dict() 79 | else: 80 | state_dict = model.state_dict() 81 | if strict: 82 | assert state_dict.keys() == saved_state_dict.keys(), "Mismatched model config and checkpoint." 83 | new_state_dict= {} 84 | for k, v in state_dict.items(): 85 | try: 86 | new_state_dict[k] = saved_state_dict[k] 87 | except: 88 | logger.info("%s is not in the checkpoint" % k) 89 | new_state_dict[k] = v 90 | if hasattr(model, 'module'): 91 | model.module.load_state_dict(new_state_dict) 92 | else: 93 | model.load_state_dict(new_state_dict) 94 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 95 | checkpoint_path, iteration)) 96 | return model, optimizer, learning_rate, iteration 97 | 98 | 99 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 100 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 101 | iteration, checkpoint_path)) 102 | if hasattr(model, 'module'): 103 | state_dict = model.module.state_dict() 104 | else: 105 | state_dict = model.state_dict() 106 | torch.save({'model': state_dict, 107 | 'iteration': iteration, 108 | 'optimizer': optimizer.state_dict(), 109 | 'learning_rate': learning_rate}, checkpoint_path) 110 | 111 | 112 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 113 | for k, v in scalars.items(): 114 | writer.add_scalar(k, v, global_step) 115 | for k, v in histograms.items(): 116 | writer.add_histogram(k, v, global_step) 117 | for k, v in images.items(): 118 | writer.add_image(k, v, global_step, dataformats='HWC') 119 | for k, v in audios.items(): 120 | writer.add_audio(k, v, global_step, audio_sampling_rate) 121 | 122 | 123 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 124 | f_list = glob.glob(os.path.join(dir_path, regex)) 125 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 126 | x = f_list[-1] 127 | print(x) 128 | return x 129 | 130 | 131 | def plot_spectrogram_to_numpy(spectrogram): 132 | global MATPLOTLIB_FLAG 133 | if not MATPLOTLIB_FLAG: 134 | import matplotlib 135 | matplotlib.use("Agg") 136 | MATPLOTLIB_FLAG = True 137 | mpl_logger = logging.getLogger('matplotlib') 138 | mpl_logger.setLevel(logging.WARNING) 139 | import matplotlib.pylab as plt 140 | import numpy as np 141 | 142 | fig, ax = plt.subplots(figsize=(10,2)) 143 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 144 | interpolation='none') 145 | plt.colorbar(im, ax=ax) 146 | plt.xlabel("Frames") 147 | plt.ylabel("Channels") 148 | plt.tight_layout() 149 | 150 | fig.canvas.draw() 151 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 152 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 153 | plt.close() 154 | return data 155 | 156 | 157 | def plot_alignment_to_numpy(alignment, info=None): 158 | global MATPLOTLIB_FLAG 159 | if not MATPLOTLIB_FLAG: 160 | import matplotlib 161 | matplotlib.use("Agg") 162 | MATPLOTLIB_FLAG = True 163 | mpl_logger = logging.getLogger('matplotlib') 164 | mpl_logger.setLevel(logging.WARNING) 165 | import matplotlib.pylab as plt 166 | import numpy as np 167 | 168 | fig, ax = plt.subplots(figsize=(6, 4)) 169 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 170 | interpolation='none') 171 | fig.colorbar(im, ax=ax) 172 | xlabel = 'Decoder timestep' 173 | if info is not None: 174 | xlabel += '\n\n' + info 175 | plt.xlabel(xlabel) 176 | plt.ylabel('Encoder timestep') 177 | plt.tight_layout() 178 | 179 | fig.canvas.draw() 180 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 181 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 182 | plt.close() 183 | return data 184 | 185 | 186 | def load_wav_to_torch(full_path): 187 | sampling_rate, data = read(full_path) 188 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 189 | 190 | 191 | def load_filepaths_and_text(filename, split="|"): 192 | with open(filename, encoding='utf-8') as f: 193 | filepaths_and_text = [line.strip().split(split) for line in f] 194 | return filepaths_and_text 195 | 196 | 197 | def get_hparams(init=True): 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 200 | help='JSON file for configuration') 201 | parser.add_argument('-m', '--model', type=str, required=True, 202 | help='Model name') 203 | 204 | args = parser.parse_args() 205 | model_dir = os.path.join("./logs", args.model) 206 | 207 | if not os.path.exists(model_dir): 208 | os.makedirs(model_dir) 209 | 210 | config_path = args.config 211 | config_save_path = os.path.join(model_dir, "config.json") 212 | if init: 213 | with open(config_path, "r") as f: 214 | data = f.read() 215 | with open(config_save_path, "w") as f: 216 | f.write(data) 217 | else: 218 | with open(config_save_path, "r") as f: 219 | data = f.read() 220 | config = json.loads(data) 221 | 222 | hparams = HParams(**config) 223 | hparams.model_dir = model_dir 224 | return hparams 225 | 226 | 227 | def get_hparams_from_dir(model_dir): 228 | config_save_path = os.path.join(model_dir, "config.json") 229 | with open(config_save_path, "r") as f: 230 | data = f.read() 231 | config = json.loads(data) 232 | 233 | hparams =HParams(**config) 234 | hparams.model_dir = model_dir 235 | return hparams 236 | 237 | 238 | def get_hparams_from_file(config_path): 239 | with open(config_path, "r") as f: 240 | data = f.read() 241 | config = json.loads(data) 242 | 243 | hparams =HParams(**config) 244 | return hparams 245 | 246 | 247 | def check_git_hash(model_dir): 248 | source_dir = os.path.dirname(os.path.realpath(__file__)) 249 | if not os.path.exists(os.path.join(source_dir, ".git")): 250 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 251 | source_dir 252 | )) 253 | return 254 | 255 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 256 | 257 | path = os.path.join(model_dir, "githash") 258 | if os.path.exists(path): 259 | saved_hash = open(path).read() 260 | if saved_hash != cur_hash: 261 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 262 | saved_hash[:8], cur_hash[:8])) 263 | else: 264 | open(path, "w").write(cur_hash) 265 | 266 | 267 | def get_logger(model_dir, filename="train.log"): 268 | global logger 269 | logger = logging.getLogger(os.path.basename(model_dir)) 270 | logger.setLevel(logging.DEBUG) 271 | 272 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 273 | if not os.path.exists(model_dir): 274 | os.makedirs(model_dir) 275 | h = logging.FileHandler(os.path.join(model_dir, filename)) 276 | h.setLevel(logging.DEBUG) 277 | h.setFormatter(formatter) 278 | logger.addHandler(h) 279 | return logger 280 | 281 | 282 | class HParams(): 283 | def __init__(self, **kwargs): 284 | for k, v in kwargs.items(): 285 | if type(v) == dict: 286 | v = HParams(**v) 287 | self[k] = v 288 | 289 | def keys(self): 290 | return self.__dict__.keys() 291 | 292 | def items(self): 293 | return self.__dict__.items() 294 | 295 | def values(self): 296 | return self.__dict__.values() 297 | 298 | def __len__(self): 299 | return len(self.__dict__) 300 | 301 | def __getitem__(self, key): 302 | return getattr(self, key) 303 | 304 | def __setitem__(self, key, value): 305 | return setattr(self, key, value) 306 | 307 | def __contains__(self, key): 308 | return key in self.__dict__ 309 | 310 | def __repr__(self): 311 | return self.__dict__.__repr__() 312 | -------------------------------------------------------------------------------- /wavlm/WavLM-Large.pt.txt: -------------------------------------------------------------------------------- 1 | https://github.com/microsoft/unilm/tree/master/wavlm -------------------------------------------------------------------------------- /wavlm/__init__.py: -------------------------------------------------------------------------------- 1 | from wavlm.WavLM import WavLM, WavLMConfig --------------------------------------------------------------------------------