├── LICENSE
├── README.md
├── commons.py
├── configs
└── hubert-neuraldec-vits.json
├── conv.py
├── convert.ipynb
├── data_utils.py
├── downsample.py
├── extra
├── DSConv.py
├── attentions.py
├── commons.py
└── modules.py
├── filelists
├── test.txt
├── train.txt
└── val.txt
├── hifigan
├── __init__.py
├── config.json
├── generator_v1.txt
└── models.py
├── losses.py
├── lstm.py
├── mel_processing.py
├── models.py
├── modules.py
├── norm.py
├── preprocess_code.py
├── preprocess_flist.py
├── preprocess_spk.py
├── preprocess_sr.py
├── requirements.txt
├── resources
└── NeurlVC.png
├── speaker_encoder
├── __init__.py
├── audio.py
├── ckpt
│ ├── pretrained_bak_5805000.pt
│ └── pretrained_bak_5805000.pt.txt
├── compute_embed.py
├── config.py
├── data_objects
│ ├── __init__.py
│ ├── random_cycler.py
│ ├── speaker.py
│ ├── speaker_batch.py
│ ├── speaker_verification_dataset.py
│ └── utterance.py
├── hparams.py
├── inference.py
├── model.py
├── params_data.py
├── params_model.py
├── preprocess.py
├── train.py
├── visualizations.py
└── voice_encoder.py
├── train.py
├── utils.py
└── wavlm
├── WavLM-Large.pt.txt
├── WavLM.py
├── __init__.py
└── modules.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jingyi Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NeuralVC Any-to-Any Voice Conversion Using Neural Networks Decoder For Real-Time Voice Conversion
2 |
3 |
4 | In this paper, we adopt the end-to-end [VITS](https://arxiv.org/abs/2106.06103) framework for high-quality waveform reconstruction. By introducing HuBERT-Soft, we extract clean speech content information, and by incorporating a pre-trained speaker encoder, we extract speaker characteristics from the speech. Inspired by the structure of [speech compression models](https://arxiv.org/abs/2210.13438), we propose a **neural decoder** that synthesizes converted speech with the target speaker's voice by adding preprocessing and conditioning networks to receive and interpret speaker information. Additionally, we significantly improve the model's inference speed, achieving real-time voice conversion.
5 |
6 | Audio samples:https://jinyuanzhang999.github.io/NeuralVC_Demo.github.io/
7 |
8 | We also provide the [pretrained models](https://1drv.ms/f/c/87587ec0bae9be5a/Ek_2ur6Uwr5Lq1g-C5-5FFUB5JkhHHhLPg9iQxKxFvHm0w?e=Zpcxec).
9 |
10 |
11 |
12 |  |
13 |
14 |
15 | Model Framework |
16 |
17 |
18 |
19 |
20 | ## Pre-requisites
21 |
22 | 1. Clone this repo: `git clone https://github.com/zzy1hjq/NeutralVC.git`
23 |
24 | 2. CD into this repo: `cd NeuralVC`
25 |
26 | 3. Install python requirements: `pip install -r requirements.txt`
27 |
28 | 4. Download the [VCTK](https://datashare.ed.ac.uk/handle/10283/3443) dataset (for training only)
29 |
30 |
31 | ## Inference Example
32 |
33 | Download the pretrained checkpoints and run:
34 |
35 | ```python
36 | # inference with NeuralVC
37 | # Replace the corresponding parameters
38 | convert.ipynb
39 | ```
40 |
41 | ## Training Example
42 |
43 | 1. Preprocess
44 |
45 | ```python
46 |
47 | # run this if you want a different train-val-test split
48 | python preprocess_flist.py
49 |
50 | # run this if you want to use pretrained speaker encoder
51 | python preprocess_spk.py
52 |
53 | # run this if you want to use a different content feature extractor.
54 | python preprocess_code.py
55 |
56 | ```
57 |
58 | 2. Train
59 |
60 | ```python
61 | # train NeuralVC
62 | python train.py
63 |
64 |
65 | ```
66 |
67 | ## References
68 |
69 | - https://github.com/jaywalnut310/vits
70 | - https://github.com/OlaWod/FreeVC
71 | - https://github.com/quickvc/QuickVC-VoiceConversion
72 | - https://github.com/facebookresearch/encodec
73 |
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | def init_weights(m, mean=0.0, std=0.01):
9 | classname = m.__class__.__name__
10 | if classname.find("Conv") != -1:
11 | m.weight.data.normal_(mean, std)
12 |
13 |
14 | def get_padding(kernel_size, dilation=1):
15 | return int((kernel_size*dilation - dilation)/2)
16 |
17 |
18 | def convert_pad_shape(pad_shape):
19 | l = pad_shape[::-1]
20 | pad_shape = [item for sublist in l for item in sublist]
21 | return pad_shape
22 |
23 |
24 | def intersperse(lst, item):
25 | result = [item] * (len(lst) * 2 + 1)
26 | result[1::2] = lst
27 | return result
28 |
29 |
30 | def kl_divergence(m_p, logs_p, m_q, logs_q):
31 | """KL(P||Q)"""
32 | kl = (logs_q - logs_p) - 0.5
33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34 | return kl
35 |
36 |
37 | def rand_gumbel(shape):
38 | """Sample from the Gumbel distribution, protect from overflows."""
39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40 | return -torch.log(-torch.log(uniform_samples))
41 |
42 |
43 | def rand_gumbel_like(x):
44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45 | return g
46 |
47 |
48 | def slice_segments(x, ids_str, segment_size=4):
49 | ret = torch.zeros_like(x[:, :, :segment_size])
50 | for i in range(x.size(0)):
51 | idx_str = ids_str[i]
52 | idx_end = idx_str + segment_size
53 | ret[i] = x[i, :, idx_str:idx_end]
54 | return ret
55 |
56 |
57 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
58 | b, d, t = x.size()
59 | if x_lengths is None:
60 | x_lengths = t
61 | ids_str_max = x_lengths - segment_size + 1
62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63 | ret = slice_segments(x, ids_str, segment_size)
64 | return ret, ids_str
65 |
66 |
67 | def rand_spec_segments(x, x_lengths=None, segment_size=4):
68 | b, d, t = x.size()
69 | if x_lengths is None:
70 | x_lengths = t
71 | ids_str_max = x_lengths - segment_size
72 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
73 | ret = slice_segments(x, ids_str, segment_size)
74 | return ret, ids_str
75 |
76 |
77 | def get_timing_signal_1d(
78 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
79 | position = torch.arange(length, dtype=torch.float)
80 | num_timescales = channels // 2
81 | log_timescale_increment = (
82 | math.log(float(max_timescale) / float(min_timescale)) /
83 | (num_timescales - 1))
84 | inv_timescales = min_timescale * torch.exp(
85 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
86 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
87 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
88 | signal = F.pad(signal, [0, 0, 0, channels % 2])
89 | signal = signal.view(1, channels, length)
90 | return signal
91 |
92 |
93 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
94 | b, channels, length = x.size()
95 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
96 | return x + signal.to(dtype=x.dtype, device=x.device)
97 |
98 |
99 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
100 | b, channels, length = x.size()
101 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
102 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
103 |
104 |
105 | def subsequent_mask(length):
106 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
107 | return mask
108 |
109 |
110 | @torch.jit.script
111 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
112 | n_channels_int = n_channels[0]
113 | in_act = input_a + input_b
114 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
115 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
116 | acts = t_act * s_act
117 | return acts
118 |
119 |
120 | def convert_pad_shape(pad_shape):
121 | l = pad_shape[::-1]
122 | pad_shape = [item for sublist in l for item in sublist]
123 | return pad_shape
124 |
125 |
126 | def shift_1d(x):
127 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
128 | return x
129 |
130 |
131 | def sequence_mask(length, max_length=None):
132 | if max_length is None:
133 | max_length = length.max()
134 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
135 | return x.unsqueeze(0) < length.unsqueeze(1)
136 |
137 |
138 | def generate_path(duration, mask):
139 | """
140 | duration: [b, 1, t_x]
141 | mask: [b, 1, t_y, t_x]
142 | """
143 | device = duration.device
144 |
145 | b, _, t_y, t_x = mask.shape
146 | cum_duration = torch.cumsum(duration, -1)
147 |
148 | cum_duration_flat = cum_duration.view(b * t_x)
149 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
150 | path = path.view(b, t_x, t_y)
151 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
152 | path = path.unsqueeze(1).transpose(2,3) * mask
153 | return path
154 |
155 |
156 | def clip_grad_value_(parameters, clip_value, norm_type=2):
157 | if isinstance(parameters, torch.Tensor):
158 | parameters = [parameters]
159 | parameters = list(filter(lambda p: p.grad is not None, parameters))
160 | norm_type = float(norm_type)
161 | if clip_value is not None:
162 | clip_value = float(clip_value)
163 |
164 | total_norm = 0
165 | for p in parameters:
166 | param_norm = p.grad.data.norm(norm_type)
167 | total_norm += param_norm.item() ** norm_type
168 | if clip_value is not None:
169 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
170 | total_norm = total_norm ** (1. / norm_type)
171 | return total_norm
172 |
--------------------------------------------------------------------------------
/configs/hubert-neuraldec-vits.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 10000,
5 | "seed": 1234,
6 | "epochs": 10000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 64,
11 | "fp16_run": false,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8960,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0,
18 | "use_sr": false,
19 | "max_speclen": 128,
20 | "port": "8001"
21 | },
22 | "data": {
23 | "training_files":"filelists/train.txt",
24 | "validation_files":"filelists/val.txt",
25 | "max_wav_value": 32768.0,
26 | "sampling_rate": 16000,
27 | "filter_length": 1280,
28 | "hop_length": 320,
29 | "win_length": 1280,
30 | "n_mel_channels": 80,
31 | "mel_fmin": 0.0,
32 | "mel_fmax": null
33 | },
34 | "model": {
35 | "inter_channels": 192,
36 | "hidden_channels": 192,
37 | "filter_channels": 768,
38 | "n_heads": 2,
39 | "n_layers": 6,
40 | "kernel_size": 3,
41 | "p_dropout": 0.1,
42 | "resblock": "1",
43 | "resblock_kernel_sizes": [3,7,11],
44 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
45 | "upsample_rates": [10,8,2,2],
46 | "upsample_initial_channel": 512,
47 | "upsample_kernel_sizes": [16,16,4,4],
48 | "n_layers_q": 3,
49 | "use_spectral_norm": false,
50 | "gin_channels": 256,
51 | "ssl_dim": 256,
52 | "use_spk": true
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Convolutional layers wrappers and utilities."""
8 |
9 | import math
10 | import typing as tp
11 | import warnings
12 |
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.nn.utils import spectral_norm, weight_norm
17 |
18 | from norm import ConvLayerNorm
19 |
20 |
21 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
22 | 'time_layer_norm', 'layer_norm', 'time_group_norm'])
23 |
24 |
25 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
26 | assert norm in CONV_NORMALIZATIONS
27 | if norm == 'weight_norm':
28 | return weight_norm(module)
29 | elif norm == 'spectral_norm':
30 | return spectral_norm(module)
31 | else:
32 | # We already check was in CONV_NORMALIZATION, so any other choice
33 | # doesn't need reparametrization.
34 | return module
35 |
36 |
37 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
38 | """Return the proper normalization module. If causal is True, this will ensure the returned
39 | module is causal, or return an error if the normalization doesn't support causal evaluation.
40 | """
41 | assert norm in CONV_NORMALIZATIONS
42 | if norm == 'layer_norm':
43 | assert isinstance(module, nn.modules.conv._ConvNd)
44 | return ConvLayerNorm(module.out_channels, **norm_kwargs)
45 | elif norm == 'time_group_norm':
46 | if causal:
47 | raise ValueError("GroupNorm doesn't support causal evaluation.")
48 | assert isinstance(module, nn.modules.conv._ConvNd)
49 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
50 | else:
51 | return nn.Identity()
52 |
53 |
54 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
55 | padding_total: int = 0) -> int:
56 | """See `pad_for_conv1d`.
57 | """
58 | length = x.shape[-1]
59 | n_frames = (length - kernel_size + padding_total) / stride + 1
60 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
61 | return ideal_length - length
62 |
63 |
64 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
65 | """Pad for a convolution to make sure that the last window is full.
66 | Extra padding is added at the end. This is required to ensure that we can rebuild
67 | an output of the same length, as otherwise, even with padding, some time steps
68 | might get removed.
69 | For instance, with total padding = 4, kernel size = 4, stride = 2:
70 | 0 0 1 2 3 4 5 0 0 # (0s are padding)
71 | 1 2 3 # (output frames of a convolution, last 0 is never used)
72 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
73 | 1 2 3 4 # once you removed padding, we are missing one time step !
74 | """
75 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
76 | return F.pad(x, (0, extra_padding))
77 |
78 |
79 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
80 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
81 | If this is the case, we insert extra 0 padding to the right before the reflection happen.
82 | """
83 | length = x.shape[-1]
84 | padding_left, padding_right = paddings
85 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
86 | if mode == 'reflect':
87 | max_pad = max(padding_left, padding_right)
88 | extra_pad = 0
89 | if length <= max_pad:
90 | extra_pad = max_pad - length + 1
91 | x = F.pad(x, (0, extra_pad))
92 | padded = F.pad(x, paddings, mode, value)
93 | end = padded.shape[-1] - extra_pad
94 | return padded[..., :end]
95 | else:
96 | return F.pad(x, paddings, mode, value)
97 |
98 |
99 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
100 | """Remove padding from x, handling properly zero padding. Only for 1d!"""
101 | padding_left, padding_right = paddings
102 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103 | assert (padding_left + padding_right) <= x.shape[-1]
104 | end = x.shape[-1] - padding_right
105 | return x[..., padding_left: end]
106 |
107 |
108 | class NormConv1d(nn.Module):
109 | """Wrapper around Conv1d and normalization applied to this conv
110 | to provide a uniform interface across normalization approaches.
111 | """
112 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
113 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
114 | super().__init__()
115 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
116 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
117 | self.norm_type = norm
118 |
119 | def forward(self, x):
120 | x = self.conv(x)
121 | x = self.norm(x)
122 | return x
123 |
124 |
125 | class NormConv2d(nn.Module):
126 | """Wrapper around Conv2d and normalization applied to this conv
127 | to provide a uniform interface across normalization approaches.
128 | """
129 | def __init__(self, *args, norm: str = 'none',
130 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131 | super().__init__()
132 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
133 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
134 | self.norm_type = norm
135 |
136 | def forward(self, x):
137 | x = self.conv(x)
138 | x = self.norm(x)
139 | return x
140 |
141 |
142 | class NormConvTranspose1d(nn.Module):
143 | """Wrapper around ConvTranspose1d and normalization applied to this conv
144 | to provide a uniform interface across normalization approaches.
145 | """
146 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
147 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148 | super().__init__()
149 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
150 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
151 | self.norm_type = norm
152 |
153 | def forward(self, x):
154 | x = self.convtr(x)
155 | x = self.norm(x)
156 | return x
157 |
158 |
159 | class NormConvTranspose2d(nn.Module):
160 | """Wrapper around ConvTranspose2d and normalization applied to this conv
161 | to provide a uniform interface across normalization approaches.
162 | """
163 | def __init__(self, *args, norm: str = 'none',
164 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165 | super().__init__()
166 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
167 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
168 |
169 | def forward(self, x):
170 | x = self.convtr(x)
171 | x = self.norm(x)
172 | return x
173 |
174 |
175 | class SConv1d(nn.Module):
176 | """Conv1d with some builtin handling of asymmetric or causal padding
177 | and normalization.
178 | """
179 | def __init__(self, in_channels: int, out_channels: int,
180 | kernel_size: int, stride: int = 1, dilation: int = 1,
181 | groups: int = 1, bias: bool = True, causal: bool = False,
182 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
183 | pad_mode: str = 'reflect'):
184 | super().__init__()
185 | # warn user on unusual setup between dilation and stride
186 | if stride > 1 and dilation > 1:
187 | warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
188 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
189 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
190 | dilation=dilation, groups=groups, bias=bias, causal=causal,
191 | norm=norm, norm_kwargs=norm_kwargs)
192 | self.causal = causal
193 | self.pad_mode = pad_mode
194 |
195 | def forward(self, x):
196 | B, C, T = x.shape
197 | kernel_size = self.conv.conv.kernel_size[0]
198 | stride = self.conv.conv.stride[0]
199 | dilation = self.conv.conv.dilation[0]
200 | kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
201 | padding_total = kernel_size - stride
202 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
203 | if self.causal:
204 | # Left padding for causal
205 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
206 | else:
207 | # Asymmetric padding required for odd strides
208 | padding_right = padding_total // 2
209 | padding_left = padding_total - padding_right
210 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
211 | return self.conv(x)
212 |
213 |
214 | class SConvTranspose1d(nn.Module):
215 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding
216 | and normalization.
217 | """
218 | def __init__(self, in_channels: int, out_channels: int,
219 | kernel_size: int, stride: int = 1, causal: bool = False,
220 | norm: str = 'none', trim_right_ratio: float = 1.,
221 | norm_kwargs: tp.Dict[str, tp.Any] = {}):
222 | super().__init__()
223 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
224 | causal=causal, norm=norm, norm_kwargs=norm_kwargs)
225 | self.causal = causal
226 | self.trim_right_ratio = trim_right_ratio
227 | assert self.causal or self.trim_right_ratio == 1., \
228 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
229 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
230 |
231 | def forward(self, x):
232 | kernel_size = self.convtr.convtr.kernel_size[0]
233 | stride = self.convtr.convtr.stride[0]
234 | padding_total = kernel_size - stride
235 |
236 | y = self.convtr(x)
237 |
238 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
239 | # removed at the very end, when keeping only the right length for the output,
240 | # as removing it here would require also passing the length at the matching layer
241 | # in the encoder.
242 | if self.causal:
243 | # Trim the padding on the right according to the specified ratio
244 | # if trim_right_ratio = 1.0, trim everything from right
245 | padding_right = math.ceil(padding_total * self.trim_right_ratio)
246 | padding_left = padding_total - padding_right
247 | y = unpad1d(y, (padding_left, padding_right))
248 | else:
249 | # Asymmetric padding required for odd strides
250 | padding_right = padding_total // 2
251 | padding_left = padding_total - padding_right
252 | y = unpad1d(y, (padding_left, padding_right))
253 | return y
254 |
--------------------------------------------------------------------------------
/convert.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import argparse\n",
11 | "import torch\n",
12 | "import librosa\n",
13 | "import time\n",
14 | "from scipy.io.wavfile import write\n",
15 | "from tqdm import tqdm\n",
16 | "import soundfile as sf\n",
17 | "import utils\n",
18 | "import time\n",
19 | "from models import HuBERT_NeuralDec_VITS\n",
20 | "from mel_processing import mel_spectrogram_torch\n",
21 | "import logging\n",
22 | "\n",
23 | "from speaker_encoder.voice_encoder import SpeakerEncoder\n",
24 | "logging.getLogger('numba').setLevel(logging.WARNING)"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "class Parameters:\n",
34 | " def __init__(self):\n",
35 | " self.hpfile = \"logs/neuralvc/config.json\"\n",
36 | " self.ptfile = \"logs/neuralvc/G_990000.pth\"\n",
37 | " self.model_name = \"hubert-neuraldec-vits\"\n",
38 | " self.outdir = \"output/temp\"\n",
39 | " self.use_timestamp = False\n",
40 | "args = Parameters()"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "if not os.path.exists(args.outdir):\n",
50 | " os.makedirs(args.outdir)\n",
51 | "\n",
52 | "# hps = utils.get_hparams_from_file(args.hpfile)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "os.makedirs(args.outdir, exist_ok=True)\n",
62 | "hps = utils.get_hparams_from_file(args.hpfile)\n",
63 | "\n",
64 | "print(\"Loading model...\")\n",
65 | "net_g = HuBERT_NeuralDec_VITS(\n",
66 | " hps.data.filter_length // 2 + 1,\n",
67 | " hps.train.segment_size // hps.data.hop_length,\n",
68 | " **hps.model)\n",
69 | "_ = net_g.eval()\n",
70 | "\n",
71 | "print(\"Loading checkpoint...\")\n",
72 | "_ = utils.load_checkpoint(args.ptfile, net_g, None, True)\n",
73 | "\n",
74 | "print(\"Loading hubert...\")\n",
75 | "hubert = torch.hub.load(\"bshall/hubert:main\", f\"hubert_soft\").eval() \n",
76 | "\n",
77 | "if hps.model.use_spk:\n",
78 | " print(\"Loading speaker encoder...\")\n",
79 | " smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')\n",
80 | "print(\"ok\")"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "from tqdm import tqdm\n",
90 | "\n",
91 | "def convert(src_list, tgt):\n",
92 | " tgtname = tgt.split(\"/\")[-1].split(\".\")[0]\n",
93 | " wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)\n",
94 | " if not os.path.exists(os.path.join(args.outdir, tgtname)):\n",
95 | " os.makedirs(os.path.join(args.outdir, tgtname))\n",
96 | " sf.write(os.path.join(args.outdir, tgtname, f\"tgt_{tgtname}.wav\"), wav_tgt, hps.data.sampling_rate)\n",
97 | " wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)\n",
98 | " g_tgt = smodel.embed_utterance(wav_tgt)\n",
99 | " g_tgt = torch.from_numpy(g_tgt).unsqueeze(0)\n",
100 | " for src in tqdm(src_list):\n",
101 | " srcname = src.split(\"/\")[-1].split(\".\")[0]\n",
102 | " title = srcname + \"-\" + tgtname\n",
103 | " with torch.no_grad():\n",
104 | " wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)\n",
105 | " sf.write(os.path.join(args.outdir, tgtname, f\"src_{srcname}.wav\"), wav_src, hps.data.sampling_rate)\n",
106 | " wav_src = torch.from_numpy(wav_src).unsqueeze(0).unsqueeze(0)\n",
107 | " c = hubert.units(wav_src)\n",
108 | " c = c.transpose(1,2)\n",
109 | " audio = net_g.infer(c, g=g_tgt)\n",
110 | " audio = audio[0][0].data.cpu().float().numpy()\n",
111 | " write(os.path.join(args.outdir, tgtname, f\"{title}.wav\"), hps.data.sampling_rate, audio)"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "# Test\n",
121 | "import time\n",
122 | "\n",
123 | "tgt1 = \"/mnt/hd/cma/zzy/dataset/test/M_5105_28233_000016_000001.wav\"\n",
124 | "\n",
125 | "src_list1 = [\"/mnt/hd/cma/zzy/dataset/test/F_3575_170457_000032_000001.wav\"]\n",
126 | "\n",
127 | "convert(src_list1, tgt1)"
128 | ]
129 | }
130 | ],
131 | "metadata": {
132 | "language_info": {
133 | "name": "python"
134 | }
135 | },
136 | "nbformat": 4,
137 | "nbformat_minor": 2
138 | }
139 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 |
8 | import commons
9 | from mel_processing import spectrogram_torch, spec_to_mel_torch
10 | from utils import load_wav_to_torch, load_filepaths_and_text, transform
11 | #import h5py
12 |
13 |
14 | """Multi speaker version"""
15 | class TextAudioSpeakerLoader(torch.utils.data.Dataset):
16 | """
17 | 1) loads audio, speaker_id, text pairs
18 | 2) normalizes text and converts them to sequences of integers
19 | 3) computes spectrograms from audio files.
20 | """
21 | def __init__(self, audiopaths, hparams):
22 | self.audiopaths = load_filepaths_and_text(audiopaths)
23 | self.max_wav_value = hparams.data.max_wav_value
24 | self.sampling_rate = hparams.data.sampling_rate
25 | self.filter_length = hparams.data.filter_length
26 | self.hop_length = hparams.data.hop_length
27 | self.win_length = hparams.data.win_length
28 | self.sampling_rate = hparams.data.sampling_rate
29 | self.use_sr = hparams.train.use_sr
30 | self.use_spk = hparams.model.use_spk
31 | self.spec_len = hparams.train.max_speclen
32 |
33 | random.seed(1235)
34 | random.shuffle(self.audiopaths)
35 | self._filter()
36 |
37 | def _filter(self):
38 | """
39 | Filter text & store spec lengths
40 | """
41 | # Store spectrogram lengths for Bucketing
42 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
43 | # spec_length = wav_length // hop_length
44 |
45 | lengths = []
46 | for audiopath in self.audiopaths:
47 | lengths.append(os.path.getsize(audiopath[0]) // (2 * self.hop_length))
48 | self.lengths = lengths
49 |
50 | def get_audio(self, filename):
51 | audio, sampling_rate = load_wav_to_torch(filename)
52 | if sampling_rate != self.sampling_rate:
53 | raise ValueError("{} SR doesn't match target {} SR,the audio is{}".format(
54 | sampling_rate, self.sampling_rate,filename))
55 | audio_norm = audio / self.max_wav_value
56 | audio_norm = audio_norm.unsqueeze(0)
57 | spec_filename = filename.replace(".wav", ".spec.pt")
58 | if os.path.exists(spec_filename):
59 | spec = torch.load(spec_filename)
60 | else:
61 | spec = spectrogram_torch(audio_norm, self.filter_length,
62 | self.sampling_rate, self.hop_length, self.win_length,
63 | center=False)
64 | spec = torch.squeeze(spec, 0)
65 | torch.save(spec, spec_filename)
66 |
67 | if self.use_spk:
68 | spk_filename = filename.replace(".wav", ".npy").replace("vctk-mini-16k", "spk")
69 | spk = torch.from_numpy(np.load(spk_filename))
70 |
71 | c_filename = filename.replace(".wav", ".pt")
72 | c=torch.load(c_filename)
73 | c=c.transpose(1,0).squeeze(1)
74 |
75 | return c, spec, audio_norm, spk
76 |
77 | def __getitem__(self, index):
78 | return self.get_audio(self.audiopaths[index][0])
79 |
80 | def __len__(self):
81 | return len(self.audiopaths)
82 |
83 |
84 |
85 | class TextAudioSpeakerCollate():
86 | """ Zero-pads model inputs and targets
87 | """
88 | def __init__(self, hps):
89 | self.hps = hps
90 | self.use_sr = hps.train.use_sr
91 | self.use_spk = hps.model.use_spk
92 |
93 | def __call__(self, batch):
94 | """Collate's training batch from normalized text, audio and speaker identities
95 | PARAMS
96 | ------
97 | batch: [text_normalized, spec_normalized, wav_normalized, sid]
98 | """
99 | # Right zero-pad all one-hot text sequences to max input length
100 | _, ids_sorted_decreasing = torch.sort(
101 | torch.LongTensor([x[0].size(1) for x in batch]),
102 | dim=0, descending=True)
103 |
104 | max_c_len = max([x[0].size(1) for x in batch])
105 | max_spec_len = max([x[1].size(1) for x in batch])
106 | max_wav_len = max([x[2].size(1) for x in batch])
107 |
108 | spec_lengths = torch.LongTensor(len(batch))
109 | wav_lengths = torch.LongTensor(len(batch))
110 | if self.use_spk:
111 | spks = torch.FloatTensor(len(batch), batch[0][3].size(0))
112 | else:
113 | spks = None
114 |
115 | c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_c_len)
116 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
117 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
118 | c_padded.zero_()
119 | spec_padded.zero_()
120 | wav_padded.zero_()
121 |
122 | for i in range(len(ids_sorted_decreasing)):
123 | row = batch[ids_sorted_decreasing[i]]
124 |
125 | c = row[0]
126 | c_padded[i, :, :c.size(1)] = c
127 |
128 | spec = row[1]
129 | spec_padded[i, :, :spec.size(1)] = spec
130 | spec_lengths[i] = spec.size(1)
131 |
132 | wav = row[2]
133 | wav_padded[i, :, :wav.size(1)] = wav
134 | wav_lengths[i] = wav.size(1)
135 |
136 | if self.use_spk:
137 | spks[i] = row[3]
138 |
139 | spec_seglen = spec_lengths[-1] if spec_lengths[-1] < self.hps.train.max_speclen + 1 else self.hps.train.max_speclen + 1
140 | wav_seglen = spec_seglen * self.hps.data.hop_length
141 |
142 | spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen)
143 | wav_padded = commons.slice_segments(wav_padded, ids_slice * self.hps.data.hop_length, wav_seglen)
144 |
145 | c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:,:,:-1]
146 |
147 | spec_padded = spec_padded[:,:,:-1]
148 | wav_padded = wav_padded[:,:,:-self.hps.data.hop_length]
149 |
150 | if self.use_spk:
151 | return c_padded, spec_padded, wav_padded, spks
152 | else:
153 | return c_padded, spec_padded, wav_padded
154 |
155 |
156 |
157 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
158 | """
159 | Maintain similar input lengths in a batch.
160 | Length groups are specified by boundaries.
161 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
162 |
163 | It removes samples which are not included in the boundaries.
164 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
165 | """
166 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
167 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
168 | self.lengths = dataset.lengths
169 | self.batch_size = batch_size
170 | self.boundaries = boundaries
171 |
172 | self.buckets, self.num_samples_per_bucket = self._create_buckets()
173 | self.total_size = sum(self.num_samples_per_bucket)
174 | self.num_samples = self.total_size // self.num_replicas
175 |
176 | def _create_buckets(self):
177 | buckets = [[] for _ in range(len(self.boundaries) - 1)]
178 | for i in range(len(self.lengths)):
179 | length = self.lengths[i]
180 | idx_bucket = self._bisect(length)
181 | if idx_bucket != -1:
182 | buckets[idx_bucket].append(i)
183 |
184 | for i in range(len(buckets) - 1, 0, -1):
185 | if len(buckets[i]) == 0:
186 | buckets.pop(i)
187 | self.boundaries.pop(i+1)
188 |
189 | num_samples_per_bucket = []
190 | for i in range(len(buckets)):
191 | len_bucket = len(buckets[i])
192 | total_batch_size = self.num_replicas * self.batch_size
193 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
194 | num_samples_per_bucket.append(len_bucket + rem)
195 | return buckets, num_samples_per_bucket
196 |
197 | def __iter__(self):
198 | # deterministically shuffle based on epoch
199 | g = torch.Generator()
200 | g.manual_seed(self.epoch)
201 |
202 | indices = []
203 | if self.shuffle:
204 | for bucket in self.buckets:
205 | indices.append(torch.randperm(len(bucket), generator=g).tolist())
206 | else:
207 | for bucket in self.buckets:
208 | indices.append(list(range(len(bucket))))
209 |
210 | batches = []
211 | for i in range(len(self.buckets)):
212 | bucket = self.buckets[i]
213 | len_bucket = len(bucket)
214 | ids_bucket = indices[i]
215 | num_samples_bucket = self.num_samples_per_bucket[i]
216 |
217 | # add extra samples to make it evenly divisible
218 | rem = num_samples_bucket - len_bucket
219 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
220 |
221 | # subsample
222 | ids_bucket = ids_bucket[self.rank::self.num_replicas]
223 |
224 | # batching
225 | for j in range(len(ids_bucket) // self.batch_size):
226 | batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]]
227 | batches.append(batch)
228 |
229 | if self.shuffle:
230 | batch_ids = torch.randperm(len(batches), generator=g).tolist()
231 | batches = [batches[i] for i in batch_ids]
232 | self.batches = batches
233 |
234 | assert len(self.batches) * self.batch_size == self.num_samples
235 | return iter(self.batches)
236 |
237 | def _bisect(self, x, lo=0, hi=None):
238 | if hi is None:
239 | hi = len(self.boundaries) - 1
240 |
241 | if hi > lo:
242 | mid = (hi + lo) // 2
243 | if self.boundaries[mid] < x and x <= self.boundaries[mid+1]:
244 | return mid
245 | elif x <= self.boundaries[mid]:
246 | return self._bisect(x, lo, mid)
247 | else:
248 | return self._bisect(x, mid + 1, hi)
249 | else:
250 | return -1
251 |
252 | def __len__(self):
253 | return self.num_samples // self.batch_size
254 |
--------------------------------------------------------------------------------
/downsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import librosa
4 | import numpy as np
5 | from multiprocessing import Pool, cpu_count
6 | from scipy.io import wavfile
7 | from tqdm import tqdm
8 |
9 |
10 | def process(wav_name):
11 | # speaker 's5', 'p280', 'p315' are excluded,
12 | speaker = wav_name[:4]
13 | wav_path = os.path.join(args.in_dir, speaker, wav_name)
14 | if os.path.exists(wav_path) and '_mic2.flac' in wav_path:
15 | os.makedirs(os.path.join(args.out_dir1, speaker), exist_ok=True)
16 | os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
17 | wav, sr = librosa.load(wav_path)
18 | wav, _ = librosa.effects.trim(wav, top_db=20)
19 | peak = np.abs(wav).max()
20 | if peak > 1.0:
21 | wav = 0.98 * wav / peak
22 | wav1 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr1)
23 | wav2 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr2)
24 | save_name = wav_name.replace("_mic2.flac", ".wav")
25 | save_path1 = os.path.join(args.out_dir1, speaker, save_name)
26 | save_path2 = os.path.join(args.out_dir2, speaker, save_name)
27 | wavfile.write(
28 | save_path1,
29 | args.sr1,
30 | (wav1 * np.iinfo(np.int16).max).astype(np.int16)
31 | )
32 | wavfile.write(
33 | save_path2,
34 | args.sr2,
35 | (wav2 * np.iinfo(np.int16).max).astype(np.int16)
36 | )
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument("--sr1", type=int, default=16000, help="sampling rate")
42 | parser.add_argument("--sr2", type=int, default=22050, help="sampling rate")
43 | parser.add_argument("--in_dir", type=str, default="/home/Datasets/lijingyi/data/vctk/wav48_silence_trimmed/", help="path to source dir")
44 | parser.add_argument("--out_dir1", type=str, default="./dataset/vctk-16k", help="path to target dir")
45 | parser.add_argument("--out_dir2", type=str, default="./dataset/vctk-22k", help="path to target dir")
46 | args = parser.parse_args()
47 |
48 | pool = Pool(processes=cpu_count()-2)
49 |
50 | for speaker in os.listdir(args.in_dir):
51 | spk_dir = os.path.join(args.in_dir, speaker)
52 | if os.path.isdir(spk_dir):
53 | for _ in tqdm(pool.imap_unordered(process, os.listdir(spk_dir))):
54 | pass
55 |
56 |
--------------------------------------------------------------------------------
/extra/DSConv.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.utils import remove_weight_norm, weight_norm
3 |
4 |
5 | class Depthwise_Separable_Conv1D(nn.Module):
6 | def __init__(
7 | self,
8 | in_channels,
9 | out_channels,
10 | kernel_size,
11 | stride=1,
12 | padding=0,
13 | dilation=1,
14 | bias=True,
15 | padding_mode='zeros', # TODO: refine this type
16 | device=None,
17 | dtype=None
18 | ):
19 | super().__init__()
20 | self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
21 | groups=in_channels, stride=stride, padding=padding, dilation=dilation, bias=bias,
22 | padding_mode=padding_mode, device=device, dtype=dtype)
23 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias,
24 | device=device, dtype=dtype)
25 |
26 | def forward(self, input):
27 | return self.point_conv(self.depth_conv(input))
28 |
29 | def weight_norm(self):
30 | self.depth_conv = weight_norm(self.depth_conv, name='weight')
31 | self.point_conv = weight_norm(self.point_conv, name='weight')
32 |
33 | def remove_weight_norm(self):
34 | self.depth_conv = remove_weight_norm(self.depth_conv, name='weight')
35 | self.point_conv = remove_weight_norm(self.point_conv, name='weight')
36 |
37 |
38 | class Depthwise_Separable_TransposeConv1D(nn.Module):
39 | def __init__(
40 | self,
41 | in_channels,
42 | out_channels,
43 | kernel_size,
44 | stride=1,
45 | padding=0,
46 | output_padding=0,
47 | bias=True,
48 | dilation=1,
49 | padding_mode='zeros', # TODO: refine this type
50 | device=None,
51 | dtype=None
52 | ):
53 | super().__init__()
54 | self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
55 | groups=in_channels, stride=stride, output_padding=output_padding,
56 | padding=padding, dilation=dilation, bias=bias, padding_mode=padding_mode,
57 | device=device, dtype=dtype)
58 | self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias,
59 | device=device, dtype=dtype)
60 |
61 | def forward(self, input):
62 | return self.point_conv(self.depth_conv(input))
63 |
64 | def weight_norm(self):
65 | self.depth_conv = weight_norm(self.depth_conv, name='weight')
66 | self.point_conv = weight_norm(self.point_conv, name='weight')
67 |
68 | def remove_weight_norm(self):
69 | remove_weight_norm(self.depth_conv, name='weight')
70 | remove_weight_norm(self.point_conv, name='weight')
71 |
72 |
73 | def weight_norm_modules(module, name='weight', dim=0):
74 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
75 | module.weight_norm()
76 | return module
77 | else:
78 | return weight_norm(module, name, dim)
79 |
80 |
81 | def remove_weight_norm_modules(module, name='weight'):
82 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
83 | module.remove_weight_norm()
84 | else:
85 | remove_weight_norm(module, name)
--------------------------------------------------------------------------------
/extra/attentions.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | import extra.commons as commons
8 | from extra.DSConv import weight_norm_modules
9 | from extra.modules import LayerNorm
10 |
11 |
12 | class FFT(nn.Module):
13 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
14 | proximal_bias=False, proximal_init=True, isflow=False, **kwargs):
15 | super().__init__()
16 | self.hidden_channels = hidden_channels
17 | self.filter_channels = filter_channels
18 | self.n_heads = n_heads
19 | self.n_layers = n_layers
20 | self.kernel_size = kernel_size
21 | self.p_dropout = p_dropout
22 | self.proximal_bias = proximal_bias
23 | self.proximal_init = proximal_init
24 | if isflow:
25 | cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
26 | self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
27 | self.cond_layer = weight_norm_modules(cond_layer, name='weight')
28 | self.gin_channels = kwargs["gin_channels"]
29 | self.drop = nn.Dropout(p_dropout)
30 | self.self_attn_layers = nn.ModuleList()
31 | self.norm_layers_0 = nn.ModuleList()
32 | self.ffn_layers = nn.ModuleList()
33 | self.norm_layers_1 = nn.ModuleList()
34 | for i in range(self.n_layers):
35 | self.self_attn_layers.append(
36 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
37 | proximal_bias=proximal_bias,
38 | proximal_init=proximal_init))
39 | self.norm_layers_0.append(LayerNorm(hidden_channels))
40 | self.ffn_layers.append(
41 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
42 | self.norm_layers_1.append(LayerNorm(hidden_channels))
43 |
44 | def forward(self, x, x_mask, g=None):
45 | """
46 | x: decoder input
47 | h: encoder output
48 | """
49 | if g is not None:
50 | g = self.cond_layer(g)
51 |
52 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
53 | x = x * x_mask
54 | for i in range(self.n_layers):
55 | if g is not None:
56 | x = self.cond_pre(x)
57 | cond_offset = i * 2 * self.hidden_channels
58 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
59 | x = commons.fused_add_tanh_sigmoid_multiply(
60 | x,
61 | g_l,
62 | torch.IntTensor([self.hidden_channels]))
63 | y = self.self_attn_layers[i](x, x, self_attn_mask)
64 | y = self.drop(y)
65 | x = self.norm_layers_0[i](x + y)
66 |
67 | y = self.ffn_layers[i](x, x_mask)
68 | y = self.drop(y)
69 | x = self.norm_layers_1[i](x + y)
70 | x = x * x_mask
71 | return x
72 |
73 |
74 | class Encoder(nn.Module):
75 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,
76 | **kwargs):
77 | super().__init__()
78 | self.hidden_channels = hidden_channels
79 | self.filter_channels = filter_channels
80 | self.n_heads = n_heads
81 | self.n_layers = n_layers
82 | self.kernel_size = kernel_size
83 | self.p_dropout = p_dropout
84 | self.window_size = window_size
85 |
86 | self.drop = nn.Dropout(p_dropout)
87 | self.attn_layers = nn.ModuleList()
88 | self.norm_layers_1 = nn.ModuleList()
89 | self.ffn_layers = nn.ModuleList()
90 | self.norm_layers_2 = nn.ModuleList()
91 | for i in range(self.n_layers):
92 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
93 | window_size=window_size))
94 | self.norm_layers_1.append(LayerNorm(hidden_channels))
95 | self.ffn_layers.append(
96 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
97 | self.norm_layers_2.append(LayerNorm(hidden_channels))
98 |
99 | def forward(self, x, x_mask):
100 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
101 | x = x * x_mask
102 | for i in range(self.n_layers):
103 | y = self.attn_layers[i](x, x, attn_mask)
104 | y = self.drop(y)
105 | x = self.norm_layers_1[i](x + y)
106 |
107 | y = self.ffn_layers[i](x, x_mask)
108 | y = self.drop(y)
109 | x = self.norm_layers_2[i](x + y)
110 | x = x * x_mask
111 | return x
112 |
113 |
114 | class Decoder(nn.Module):
115 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
116 | proximal_bias=False, proximal_init=True, **kwargs):
117 | super().__init__()
118 | self.hidden_channels = hidden_channels
119 | self.filter_channels = filter_channels
120 | self.n_heads = n_heads
121 | self.n_layers = n_layers
122 | self.kernel_size = kernel_size
123 | self.p_dropout = p_dropout
124 | self.proximal_bias = proximal_bias
125 | self.proximal_init = proximal_init
126 |
127 | self.drop = nn.Dropout(p_dropout)
128 | self.self_attn_layers = nn.ModuleList()
129 | self.norm_layers_0 = nn.ModuleList()
130 | self.encdec_attn_layers = nn.ModuleList()
131 | self.norm_layers_1 = nn.ModuleList()
132 | self.ffn_layers = nn.ModuleList()
133 | self.norm_layers_2 = nn.ModuleList()
134 | for i in range(self.n_layers):
135 | self.self_attn_layers.append(
136 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
137 | proximal_bias=proximal_bias, proximal_init=proximal_init))
138 | self.norm_layers_0.append(LayerNorm(hidden_channels))
139 | self.encdec_attn_layers.append(
140 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
141 | self.norm_layers_1.append(LayerNorm(hidden_channels))
142 | self.ffn_layers.append(
143 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
144 | self.norm_layers_2.append(LayerNorm(hidden_channels))
145 |
146 | def forward(self, x, x_mask, h, h_mask):
147 | """
148 | x: decoder input
149 | h: encoder output
150 | """
151 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
152 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
153 | x = x * x_mask
154 | for i in range(self.n_layers):
155 | y = self.self_attn_layers[i](x, x, self_attn_mask)
156 | y = self.drop(y)
157 | x = self.norm_layers_0[i](x + y)
158 |
159 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
160 | y = self.drop(y)
161 | x = self.norm_layers_1[i](x + y)
162 |
163 | y = self.ffn_layers[i](x, x_mask)
164 | y = self.drop(y)
165 | x = self.norm_layers_2[i](x + y)
166 | x = x * x_mask
167 | return x
168 |
169 |
170 | class MultiHeadAttention(nn.Module):
171 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True,
172 | block_length=None, proximal_bias=False, proximal_init=False):
173 | super().__init__()
174 | assert channels % n_heads == 0
175 |
176 | self.channels = channels
177 | self.out_channels = out_channels
178 | self.n_heads = n_heads
179 | self.p_dropout = p_dropout
180 | self.window_size = window_size
181 | self.heads_share = heads_share
182 | self.block_length = block_length
183 | self.proximal_bias = proximal_bias
184 | self.proximal_init = proximal_init
185 | self.attn = None
186 |
187 | self.k_channels = channels // n_heads
188 | self.conv_q = nn.Conv1d(channels, channels, 1)
189 | self.conv_k = nn.Conv1d(channels, channels, 1)
190 | self.conv_v = nn.Conv1d(channels, channels, 1)
191 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
192 | self.drop = nn.Dropout(p_dropout)
193 |
194 | if window_size is not None:
195 | n_heads_rel = 1 if heads_share else n_heads
196 | rel_stddev = self.k_channels ** -0.5
197 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
198 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
199 |
200 | nn.init.xavier_uniform_(self.conv_q.weight)
201 | nn.init.xavier_uniform_(self.conv_k.weight)
202 | nn.init.xavier_uniform_(self.conv_v.weight)
203 | if proximal_init:
204 | with torch.no_grad():
205 | self.conv_k.weight.copy_(self.conv_q.weight)
206 | self.conv_k.bias.copy_(self.conv_q.bias)
207 |
208 | def forward(self, x, c, attn_mask=None):
209 | q = self.conv_q(x)
210 | k = self.conv_k(c)
211 | v = self.conv_v(c)
212 |
213 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
214 |
215 | x = self.conv_o(x)
216 | return x
217 |
218 | def attention(self, query, key, value, mask=None):
219 | # reshape [b, d, t] -> [b, n_h, t, d_k]
220 | b, d, t_s, t_t = (*key.size(), query.size(2))
221 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
222 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
223 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
224 |
225 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
226 | if self.window_size is not None:
227 | assert t_s == t_t, "Relative attention is only available for self-attention."
228 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
229 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
230 | scores_local = self._relative_position_to_absolute_position(rel_logits)
231 | scores = scores + scores_local
232 | if self.proximal_bias:
233 | assert t_s == t_t, "Proximal bias is only available for self-attention."
234 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
235 | if mask is not None:
236 | scores = scores.masked_fill(mask == 0, -1e4)
237 | if self.block_length is not None:
238 | assert t_s == t_t, "Local attention is only available for self-attention."
239 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
240 | scores = scores.masked_fill(block_mask == 0, -1e4)
241 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
242 | p_attn = self.drop(p_attn)
243 | output = torch.matmul(p_attn, value)
244 | if self.window_size is not None:
245 | relative_weights = self._absolute_position_to_relative_position(p_attn)
246 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
247 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
248 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
249 | return output, p_attn
250 |
251 | def _matmul_with_relative_values(self, x, y):
252 | """
253 | x: [b, h, l, m]
254 | y: [h or 1, m, d]
255 | ret: [b, h, l, d]
256 | """
257 | ret = torch.matmul(x, y.unsqueeze(0))
258 | return ret
259 |
260 | def _matmul_with_relative_keys(self, x, y):
261 | """
262 | x: [b, h, l, d]
263 | y: [h or 1, m, d]
264 | ret: [b, h, l, m]
265 | """
266 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
267 | return ret
268 |
269 | def _get_relative_embeddings(self, relative_embeddings, length):
270 | 2 * self.window_size + 1
271 | # Pad first before slice to avoid using cond ops.
272 | pad_length = max(length - (self.window_size + 1), 0)
273 | slice_start_position = max((self.window_size + 1) - length, 0)
274 | slice_end_position = slice_start_position + 2 * length - 1
275 | if pad_length > 0:
276 | padded_relative_embeddings = F.pad(
277 | relative_embeddings,
278 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
279 | else:
280 | padded_relative_embeddings = relative_embeddings
281 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
282 | return used_relative_embeddings
283 |
284 | def _relative_position_to_absolute_position(self, x):
285 | """
286 | x: [b, h, l, 2*l-1]
287 | ret: [b, h, l, l]
288 | """
289 | batch, heads, length, _ = x.size()
290 | # Concat columns of pad to shift from relative to absolute indexing.
291 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
292 |
293 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
294 | x_flat = x.view([batch, heads, length * 2 * length])
295 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
296 |
297 | # Reshape and slice out the padded elements.
298 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
299 | return x_final
300 |
301 | def _absolute_position_to_relative_position(self, x):
302 | """
303 | x: [b, h, l, l]
304 | ret: [b, h, l, 2*l-1]
305 | """
306 | batch, heads, length, _ = x.size()
307 | # padd along column
308 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
309 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
310 | # add 0's in the beginning that will skew the elements after reshape
311 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
312 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
313 | return x_final
314 |
315 | def _attention_bias_proximal(self, length):
316 | """Bias for self-attention to encourage attention to close positions.
317 | Args:
318 | length: an integer scalar.
319 | Returns:
320 | a Tensor with shape [1, 1, length, length]
321 | """
322 | r = torch.arange(length, dtype=torch.float32)
323 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
324 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
325 |
326 |
327 | class FFN(nn.Module):
328 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None,
329 | causal=False):
330 | super().__init__()
331 | self.in_channels = in_channels
332 | self.out_channels = out_channels
333 | self.filter_channels = filter_channels
334 | self.kernel_size = kernel_size
335 | self.p_dropout = p_dropout
336 | self.activation = activation
337 | self.causal = causal
338 |
339 | if causal:
340 | self.padding = self._causal_padding
341 | else:
342 | self.padding = self._same_padding
343 |
344 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
345 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
346 | self.drop = nn.Dropout(p_dropout)
347 |
348 | def forward(self, x, x_mask):
349 | x = self.conv_1(self.padding(x * x_mask))
350 | if self.activation == "gelu":
351 | x = x * torch.sigmoid(1.702 * x)
352 | else:
353 | x = torch.relu(x)
354 | x = self.drop(x)
355 | x = self.conv_2(self.padding(x * x_mask))
356 | return x * x_mask
357 |
358 | def _causal_padding(self, x):
359 | if self.kernel_size == 1:
360 | return x
361 | pad_l = self.kernel_size - 1
362 | pad_r = 0
363 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
364 | x = F.pad(x, commons.convert_pad_shape(padding))
365 | return x
366 |
367 | def _same_padding(self, x):
368 | if self.kernel_size == 1:
369 | return x
370 | pad_l = (self.kernel_size - 1) // 2
371 | pad_r = self.kernel_size // 2
372 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
373 | x = F.pad(x, commons.convert_pad_shape(padding))
374 | return x
375 |
--------------------------------------------------------------------------------
/extra/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | def slice_pitch_segments(x, ids_str, segment_size=4):
8 | ret = torch.zeros_like(x[:, :segment_size])
9 | for i in range(x.size(0)):
10 | idx_str = ids_str[i]
11 | idx_end = idx_str + segment_size
12 | ret[i] = x[i, idx_str:idx_end]
13 | return ret
14 |
15 |
16 | def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
17 | b, d, t = x.size()
18 | if x_lengths is None:
19 | x_lengths = t
20 | ids_str_max = x_lengths - segment_size + 1
21 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
22 | ret = slice_segments(x, ids_str, segment_size)
23 | ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
24 | return ret, ret_pitch, ids_str
25 |
26 |
27 | def init_weights(m, mean=0.0, std=0.01):
28 | classname = m.__class__.__name__
29 | if "Depthwise_Separable" in classname:
30 | m.depth_conv.weight.data.normal_(mean, std)
31 | m.point_conv.weight.data.normal_(mean, std)
32 | elif classname.find("Conv") != -1:
33 | m.weight.data.normal_(mean, std)
34 |
35 |
36 | def get_padding(kernel_size, dilation=1):
37 | return int((kernel_size * dilation - dilation) / 2)
38 |
39 |
40 | def convert_pad_shape(pad_shape):
41 | l = pad_shape[::-1]
42 | pad_shape = [item for sublist in l for item in sublist]
43 | return pad_shape
44 |
45 |
46 | def intersperse(lst, item):
47 | result = [item] * (len(lst) * 2 + 1)
48 | result[1::2] = lst
49 | return result
50 |
51 |
52 | def kl_divergence(m_p, logs_p, m_q, logs_q):
53 | """KL(P||Q)"""
54 | kl = (logs_q - logs_p) - 0.5
55 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q)
56 | return kl
57 |
58 |
59 | def rand_gumbel(shape):
60 | """Sample from the Gumbel distribution, protect from overflows."""
61 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
62 | return -torch.log(-torch.log(uniform_samples))
63 |
64 |
65 | def rand_gumbel_like(x):
66 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
67 | return g
68 |
69 |
70 | def slice_segments(x, ids_str, segment_size=4):
71 | ret = torch.zeros_like(x[:, :, :segment_size])
72 | for i in range(x.size(0)):
73 | idx_str = ids_str[i]
74 | idx_end = idx_str + segment_size
75 | ret[i] = x[i, :, idx_str:idx_end]
76 | return ret
77 |
78 |
79 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
80 | b, d, t = x.size()
81 | if x_lengths is None:
82 | x_lengths = t
83 | ids_str_max = x_lengths - segment_size + 1
84 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
85 | ret = slice_segments(x, ids_str, segment_size)
86 | return ret, ids_str
87 |
88 |
89 | def rand_spec_segments(x, x_lengths=None, segment_size=4):
90 | b, d, t = x.size()
91 | if x_lengths is None:
92 | x_lengths = t
93 | ids_str_max = x_lengths - segment_size
94 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
95 | ret = slice_segments(x, ids_str, segment_size)
96 | return ret, ids_str
97 |
98 |
99 | def get_timing_signal_1d(
100 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
101 | position = torch.arange(length, dtype=torch.float)
102 | num_timescales = channels // 2
103 | log_timescale_increment = (
104 | math.log(float(max_timescale) / float(min_timescale)) /
105 | (num_timescales - 1))
106 | inv_timescales = min_timescale * torch.exp(
107 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
108 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
109 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
110 | signal = F.pad(signal, [0, 0, 0, channels % 2])
111 | signal = signal.view(1, channels, length)
112 | return signal
113 |
114 |
115 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
116 | b, channels, length = x.size()
117 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
118 | return x + signal.to(dtype=x.dtype, device=x.device)
119 |
120 |
121 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
122 | b, channels, length = x.size()
123 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
124 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
125 |
126 |
127 | def subsequent_mask(length):
128 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
129 | return mask
130 |
131 |
132 | @torch.jit.script
133 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
134 | n_channels_int = n_channels[0]
135 | in_act = input_a + input_b
136 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
137 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
138 | acts = t_act * s_act
139 | return acts
140 |
141 |
142 | def shift_1d(x):
143 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
144 | return x
145 |
146 |
147 | def sequence_mask(length, max_length=None):
148 | if max_length is None:
149 | max_length = length.max()
150 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
151 | return x.unsqueeze(0) < length.unsqueeze(1)
152 |
153 |
154 | def generate_path(duration, mask):
155 | """
156 | duration: [b, 1, t_x]
157 | mask: [b, 1, t_y, t_x]
158 | """
159 |
160 | b, _, t_y, t_x = mask.shape
161 | cum_duration = torch.cumsum(duration, -1)
162 |
163 | cum_duration_flat = cum_duration.view(b * t_x)
164 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
165 | path = path.view(b, t_x, t_y)
166 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
167 | path = path.unsqueeze(1).transpose(2, 3) * mask
168 | return path
169 |
170 |
171 | def clip_grad_value_(parameters, clip_value, norm_type=2):
172 | if isinstance(parameters, torch.Tensor):
173 | parameters = [parameters]
174 | parameters = list(filter(lambda p: p.grad is not None, parameters))
175 | norm_type = float(norm_type)
176 | if clip_value is not None:
177 | clip_value = float(clip_value)
178 |
179 | total_norm = 0
180 | for p in parameters:
181 | param_norm = p.grad.data.norm(norm_type)
182 | total_norm += param_norm.item() ** norm_type
183 | if clip_value is not None:
184 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
185 | total_norm = total_norm ** (1. / norm_type)
186 | return total_norm
187 |
--------------------------------------------------------------------------------
/extra/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | import extra.attentions as attentions
6 | import extra.commons as commons
7 | from extra.commons import get_padding, init_weights
8 | from extra.DSConv import (
9 | Depthwise_Separable_Conv1D,
10 | remove_weight_norm_modules,
11 | weight_norm_modules,
12 | )
13 |
14 | LRELU_SLOPE = 0.1
15 |
16 | Conv1dModel = nn.Conv1d
17 |
18 |
19 | def set_Conv1dModel(use_depthwise_conv):
20 | global Conv1dModel
21 | Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d
22 |
23 |
24 | class LayerNorm(nn.Module):
25 | def __init__(self, channels, eps=1e-5):
26 | super().__init__()
27 | self.channels = channels
28 | self.eps = eps
29 |
30 | self.gamma = nn.Parameter(torch.ones(channels))
31 | self.beta = nn.Parameter(torch.zeros(channels))
32 |
33 | def forward(self, x):
34 | x = x.transpose(1, -1)
35 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
36 | return x.transpose(1, -1)
37 |
38 |
39 | class ConvReluNorm(nn.Module):
40 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
41 | super().__init__()
42 | self.in_channels = in_channels
43 | self.hidden_channels = hidden_channels
44 | self.out_channels = out_channels
45 | self.kernel_size = kernel_size
46 | self.n_layers = n_layers
47 | self.p_dropout = p_dropout
48 | assert n_layers > 1, "Number of layers should be larger than 0."
49 |
50 | self.conv_layers = nn.ModuleList()
51 | self.norm_layers = nn.ModuleList()
52 | self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
53 | self.norm_layers.append(LayerNorm(hidden_channels))
54 | self.relu_drop = nn.Sequential(
55 | nn.ReLU(),
56 | nn.Dropout(p_dropout))
57 | for _ in range(n_layers - 1):
58 | self.conv_layers.append(
59 | Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
60 | self.norm_layers.append(LayerNorm(hidden_channels))
61 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
62 | self.proj.weight.data.zero_()
63 | self.proj.bias.data.zero_()
64 |
65 | def forward(self, x, x_mask):
66 | x_org = x
67 | for i in range(self.n_layers):
68 | x = self.conv_layers[i](x * x_mask)
69 | x = self.norm_layers[i](x)
70 | x = self.relu_drop(x)
71 | x = x_org + self.proj(x)
72 | return x * x_mask
73 |
74 |
75 | class WN(torch.nn.Module):
76 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
77 | super(WN, self).__init__()
78 | assert (kernel_size % 2 == 1)
79 | self.hidden_channels = hidden_channels
80 | self.kernel_size = kernel_size,
81 | self.dilation_rate = dilation_rate
82 | self.n_layers = n_layers
83 | self.gin_channels = gin_channels
84 | self.p_dropout = p_dropout
85 |
86 | self.in_layers = torch.nn.ModuleList()
87 | self.res_skip_layers = torch.nn.ModuleList()
88 | self.drop = nn.Dropout(p_dropout)
89 |
90 | if gin_channels != 0:
91 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
92 | self.cond_layer = weight_norm_modules(cond_layer, name='weight')
93 |
94 | for i in range(n_layers):
95 | dilation = dilation_rate ** i
96 | padding = int((kernel_size * dilation - dilation) / 2)
97 | in_layer = Conv1dModel(hidden_channels, 2 * hidden_channels, kernel_size,
98 | dilation=dilation, padding=padding)
99 | in_layer = weight_norm_modules(in_layer, name='weight')
100 | self.in_layers.append(in_layer)
101 |
102 | # last one is not necessary
103 | if i < n_layers - 1:
104 | res_skip_channels = 2 * hidden_channels
105 | else:
106 | res_skip_channels = hidden_channels
107 |
108 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
109 | res_skip_layer = weight_norm_modules(res_skip_layer, name='weight')
110 | self.res_skip_layers.append(res_skip_layer)
111 |
112 | def forward(self, x, x_mask, g=None, **kwargs):
113 | output = torch.zeros_like(x)
114 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
115 |
116 | if g is not None:
117 | g = self.cond_layer(g)
118 |
119 | for i in range(self.n_layers):
120 | x_in = self.in_layers[i](x)
121 | if g is not None:
122 | cond_offset = i * 2 * self.hidden_channels
123 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
124 | else:
125 | g_l = torch.zeros_like(x_in)
126 |
127 | acts = commons.fused_add_tanh_sigmoid_multiply(
128 | x_in,
129 | g_l,
130 | n_channels_tensor)
131 | acts = self.drop(acts)
132 |
133 | res_skip_acts = self.res_skip_layers[i](acts)
134 | if i < self.n_layers - 1:
135 | res_acts = res_skip_acts[:, :self.hidden_channels, :]
136 | x = (x + res_acts) * x_mask
137 | output = output + res_skip_acts[:, self.hidden_channels:, :]
138 | else:
139 | output = output + res_skip_acts
140 | return output * x_mask
141 |
142 | def remove_weight_norm(self):
143 | if self.gin_channels != 0:
144 | remove_weight_norm_modules(self.cond_layer)
145 | for l in self.in_layers:
146 | remove_weight_norm_modules(l)
147 | for l in self.res_skip_layers:
148 | remove_weight_norm_modules(l)
149 |
150 |
151 | class ResBlock1(torch.nn.Module):
152 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
153 | super(ResBlock1, self).__init__()
154 | self.convs1 = nn.ModuleList([
155 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
156 | padding=get_padding(kernel_size, dilation[0]))),
157 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
158 | padding=get_padding(kernel_size, dilation[1]))),
159 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2],
160 | padding=get_padding(kernel_size, dilation[2])))
161 | ])
162 | self.convs1.apply(init_weights)
163 |
164 | self.convs2 = nn.ModuleList([
165 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
166 | padding=get_padding(kernel_size, 1))),
167 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
168 | padding=get_padding(kernel_size, 1))),
169 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
170 | padding=get_padding(kernel_size, 1)))
171 | ])
172 | self.convs2.apply(init_weights)
173 |
174 | def forward(self, x, x_mask=None):
175 | for c1, c2 in zip(self.convs1, self.convs2):
176 | xt = F.leaky_relu(x, LRELU_SLOPE)
177 | if x_mask is not None:
178 | xt = xt * x_mask
179 | xt = c1(xt)
180 | xt = F.leaky_relu(xt, LRELU_SLOPE)
181 | if x_mask is not None:
182 | xt = xt * x_mask
183 | xt = c2(xt)
184 | x = xt + x
185 | if x_mask is not None:
186 | x = x * x_mask
187 | return x
188 |
189 | def remove_weight_norm(self):
190 | for l in self.convs1:
191 | remove_weight_norm_modules(l)
192 | for l in self.convs2:
193 | remove_weight_norm_modules(l)
194 |
195 |
196 | class ResBlock2(torch.nn.Module):
197 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
198 | super(ResBlock2, self).__init__()
199 | self.convs = nn.ModuleList([
200 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
201 | padding=get_padding(kernel_size, dilation[0]))),
202 | weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
203 | padding=get_padding(kernel_size, dilation[1])))
204 | ])
205 | self.convs.apply(init_weights)
206 |
207 | def forward(self, x, x_mask=None):
208 | for c in self.convs:
209 | xt = F.leaky_relu(x, LRELU_SLOPE)
210 | if x_mask is not None:
211 | xt = xt * x_mask
212 | xt = c(xt)
213 | x = xt + x
214 | if x_mask is not None:
215 | x = x * x_mask
216 | return x
217 |
218 | def remove_weight_norm(self):
219 | for l in self.convs:
220 | remove_weight_norm_modules(l)
221 |
222 |
223 | class Log(nn.Module):
224 | def forward(self, x, x_mask, reverse=False, **kwargs):
225 | if not reverse:
226 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
227 | logdet = torch.sum(-y, [1, 2])
228 | return y, logdet
229 | else:
230 | x = torch.exp(x) * x_mask
231 | return x
232 |
233 |
234 | class Flip(nn.Module):
235 | def forward(self, x, *args, reverse=False, **kwargs):
236 | x = torch.flip(x, [1])
237 | if not reverse:
238 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
239 | return x, logdet
240 | else:
241 | return x
242 |
243 |
244 | class ElementwiseAffine(nn.Module):
245 | def __init__(self, channels):
246 | super().__init__()
247 | self.channels = channels
248 | self.m = nn.Parameter(torch.zeros(channels, 1))
249 | self.logs = nn.Parameter(torch.zeros(channels, 1))
250 |
251 | def forward(self, x, x_mask, reverse=False, **kwargs):
252 | if not reverse:
253 | y = self.m + torch.exp(self.logs) * x
254 | y = y * x_mask
255 | logdet = torch.sum(self.logs * x_mask, [1, 2])
256 | return y, logdet
257 | else:
258 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
259 | return x
260 |
261 |
262 | class ResidualCouplingLayer(nn.Module):
263 | def __init__(self,
264 | channels,
265 | hidden_channels,
266 | kernel_size,
267 | dilation_rate,
268 | n_layers,
269 | p_dropout=0,
270 | gin_channels=0,
271 | mean_only=False,
272 | wn_sharing_parameter=None
273 | ):
274 | assert channels % 2 == 0, "channels should be divisible by 2"
275 | super().__init__()
276 | self.channels = channels
277 | self.hidden_channels = hidden_channels
278 | self.kernel_size = kernel_size
279 | self.dilation_rate = dilation_rate
280 | self.n_layers = n_layers
281 | self.half_channels = channels // 2
282 | self.mean_only = mean_only
283 |
284 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
285 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
286 | gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
287 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
288 | self.post.weight.data.zero_()
289 | self.post.bias.data.zero_()
290 |
291 | def forward(self, x, x_mask, g=None, reverse=False):
292 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
293 | h = self.pre(x0) * x_mask
294 | h = self.enc(h, x_mask, g=g)
295 | stats = self.post(h) * x_mask
296 | if not self.mean_only:
297 | m, logs = torch.split(stats, [self.half_channels] * 2, 1)
298 | else:
299 | m = stats
300 | logs = torch.zeros_like(m)
301 |
302 | if not reverse:
303 | x1 = m + x1 * torch.exp(logs) * x_mask
304 | x = torch.cat([x0, x1], 1)
305 | logdet = torch.sum(logs, [1, 2])
306 | return x, logdet
307 | else:
308 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
309 | x = torch.cat([x0, x1], 1)
310 | return x
311 |
312 |
313 | class TransformerCouplingLayer(nn.Module):
314 | def __init__(self,
315 | channels,
316 | hidden_channels,
317 | kernel_size,
318 | n_layers,
319 | n_heads,
320 | p_dropout=0,
321 | filter_channels=0,
322 | mean_only=False,
323 | wn_sharing_parameter=None,
324 | gin_channels=0
325 | ):
326 | assert channels % 2 == 0, "channels should be divisible by 2"
327 | super().__init__()
328 | self.channels = channels
329 | self.hidden_channels = hidden_channels
330 | self.kernel_size = kernel_size
331 | self.n_layers = n_layers
332 | self.half_channels = channels // 2
333 | self.mean_only = mean_only
334 |
335 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
336 | self.enc = attentions.FFT(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout,
337 | isflow=True,
338 | gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
339 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
340 | self.post.weight.data.zero_()
341 | self.post.bias.data.zero_()
342 |
343 | def forward(self, x, x_mask, g=None, reverse=False):
344 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
345 | h = self.pre(x0) * x_mask
346 | h = self.enc(h, x_mask, g=g)
347 | stats = self.post(h) * x_mask
348 | if not self.mean_only:
349 | m, logs = torch.split(stats, [self.half_channels] * 2, 1)
350 | else:
351 | m = stats
352 | logs = torch.zeros_like(m)
353 |
354 | if not reverse:
355 | x1 = m + x1 * torch.exp(logs) * x_mask
356 | x = torch.cat([x0, x1], 1)
357 | logdet = torch.sum(logs, [1, 2])
358 | return x, logdet
359 | else:
360 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
361 | x = torch.cat([x0, x1], 1)
362 | return x
363 |
--------------------------------------------------------------------------------
/filelists/val.txt:
--------------------------------------------------------------------------------
1 | DUMMY/p376/p376_392.wav
2 | DUMMY/p286/p286_273.wav
3 | DUMMY/p234/p234_034.wav
4 | DUMMY/p374/p374_392.wav
5 | DUMMY/p287/p287_200.wav
6 | DUMMY/p271/p271_053.wav
7 | DUMMY/p227/p227_159.wav
8 | DUMMY/p261/p261_019.wav
9 | DUMMY/p268/p268_174.wav
10 | DUMMY/p225/p225_021.wav
11 | DUMMY/p361/p361_283.wav
12 | DUMMY/p230/p230_016.wav
13 | DUMMY/p238/p238_193.wav
14 | DUMMY/p257/p257_063.wav
15 | DUMMY/p237/p237_058.wav
16 | DUMMY/p272/p272_381.wav
17 | DUMMY/p343/p343_313.wav
18 | DUMMY/p313/p313_201.wav
19 | DUMMY/p243/p243_041.wav
20 | DUMMY/p363/p363_231.wav
21 | DUMMY/p292/p292_394.wav
22 | DUMMY/p260/p260_027.wav
23 | DUMMY/p273/p273_206.wav
24 | DUMMY/p329/p329_355.wav
25 | DUMMY/p318/p318_197.wav
26 | DUMMY/p293/p293_223.wav
27 | DUMMY/p274/p274_289.wav
28 | DUMMY/p362/p362_019.wav
29 | DUMMY/p229/p229_226.wav
30 | DUMMY/p228/p228_063.wav
31 | DUMMY/p288/p288_171.wav
32 | DUMMY/p243/p243_130.wav
33 | DUMMY/p360/p360_250.wav
34 | DUMMY/p264/p264_078.wav
35 | DUMMY/p301/p301_268.wav
36 | DUMMY/p239/p239_269.wav
37 | DUMMY/p330/p330_306.wav
38 | DUMMY/p273/p273_311.wav
39 | DUMMY/p329/p329_094.wav
40 | DUMMY/p305/p305_163.wav
41 | DUMMY/p347/p347_226.wav
42 | DUMMY/p265/p265_051.wav
43 | DUMMY/p282/p282_078.wav
44 | DUMMY/p226/p226_181.wav
45 | DUMMY/p326/p326_173.wav
46 | DUMMY/p310/p310_254.wav
47 | DUMMY/p313/p313_410.wav
48 | DUMMY/p301/p301_209.wav
49 | DUMMY/p239/p239_343.wav
50 | DUMMY/p249/p249_025.wav
51 | DUMMY/p267/p267_333.wav
52 | DUMMY/p312/p312_337.wav
53 | DUMMY/p340/p340_020.wav
54 | DUMMY/p229/p229_368.wav
55 | DUMMY/p270/p270_450.wav
56 | DUMMY/p298/p298_145.wav
57 | DUMMY/p316/p316_034.wav
58 | DUMMY/p253/p253_148.wav
59 | DUMMY/p279/p279_301.wav
60 | DUMMY/p300/p300_231.wav
61 | DUMMY/p270/p270_068.wav
62 | DUMMY/p258/p258_061.wav
63 | DUMMY/p282/p282_231.wav
64 | DUMMY/p277/p277_315.wav
65 | DUMMY/p362/p362_134.wav
66 | DUMMY/p244/p244_063.wav
67 | DUMMY/p275/p275_212.wav
68 | DUMMY/p233/p233_044.wav
69 | DUMMY/p284/p284_192.wav
70 | DUMMY/p304/p304_156.wav
71 | DUMMY/p249/p249_102.wav
72 | DUMMY/p236/p236_076.wav
73 | DUMMY/p312/p312_373.wav
74 | DUMMY/p259/p259_299.wav
75 | DUMMY/p347/p347_164.wav
76 | DUMMY/p330/p330_363.wav
77 | DUMMY/p303/p303_046.wav
78 | DUMMY/p304/p304_167.wav
79 | DUMMY/p314/p314_235.wav
80 | DUMMY/p336/p336_040.wav
81 | DUMMY/p317/p317_077.wav
82 | DUMMY/p281/p281_402.wav
83 | DUMMY/p241/p241_345.wav
84 | DUMMY/p292/p292_022.wav
85 | DUMMY/p262/p262_240.wav
86 | DUMMY/p263/p263_098.wav
87 | DUMMY/p250/p250_304.wav
88 | DUMMY/p376/p376_062.wav
89 | DUMMY/p264/p264_181.wav
90 | DUMMY/p260/p260_109.wav
91 | DUMMY/p333/p333_282.wav
92 | DUMMY/p310/p310_073.wav
93 | DUMMY/p343/p343_189.wav
94 | DUMMY/p257/p257_217.wav
95 | DUMMY/p288/p288_135.wav
96 | DUMMY/p285/p285_076.wav
97 | DUMMY/p265/p265_242.wav
98 | DUMMY/p226/p226_337.wav
99 | DUMMY/p302/p302_125.wav
100 | DUMMY/p341/p341_020.wav
101 | DUMMY/p246/p246_060.wav
102 | DUMMY/p244/p244_381.wav
103 | DUMMY/p283/p283_222.wav
104 | DUMMY/p266/p266_335.wav
105 | DUMMY/p297/p297_374.wav
106 | DUMMY/p245/p245_190.wav
107 | DUMMY/p231/p231_471.wav
108 | DUMMY/p284/p284_217.wav
109 | DUMMY/p245/p245_118.wav
110 | DUMMY/p240/p240_196.wav
111 | DUMMY/p236/p236_205.wav
112 | DUMMY/p256/p256_042.wav
113 | DUMMY/p326/p326_064.wav
114 | DUMMY/p255/p255_353.wav
115 | DUMMY/p311/p311_138.wav
116 | DUMMY/p345/p345_057.wav
117 | DUMMY/p351/p351_418.wav
118 | DUMMY/p234/p234_023.wav
119 | DUMMY/p307/p307_067.wav
120 | DUMMY/p283/p283_137.wav
121 | DUMMY/p268/p268_058.wav
122 | DUMMY/p339/p339_143.wav
123 | DUMMY/p258/p258_287.wav
124 | DUMMY/p363/p363_322.wav
125 | DUMMY/p237/p237_196.wav
126 | DUMMY/p341/p341_008.wav
127 | DUMMY/p323/p323_278.wav
128 | DUMMY/p231/p231_453.wav
129 | DUMMY/p307/p307_412.wav
130 | DUMMY/p267/p267_266.wav
131 | DUMMY/p293/p293_272.wav
132 | DUMMY/p306/p306_321.wav
133 | DUMMY/p262/p262_393.wav
134 | DUMMY/p314/p314_267.wav
135 | DUMMY/p274/p274_373.wav
136 | DUMMY/p250/p250_336.wav
137 | DUMMY/p334/p334_148.wav
138 | DUMMY/p251/p251_294.wav
139 | DUMMY/p255/p255_042.wav
140 | DUMMY/p294/p294_251.wav
141 | DUMMY/p254/p254_217.wav
142 | DUMMY/p299/p299_122.wav
143 | DUMMY/p269/p269_150.wav
144 | DUMMY/p272/p272_084.wav
145 | DUMMY/p345/p345_320.wav
146 | DUMMY/p300/p300_267.wav
147 | DUMMY/p299/p299_109.wav
148 | DUMMY/p246/p246_160.wav
149 | DUMMY/p278/p278_366.wav
150 | DUMMY/p241/p241_306.wav
151 | DUMMY/p240/p240_213.wav
152 | DUMMY/p311/p311_376.wav
153 | DUMMY/p256/p256_006.wav
154 | DUMMY/p254/p254_356.wav
155 | DUMMY/p276/p276_111.wav
156 | DUMMY/p263/p263_270.wav
157 | DUMMY/p295/p295_371.wav
158 | DUMMY/p230/p230_184.wav
159 | DUMMY/p286/p286_341.wav
160 | DUMMY/p302/p302_160.wav
161 | DUMMY/p232/p232_396.wav
162 | DUMMY/p278/p278_077.wav
163 | DUMMY/p281/p281_367.wav
164 | DUMMY/p336/p336_316.wav
165 | DUMMY/p335/p335_365.wav
166 | DUMMY/p233/p233_041.wav
167 | DUMMY/p225/p225_038.wav
168 | DUMMY/p248/p248_160.wav
169 | DUMMY/p228/p228_230.wav
170 | DUMMY/p285/p285_197.wav
171 | DUMMY/p360/p360_157.wav
172 | DUMMY/p333/p333_259.wav
173 | DUMMY/p308/p308_335.wav
174 | DUMMY/p339/p339_218.wav
175 | DUMMY/p247/p247_320.wav
176 | DUMMY/p364/p364_014.wav
177 | DUMMY/p227/p227_255.wav
178 | DUMMY/p238/p238_060.wav
179 | DUMMY/p323/p323_373.wav
180 | DUMMY/p277/p277_045.wav
181 | DUMMY/p361/p361_152.wav
182 | DUMMY/p275/p275_380.wav
183 | DUMMY/p232/p232_180.wav
184 | DUMMY/p269/p269_130.wav
185 | DUMMY/p316/p316_055.wav
186 | DUMMY/p252/p252_247.wav
187 | DUMMY/p340/p340_036.wav
188 | DUMMY/p294/p294_414.wav
189 | DUMMY/p298/p298_228.wav
190 | DUMMY/p287/p287_348.wav
191 | DUMMY/p295/p295_214.wav
192 | DUMMY/p251/p251_222.wav
193 | DUMMY/p253/p253_339.wav
194 | DUMMY/p305/p305_327.wav
195 | DUMMY/p279/p279_283.wav
196 | DUMMY/p318/p318_342.wav
197 | DUMMY/p351/p351_194.wav
198 | DUMMY/p248/p248_016.wav
199 | DUMMY/p276/p276_321.wav
200 | DUMMY/p259/p259_262.wav
201 | DUMMY/p261/p261_018.wav
202 | DUMMY/p303/p303_320.wav
203 | DUMMY/p297/p297_122.wav
204 | DUMMY/p374/p374_106.wav
205 | DUMMY/p271/p271_200.wav
206 | DUMMY/p247/p247_315.wav
207 | DUMMY/p252/p252_402.wav
208 | DUMMY/p335/p335_308.wav
209 | DUMMY/p308/p308_104.wav
210 | DUMMY/p266/p266_040.wav
211 | DUMMY/p306/p306_312.wav
212 | DUMMY/p317/p317_201.wav
213 | DUMMY/p334/p334_357.wav
214 | DUMMY/p364/p364_027.wav
215 |
--------------------------------------------------------------------------------
/hifigan/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import Generator
2 |
3 |
4 | class AttrDict(dict):
5 | def __init__(self, *args, **kwargs):
6 | super(AttrDict, self).__init__(*args, **kwargs)
7 | self.__dict__ = self
--------------------------------------------------------------------------------
/hifigan/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,8,2,2],
12 | "upsample_kernel_sizes": [16,16,4,4],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 | "resblock_initial_channel": 256,
17 |
18 | "segment_size": 8192,
19 | "num_mels": 80,
20 | "num_freq": 1025,
21 | "n_fft": 1024,
22 | "hop_size": 256,
23 | "win_size": 1024,
24 |
25 | "sampling_rate": 22050,
26 |
27 | "fmin": 0,
28 | "fmax": 8000,
29 | "fmax_loss": null,
30 |
31 | "num_workers": 4,
32 |
33 | "dist_config": {
34 | "dist_backend": "nccl",
35 | "dist_url": "tcp://localhost:54321",
36 | "world_size": 1
37 | }
38 | }
--------------------------------------------------------------------------------
/hifigan/generator_v1.txt:
--------------------------------------------------------------------------------
1 | https://github.com/jik876/hifi-gan
--------------------------------------------------------------------------------
/hifigan/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Conv1d, ConvTranspose1d
5 | from torch.nn.utils import weight_norm, remove_weight_norm
6 |
7 | LRELU_SLOPE = 0.1
8 |
9 |
10 | def init_weights(m, mean=0.0, std=0.01):
11 | classname = m.__class__.__name__
12 | if classname.find("Conv") != -1:
13 | m.weight.data.normal_(mean, std)
14 |
15 |
16 | def get_padding(kernel_size, dilation=1):
17 | return int((kernel_size * dilation - dilation) / 2)
18 |
19 |
20 | class ResBlock(torch.nn.Module):
21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22 | super(ResBlock, self).__init__()
23 | self.h = h
24 | self.convs1 = nn.ModuleList(
25 | [
26 | weight_norm(
27 | Conv1d(
28 | channels,
29 | channels,
30 | kernel_size,
31 | 1,
32 | dilation=dilation[0],
33 | padding=get_padding(kernel_size, dilation[0]),
34 | )
35 | ),
36 | weight_norm(
37 | Conv1d(
38 | channels,
39 | channels,
40 | kernel_size,
41 | 1,
42 | dilation=dilation[1],
43 | padding=get_padding(kernel_size, dilation[1]),
44 | )
45 | ),
46 | weight_norm(
47 | Conv1d(
48 | channels,
49 | channels,
50 | kernel_size,
51 | 1,
52 | dilation=dilation[2],
53 | padding=get_padding(kernel_size, dilation[2]),
54 | )
55 | ),
56 | ]
57 | )
58 | self.convs1.apply(init_weights)
59 |
60 | self.convs2 = nn.ModuleList(
61 | [
62 | weight_norm(
63 | Conv1d(
64 | channels,
65 | channels,
66 | kernel_size,
67 | 1,
68 | dilation=1,
69 | padding=get_padding(kernel_size, 1),
70 | )
71 | ),
72 | weight_norm(
73 | Conv1d(
74 | channels,
75 | channels,
76 | kernel_size,
77 | 1,
78 | dilation=1,
79 | padding=get_padding(kernel_size, 1),
80 | )
81 | ),
82 | weight_norm(
83 | Conv1d(
84 | channels,
85 | channels,
86 | kernel_size,
87 | 1,
88 | dilation=1,
89 | padding=get_padding(kernel_size, 1),
90 | )
91 | ),
92 | ]
93 | )
94 | self.convs2.apply(init_weights)
95 |
96 | def forward(self, x):
97 | for c1, c2 in zip(self.convs1, self.convs2):
98 | xt = F.leaky_relu(x, LRELU_SLOPE)
99 | xt = c1(xt)
100 | xt = F.leaky_relu(xt, LRELU_SLOPE)
101 | xt = c2(xt)
102 | x = xt + x
103 | return x
104 |
105 | def remove_weight_norm(self):
106 | for l in self.convs1:
107 | remove_weight_norm(l)
108 | for l in self.convs2:
109 | remove_weight_norm(l)
110 |
111 |
112 | class Generator(torch.nn.Module):
113 | def __init__(self, h):
114 | super(Generator, self).__init__()
115 | self.h = h
116 | self.num_kernels = len(h.resblock_kernel_sizes)
117 | self.num_upsamples = len(h.upsample_rates)
118 | self.conv_pre = weight_norm(
119 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
120 | )
121 | resblock = ResBlock
122 |
123 | self.ups = nn.ModuleList()
124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125 | self.ups.append(
126 | weight_norm(
127 | ConvTranspose1d(
128 | h.upsample_initial_channel // (2 ** i),
129 | h.upsample_initial_channel // (2 ** (i + 1)),
130 | k,
131 | u,
132 | padding=(k - u) // 2,
133 | )
134 | )
135 | )
136 |
137 | self.resblocks = nn.ModuleList()
138 | for i in range(len(self.ups)):
139 | ch = h.upsample_initial_channel // (2 ** (i + 1))
140 | for j, (k, d) in enumerate(
141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142 | ):
143 | self.resblocks.append(resblock(h, ch, k, d))
144 |
145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146 | self.ups.apply(init_weights)
147 | self.conv_post.apply(init_weights)
148 |
149 | def forward(self, x):
150 | x = self.conv_pre(x)
151 | for i in range(self.num_upsamples):
152 | x = F.leaky_relu(x, LRELU_SLOPE)
153 | x = self.ups[i](x)
154 | xs = None
155 | for j in range(self.num_kernels):
156 | if xs is None:
157 | xs = self.resblocks[i * self.num_kernels + j](x)
158 | else:
159 | xs += self.resblocks[i * self.num_kernels + j](x)
160 | x = xs / self.num_kernels
161 | x = F.leaky_relu(x)
162 | x = self.conv_post(x)
163 | x = torch.tanh(x)
164 |
165 | return x
166 |
167 | def remove_weight_norm(self):
168 | print("Removing weight norm...")
169 | for l in self.ups:
170 | remove_weight_norm(l)
171 | for l in self.resblocks:
172 | l.remove_weight_norm()
173 | remove_weight_norm(self.conv_pre)
174 | remove_weight_norm(self.conv_post)
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import commons
5 |
6 |
7 | def feature_loss(fmap_r, fmap_g):
8 | loss = 0
9 | for dr, dg in zip(fmap_r, fmap_g):
10 | for rl, gl in zip(dr, dg):
11 | rl = rl.float().detach()
12 | gl = gl.float()
13 | loss += torch.mean(torch.abs(rl - gl))
14 |
15 | return loss * 2
16 |
17 |
18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
19 | loss = 0
20 | r_losses = []
21 | g_losses = []
22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
23 | dr = dr.float()
24 | dg = dg.float()
25 | r_loss = torch.mean((1-dr)**2)
26 | g_loss = torch.mean(dg**2)
27 | loss += (r_loss + g_loss)
28 | r_losses.append(r_loss.item())
29 | g_losses.append(g_loss.item())
30 |
31 | return loss, r_losses, g_losses
32 |
33 |
34 | def generator_loss(disc_outputs):
35 | loss = 0
36 | gen_losses = []
37 | for dg in disc_outputs:
38 | dg = dg.float()
39 | l = torch.mean((1-dg)**2)
40 | gen_losses.append(l)
41 | loss += l
42 |
43 | return loss, gen_losses
44 |
45 |
46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
47 | """
48 | z_p, logs_q: [b, h, t_t]
49 | m_p, logs_p: [b, h, t_t]
50 | """
51 | z_p = z_p.float()
52 | logs_q = logs_q.float()
53 | m_p = m_p.float()
54 | logs_p = logs_p.float()
55 | z_mask = z_mask.float()
56 | #print(logs_p)
57 | kl = logs_p - logs_q - 0.5
58 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
59 | kl = torch.sum(kl * z_mask)
60 | l = kl / torch.sum(z_mask)
61 | return l
62 |
--------------------------------------------------------------------------------
/lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """LSTM layers module."""
8 |
9 | from torch import nn
10 |
11 |
12 | class SLSTM(nn.Module):
13 | """
14 | LSTM without worrying about the hidden state, nor the layout of the data.
15 | Expects input as convolutional layout.
16 | """
17 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
18 | super().__init__()
19 | self.skip = skip
20 | self.lstm = nn.LSTM(dimension, dimension, num_layers)
21 |
22 | def forward(self, x):
23 | x = x.permute(2, 0, 1)
24 | y, _ = self.lstm(x)
25 | if self.skip:
26 | y = y + x
27 | y = y.permute(1, 2, 0)
28 | return y
29 |
--------------------------------------------------------------------------------
/mel_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | import torch.utils.data
8 | import numpy as np
9 | import librosa
10 | import librosa.util as librosa_util
11 | from librosa.util import normalize, pad_center, tiny
12 | from scipy.signal import get_window
13 | from scipy.io.wavfile import read
14 | from librosa.filters import mel as librosa_mel_fn
15 |
16 | MAX_WAV_VALUE = 32768.0
17 |
18 |
19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20 | """
21 | PARAMS
22 | ------
23 | C: compression factor
24 | """
25 | return torch.log(torch.clamp(x, min=clip_val) * C)
26 |
27 |
28 | def dynamic_range_decompression_torch(x, C=1):
29 | """
30 | PARAMS
31 | ------
32 | C: compression factor used to compress
33 | """
34 | return torch.exp(x) / C
35 |
36 |
37 | def spectral_normalize_torch(magnitudes):
38 | output = dynamic_range_compression_torch(magnitudes)
39 | return output
40 |
41 |
42 | def spectral_de_normalize_torch(magnitudes):
43 | output = dynamic_range_decompression_torch(magnitudes)
44 | return output
45 |
46 |
47 | mel_basis = {}
48 | hann_window = {}
49 |
50 |
51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52 | if torch.min(y) < -1.:
53 | print('min value is ', torch.min(y))
54 | if torch.max(y) > 1.:
55 | print('max value is ', torch.max(y))
56 |
57 | global hann_window
58 | dtype_device = str(y.dtype) + '_' + str(y.device)
59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
60 | if wnsize_dtype_device not in hann_window:
61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62 |
63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64 | y = y.squeeze(1)
65 |
66 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
68 |
69 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70 | return spec
71 |
72 |
73 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74 | global mel_basis
75 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
76 | fmax_dtype_device = str(fmax) + '_' + dtype_device
77 | if fmax_dtype_device not in mel_basis:
78 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81 | spec = spectral_normalize_torch(spec)
82 | return spec
83 |
84 |
85 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86 | if torch.min(y) < -1.:
87 | print('min value is ', torch.min(y))
88 | if torch.max(y) > 1.:
89 | print('max value is ', torch.max(y))
90 |
91 | global mel_basis, hann_window
92 | dtype_device = str(y.dtype) + '_' + str(y.device)
93 | fmax_dtype_device = str(fmax) + '_' + dtype_device
94 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
95 | if fmax_dtype_device not in mel_basis:
96 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
97 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98 | if wnsize_dtype_device not in hann_window:
99 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100 |
101 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102 | y = y.squeeze(1)
103 |
104 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
106 |
107 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108 |
109 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110 | spec = spectral_normalize_torch(spec)
111 |
112 | return spec
113 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10 | from torch.nn.utils import weight_norm, remove_weight_norm
11 |
12 | import commons
13 | from commons import init_weights, get_padding
14 |
15 |
16 | LRELU_SLOPE = 0.1
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | def __init__(self, channels, eps=1e-5):
21 | super().__init__()
22 | self.channels = channels
23 | self.eps = eps
24 |
25 | self.gamma = nn.Parameter(torch.ones(channels))
26 | self.beta = nn.Parameter(torch.zeros(channels))
27 |
28 | def forward(self, x):
29 | x = x.transpose(1, -1)
30 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31 | return x.transpose(1, -1)
32 |
33 |
34 | class ConvReluNorm(nn.Module):
35 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36 | super().__init__()
37 | self.in_channels = in_channels
38 | self.hidden_channels = hidden_channels
39 | self.out_channels = out_channels
40 | self.kernel_size = kernel_size
41 | self.n_layers = n_layers
42 | self.p_dropout = p_dropout
43 | assert n_layers > 1, "Number of layers should be larger than 0."
44 |
45 | self.conv_layers = nn.ModuleList()
46 | self.norm_layers = nn.ModuleList()
47 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48 | self.norm_layers.append(LayerNorm(hidden_channels))
49 | self.relu_drop = nn.Sequential(
50 | nn.ReLU(),
51 | nn.Dropout(p_dropout))
52 | for _ in range(n_layers-1):
53 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54 | self.norm_layers.append(LayerNorm(hidden_channels))
55 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56 | self.proj.weight.data.zero_()
57 | self.proj.bias.data.zero_()
58 |
59 | def forward(self, x, x_mask):
60 | x_org = x
61 | for i in range(self.n_layers):
62 | x = self.conv_layers[i](x * x_mask)
63 | x = self.norm_layers[i](x)
64 | x = self.relu_drop(x)
65 | x = x_org + self.proj(x)
66 | return x * x_mask
67 |
68 |
69 | class DDSConv(nn.Module):
70 | """
71 | Dialted and Depth-Separable Convolution
72 | """
73 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74 | super().__init__()
75 | self.channels = channels
76 | self.kernel_size = kernel_size
77 | self.n_layers = n_layers
78 | self.p_dropout = p_dropout
79 |
80 | self.drop = nn.Dropout(p_dropout)
81 | self.convs_sep = nn.ModuleList()
82 | self.convs_1x1 = nn.ModuleList()
83 | self.norms_1 = nn.ModuleList()
84 | self.norms_2 = nn.ModuleList()
85 | for i in range(n_layers):
86 | dilation = kernel_size ** i
87 | padding = (kernel_size * dilation - dilation) // 2
88 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89 | groups=channels, dilation=dilation, padding=padding
90 | ))
91 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92 | self.norms_1.append(LayerNorm(channels))
93 | self.norms_2.append(LayerNorm(channels))
94 |
95 | def forward(self, x, x_mask, g=None):
96 | if g is not None:
97 | x = x + g
98 | for i in range(self.n_layers):
99 | y = self.convs_sep[i](x * x_mask)
100 | y = self.norms_1[i](y)
101 | y = F.gelu(y)
102 | y = self.convs_1x1[i](y)
103 | y = self.norms_2[i](y)
104 | y = F.gelu(y)
105 | y = self.drop(y)
106 | x = x + y
107 | return x * x_mask
108 |
109 |
110 | class WN(torch.nn.Module):
111 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112 | super(WN, self).__init__()
113 | assert(kernel_size % 2 == 1)
114 | self.hidden_channels =hidden_channels
115 | self.kernel_size = kernel_size,
116 | self.dilation_rate = dilation_rate
117 | self.n_layers = n_layers
118 | self.gin_channels = gin_channels
119 | self.p_dropout = p_dropout
120 |
121 | self.in_layers = torch.nn.ModuleList()
122 | self.res_skip_layers = torch.nn.ModuleList()
123 | self.drop = nn.Dropout(p_dropout)
124 |
125 | if gin_channels != 0:
126 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128 |
129 | for i in range(n_layers):
130 | dilation = dilation_rate ** i
131 | padding = int((kernel_size * dilation - dilation) / 2)
132 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133 | dilation=dilation, padding=padding)
134 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135 | self.in_layers.append(in_layer)
136 |
137 | # last one is not necessary
138 | if i < n_layers - 1:
139 | res_skip_channels = 2 * hidden_channels
140 | else:
141 | res_skip_channels = hidden_channels
142 |
143 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145 | self.res_skip_layers.append(res_skip_layer)
146 |
147 | def forward(self, x, x_mask, g=None, **kwargs):
148 | output = torch.zeros_like(x)
149 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
150 |
151 | if g is not None:
152 | g = self.cond_layer(g)
153 |
154 | for i in range(self.n_layers):
155 | x_in = self.in_layers[i](x)
156 | if g is not None:
157 | cond_offset = i * 2 * self.hidden_channels
158 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159 | else:
160 | g_l = torch.zeros_like(x_in)
161 |
162 | acts = commons.fused_add_tanh_sigmoid_multiply(
163 | x_in,
164 | g_l,
165 | n_channels_tensor)
166 | acts = self.drop(acts)
167 |
168 | res_skip_acts = self.res_skip_layers[i](acts)
169 | if i < self.n_layers - 1:
170 | res_acts = res_skip_acts[:,:self.hidden_channels,:]
171 | x = (x + res_acts) * x_mask
172 | output = output + res_skip_acts[:,self.hidden_channels:,:]
173 | else:
174 | output = output + res_skip_acts
175 | return output * x_mask
176 |
177 | def remove_weight_norm(self):
178 | if self.gin_channels != 0:
179 | torch.nn.utils.remove_weight_norm(self.cond_layer)
180 | for l in self.in_layers:
181 | torch.nn.utils.remove_weight_norm(l)
182 | for l in self.res_skip_layers:
183 | torch.nn.utils.remove_weight_norm(l)
184 |
185 |
186 | class ResBlock1(torch.nn.Module):
187 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188 | super(ResBlock1, self).__init__()
189 | self.convs1 = nn.ModuleList([
190 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191 | padding=get_padding(kernel_size, dilation[0]))),
192 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193 | padding=get_padding(kernel_size, dilation[1]))),
194 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195 | padding=get_padding(kernel_size, dilation[2])))
196 | ])
197 | self.convs1.apply(init_weights)
198 |
199 | self.convs2 = nn.ModuleList([
200 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201 | padding=get_padding(kernel_size, 1))),
202 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203 | padding=get_padding(kernel_size, 1))),
204 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205 | padding=get_padding(kernel_size, 1)))
206 | ])
207 | self.convs2.apply(init_weights)
208 |
209 | def forward(self, x, x_mask=None):
210 | for c1, c2 in zip(self.convs1, self.convs2):
211 | xt = F.leaky_relu(x, LRELU_SLOPE)
212 | if x_mask is not None:
213 | xt = xt * x_mask
214 | xt = c1(xt)
215 | xt = F.leaky_relu(xt, LRELU_SLOPE)
216 | if x_mask is not None:
217 | xt = xt * x_mask
218 | xt = c2(xt)
219 | x = xt + x
220 | if x_mask is not None:
221 | x = x * x_mask
222 | return x
223 |
224 | def remove_weight_norm(self):
225 | for l in self.convs1:
226 | remove_weight_norm(l)
227 | for l in self.convs2:
228 | remove_weight_norm(l)
229 |
230 |
231 | class ResBlock2(torch.nn.Module):
232 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233 | super(ResBlock2, self).__init__()
234 | self.convs = nn.ModuleList([
235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236 | padding=get_padding(kernel_size, dilation[0]))),
237 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238 | padding=get_padding(kernel_size, dilation[1])))
239 | ])
240 | self.convs.apply(init_weights)
241 |
242 | def forward(self, x, x_mask=None):
243 | for c in self.convs:
244 | xt = F.leaky_relu(x, LRELU_SLOPE)
245 | if x_mask is not None:
246 | xt = xt * x_mask
247 | xt = c(xt)
248 | x = xt + x
249 | if x_mask is not None:
250 | x = x * x_mask
251 | return x
252 |
253 | def remove_weight_norm(self):
254 | for l in self.convs:
255 | remove_weight_norm(l)
256 |
257 |
258 | class Log(nn.Module):
259 | def forward(self, x, x_mask, reverse=False, **kwargs):
260 | if not reverse:
261 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262 | logdet = torch.sum(-y, [1, 2])
263 | return y, logdet
264 | else:
265 | x = torch.exp(x) * x_mask
266 | return x
267 |
268 |
269 | class Flip(nn.Module):
270 | def forward(self, x, *args, reverse=False, **kwargs):
271 | x = torch.flip(x, [1])
272 | if not reverse:
273 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274 | return x, logdet
275 | else:
276 | return x
277 |
278 |
279 | class ElementwiseAffine(nn.Module):
280 | def __init__(self, channels):
281 | super().__init__()
282 | self.channels = channels
283 | self.m = nn.Parameter(torch.zeros(channels,1))
284 | self.logs = nn.Parameter(torch.zeros(channels,1))
285 |
286 | def forward(self, x, x_mask, reverse=False, **kwargs):
287 | if not reverse:
288 | y = self.m + torch.exp(self.logs) * x
289 | y = y * x_mask
290 | logdet = torch.sum(self.logs * x_mask, [1,2])
291 | return y, logdet
292 | else:
293 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
294 | return x
295 |
296 |
297 | class ResidualCouplingLayer(nn.Module):
298 | def __init__(self,
299 | channels,
300 | hidden_channels,
301 | kernel_size,
302 | dilation_rate,
303 | n_layers,
304 | p_dropout=0,
305 | gin_channels=0,
306 | mean_only=False):
307 | assert channels % 2 == 0, "channels should be divisible by 2"
308 | super().__init__()
309 | self.channels = channels
310 | self.hidden_channels = hidden_channels
311 | self.kernel_size = kernel_size
312 | self.dilation_rate = dilation_rate
313 | self.n_layers = n_layers
314 | self.half_channels = channels // 2
315 | self.mean_only = mean_only
316 |
317 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320 | self.post.weight.data.zero_()
321 | self.post.bias.data.zero_()
322 |
323 | def forward(self, x, x_mask, g=None, reverse=False):
324 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325 | h = self.pre(x0) * x_mask
326 | h = self.enc(h, x_mask, g=g)
327 | stats = self.post(h) * x_mask
328 | if not self.mean_only:
329 | m, logs = torch.split(stats, [self.half_channels]*2, 1)
330 | else:
331 | m = stats
332 | logs = torch.zeros_like(m)
333 |
334 | if not reverse:
335 | x1 = m + x1 * torch.exp(logs) * x_mask
336 | x = torch.cat([x0, x1], 1)
337 | logdet = torch.sum(logs, [1,2])
338 | return x, logdet
339 | else:
340 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
341 | x = torch.cat([x0, x1], 1)
342 | return x
343 |
--------------------------------------------------------------------------------
/norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Normalization modules."""
8 |
9 | import typing as tp
10 |
11 | import einops
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class ConvLayerNorm(nn.LayerNorm):
17 | """
18 | Convolution-friendly LayerNorm that moves channels to last dimensions
19 | before running the normalization and moves them back to original position right after.
20 | """
21 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
22 | super().__init__(normalized_shape, **kwargs)
23 |
24 | def forward(self, x):
25 | x = einops.rearrange(x, 'b ... t -> b t ...')
26 | x = super().forward(x)
27 | x = einops.rearrange(x, 'b t ... -> b ... t')
28 | return
29 |
--------------------------------------------------------------------------------
/preprocess_code.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import librosa
4 |
5 | hubert = torch.hub.load("bshall/hubert:main", f"hubert_soft").eval()
6 |
7 | def get_code(vctk_path):
8 | speakers = os.listdir(vctk_path)
9 | for spk in speakers:
10 | files_path = f"{vctk_path}/{spk}"
11 | wavs_path = os.listdir(files_path)
12 | for wav_name in wavs_path:
13 | wav_path = f"{files_path}/{wav_name}"
14 | wav, r = librosa.load(wav_path)
15 | wav = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0)
16 | c = hubert.units(wav)
17 | c = c.transpose(1,2)
18 | torch.save(c, wav_path.replace(".wav", ".pt"))
19 | c_path = wav_path.replace(".wav", ".pt")
20 | print(f"content code saved in {c_path}")
21 |
22 | if __name__ == "__main__":
23 | vctk_path = ".dataset/vctk-16k"
24 | get_code(vctk_path)
--------------------------------------------------------------------------------
/preprocess_flist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from random import shuffle
5 |
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
10 | parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
11 | parser.add_argument("--test_list", type=str, default="./filelists/test.txt", help="path to test list")
12 | parser.add_argument("--source_dir", type=str, default="./dataset/vctk-16k", help="path to source dir")
13 | args = parser.parse_args()
14 |
15 | train = []
16 | val = []
17 | test = []
18 | idx = 0
19 |
20 | for speaker in tqdm(os.listdir(args.source_dir)):
21 | wavs = os.listdir(os.path.join(args.source_dir, speaker))
22 | shuffle(wavs)
23 | train += wavs[2:-10]
24 | val += wavs[:2]
25 | test += wavs[-10:]
26 |
27 | shuffle(train)
28 | shuffle(val)
29 | shuffle(test)
30 |
31 | print("Writing", args.train_list)
32 | with open(args.train_list, "w") as f:
33 | for fname in tqdm(train):
34 | speaker = fname[:4]
35 | wavpath = os.path.join("DUMMY", speaker, fname)
36 | f.write(wavpath + "\n")
37 |
38 | print("Writing", args.val_list)
39 | with open(args.val_list, "w") as f:
40 | for fname in tqdm(val):
41 | speaker = fname[:4]
42 | wavpath = os.path.join("DUMMY", speaker, fname)
43 | f.write(wavpath + "\n")
44 |
45 | print("Writing", args.test_list)
46 | with open(args.test_list, "w") as f:
47 | for fname in tqdm(test):
48 | speaker = fname[:4]
49 | wavpath = os.path.join("DUMMY", speaker, fname)
50 | f.write(wavpath + "\n")
51 |
--------------------------------------------------------------------------------
/preprocess_spk.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from speaker_encoder.voice_encoder import SpeakerEncoder
3 | from speaker_encoder.audio import preprocess_wav
4 | from pathlib import Path
5 | import numpy as np
6 | from os.path import join, basename, split
7 | from tqdm import tqdm
8 | from multiprocessing import cpu_count
9 | from concurrent.futures import ProcessPoolExecutor
10 | from functools import partial
11 | import glob
12 | import argparse
13 |
14 |
15 | def build_from_path(in_dir, out_dir, weights_fpath, num_workers=1):
16 | executor = ProcessPoolExecutor(max_workers=num_workers)
17 | futures = []
18 | wavfile_paths = glob.glob(os.path.join(in_dir, '*.wav'))
19 | wavfile_paths= sorted(wavfile_paths)
20 | for wav_path in wavfile_paths:
21 | futures.append(executor.submit(
22 | partial(_compute_spkEmbed, out_dir, wav_path, weights_fpath)))
23 | return [future.result() for future in tqdm(futures)]
24 |
25 | def _compute_spkEmbed(out_dir, wav_path, weights_fpath):
26 | utt_id = os.path.basename(wav_path).rstrip(".wav")
27 | fpath = Path(wav_path)
28 | wav = preprocess_wav(fpath)
29 |
30 | encoder = SpeakerEncoder(weights_fpath)
31 | embed = encoder.embed_utterance(wav)
32 | fname_save = os.path.join(out_dir, f"{utt_id}.npy")
33 | np.save(fname_save, embed, allow_pickle=False)
34 | return os.path.basename(fname_save)
35 |
36 | def preprocess(in_dir, out_dir_root, spk, weights_fpath, num_workers):
37 | out_dir = os.path.join(out_dir_root, spk)
38 | os.makedirs(out_dir, exist_ok=True)
39 | metadata = build_from_path(in_dir, out_dir, weights_fpath, num_workers)
40 |
41 | if __name__ == "__main__":
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--in_dir', type=str,
44 | default='dataset/vctk-16k/')
45 | parser.add_argument('--num_workers', type=int, default=12)
46 | parser.add_argument('--out_dir_root', type=str,
47 | default='dataset/vctk-16k/')
48 | parser.add_argument('--spk_encoder_ckpt', type=str, \
49 | default='speaker_encoder/ckpt/pretrained_bak_5805000.pt')
50 |
51 | args = parser.parse_args()
52 |
53 | #split_list = ['train-clean-100', 'train-clean-360']
54 |
55 | sub_folder_list = os.listdir(args.in_dir)
56 | sub_folder_list.sort()
57 |
58 | args.num_workers = args.num_workers if args.num_workers is not None else cpu_count()
59 | print("Number of workers: ", args.num_workers)
60 | ckpt_step = os.path.basename(args.spk_encoder_ckpt).split('.')[0].split('_')[-1]
61 | spk_embed_out_dir = os.path.join(args.out_dir_root, "spk")
62 | print("[INFO] spk_embed_out_dir: ", spk_embed_out_dir)
63 | os.makedirs(spk_embed_out_dir, exist_ok=True)
64 |
65 | #for data_split in split_list:
66 | # sub_folder_list = os.listdir(args.in_dir, data_split)
67 | for spk in sub_folder_list:
68 | print("Preprocessing {} ...".format(spk))
69 | in_dir = os.path.join(args.in_dir, spk)
70 | if not os.path.isdir(in_dir):
71 | continue
72 | #out_dir = os.path.join(args.out_dir, spk)
73 | preprocess(in_dir, spk_embed_out_dir, spk, args.spk_encoder_ckpt, args.num_workers)
74 | '''
75 | for data_split in split_list:
76 | in_dir = os.path.join(args.in_dir, data_split)
77 | preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers)
78 | '''
79 |
80 | print("DONE!")
81 | sys.exit(0)
82 |
83 |
84 |
--------------------------------------------------------------------------------
/preprocess_sr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import librosa
5 | import json
6 | from glob import glob
7 | from tqdm import tqdm
8 | from scipy.io import wavfile
9 |
10 | import utils
11 | from mel_processing import mel_spectrogram_torch
12 | from wavlm import WavLM, WavLMConfig
13 | #import h5py
14 | import logging
15 | logging.getLogger('numba').setLevel(logging.WARNING)
16 |
17 |
18 | def process(filename):
19 | basename = os.path.basename(filename)
20 | speaker = filename.split("/")[-2]#basename[:4]
21 | wav_dir = os.path.join(args.wav_dir, speaker)
22 | ssl_dir = os.path.join(args.ssl_dir, speaker)
23 | os.makedirs(wav_dir, exist_ok=True)
24 | os.makedirs(ssl_dir, exist_ok=True)
25 | wav, _ = librosa.load(filename, sr=hps.sampling_rate)
26 | wav = torch.from_numpy(wav).unsqueeze(0).cuda()
27 | mel = mel_spectrogram_torch(
28 | wav,
29 | hps.n_fft,
30 | hps.num_mels,
31 | hps.sampling_rate,
32 | hps.hop_size,
33 | hps.win_size,
34 | hps.fmin,
35 | hps.fmax
36 | )
37 | '''
38 | f = {}
39 | for i in range(args.min, args.max+1):
40 | fpath = os.path.join(ssl_dir, f"{i}.hdf5")
41 | f[i] = h5py.File(fpath, "a")
42 | '''
43 | for i in range(args.min, args.max+1):
44 | mel_rs = utils.transform(mel, i)
45 | wav_rs = vocoder(mel_rs)[0][0].detach().cpu().numpy()
46 | _wav_rs = librosa.resample(wav_rs, orig_sr=hps.sampling_rate, target_sr=args.sr)
47 | wav_rs = torch.from_numpy(_wav_rs).cuda().unsqueeze(0)
48 | c = utils.get_content(cmodel, wav_rs)
49 | ssl_path = os.path.join(ssl_dir, basename.replace(".wav", f"_{i}.pt"))
50 | torch.save(c.cpu(), ssl_path)
51 | #print(wav_rs.size(), c.size())
52 | wav_path = os.path.join(wav_dir, basename.replace(".wav", f"_{i}.wav"))
53 | wavfile.write(
54 | wav_path,
55 | args.sr,
56 | _wav_rs
57 | )
58 | '''
59 | f[i][basename[:-4]] = c.cpu()
60 | for i in range(args.min, args.max+1):
61 | f[i].close()
62 | '''
63 |
64 |
65 |
66 | if __name__ == "__main__":
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument("--sr", type=int, default=16000, help="sampling rate")
69 | parser.add_argument("--min", type=int, default=68, help="min")
70 | parser.add_argument("--max", type=int, default=92, help="max")
71 | parser.add_argument("--config", type=str, default="hifigan/config.json", help="path to config file")
72 | parser.add_argument("--in_dir", type=str, default="dataset/vctk-22k", help="path to input dir")
73 | parser.add_argument("--wav_dir", type=str, default="dataset/sr/wav", help="path to output wav dir")
74 | parser.add_argument("--ssl_dir", type=str, default="dataset/sr/wavlm", help="path to output ssl dir")
75 | args = parser.parse_args()
76 |
77 | print("Loading WavLM for content...")
78 | checkpoint = torch.load('wavlm/WavLM-Large.pt')
79 | cfg = WavLMConfig(checkpoint['cfg'])
80 | cmodel = WavLM(cfg).cuda()
81 | cmodel.load_state_dict(checkpoint['model'])
82 | cmodel.eval()
83 | print("Loaded WavLM.")
84 |
85 | print("Loading vocoder...")
86 | vocoder = utils.get_vocoder(0)
87 | vocoder.eval()
88 | print("Loaded vocoder.")
89 |
90 | config_path = args.config
91 | with open(config_path, "r") as f:
92 | data = f.read()
93 | config = json.loads(data)
94 | hps = utils.HParams(**config)
95 |
96 | filenames = glob(f'{args.in_dir}/*/*.wav', recursive=True)#[:10]
97 |
98 | for filename in tqdm(filenames):
99 | process(filename)
100 |
101 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | glob2==0.7
2 | tqdm==4.62.3
3 | librosa==0.8.1
4 | numpy==1.21.6
5 | scipy==1.7.2
6 | tensorboard==2.7.0
7 | torch==1.10.0
8 | torchvision==0.9.0
9 | webrtcvad==2.0.10
10 |
--------------------------------------------------------------------------------
/resources/NeurlVC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zzy1hjq/NeuralVC/9b562356ac008c76ead4b31251ea956a0c15eda2/resources/NeurlVC.png
--------------------------------------------------------------------------------
/speaker_encoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zzy1hjq/NeuralVC/9b562356ac008c76ead4b31251ea956a0c15eda2/speaker_encoder/__init__.py
--------------------------------------------------------------------------------
/speaker_encoder/audio.py:
--------------------------------------------------------------------------------
1 | from scipy.ndimage.morphology import binary_dilation
2 | from speaker_encoder.params_data import *
3 | from pathlib import Path
4 | from typing import Optional, Union
5 | import numpy as np
6 | import webrtcvad
7 | import librosa
8 | import struct
9 |
10 | int16_max = (2 ** 15) - 1
11 |
12 |
13 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14 | source_sr: Optional[int] = None):
15 | """
16 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18 |
19 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20 | just .wav), either the waveform as a numpy array of floats.
21 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22 | preprocessing. After preprocessing, the waveform's sampling rate will match the data
23 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24 | this argument will be ignored.
25 | """
26 | # Load the wav from disk if needed
27 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28 | wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29 | else:
30 | wav = fpath_or_wav
31 |
32 | # Resample the wav if needed
33 | if source_sr is not None and source_sr != sampling_rate:
34 | wav = librosa.resample(wav, source_sr, sampling_rate)
35 |
36 | # Apply the preprocessing: normalize volume and shorten long silences
37 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38 | wav = trim_long_silences(wav)
39 |
40 | return wav
41 |
42 |
43 | def wav_to_mel_spectrogram(wav):
44 | """
45 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46 | Note: this not a log-mel spectrogram.
47 | """
48 | frames = librosa.feature.melspectrogram(
49 | y=wav,
50 | sr=sampling_rate,
51 | n_fft=int(sampling_rate * mel_window_length / 1000),
52 | hop_length=int(sampling_rate * mel_window_step / 1000),
53 | n_mels=mel_n_channels
54 | )
55 | return frames.astype(np.float32).T
56 |
57 |
58 | def trim_long_silences(wav):
59 | """
60 | Ensures that segments without voice in the waveform remain no longer than a
61 | threshold determined by the VAD parameters in params.py.
62 |
63 | :param wav: the raw waveform as a numpy array of floats
64 | :return: the same waveform with silences trimmed away (length <= original wav length)
65 | """
66 | # Compute the voice detection window size
67 | samples_per_window = (vad_window_length * sampling_rate) // 1000
68 |
69 | # Trim the end of the audio to have a multiple of the window size
70 | wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71 |
72 | # Convert the float waveform to 16-bit mono PCM
73 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74 |
75 | # Perform voice activation detection
76 | voice_flags = []
77 | vad = webrtcvad.Vad(mode=3)
78 | for window_start in range(0, len(wav), samples_per_window):
79 | window_end = window_start + samples_per_window
80 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81 | sample_rate=sampling_rate))
82 | voice_flags = np.array(voice_flags)
83 |
84 | # Smooth the voice detection with a moving average
85 | def moving_average(array, width):
86 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87 | ret = np.cumsum(array_padded, dtype=float)
88 | ret[width:] = ret[width:] - ret[:-width]
89 | return ret[width - 1:] / width
90 |
91 | audio_mask = moving_average(voice_flags, vad_moving_average_width)
92 | audio_mask = np.round(audio_mask).astype(np.bool)
93 |
94 | # Dilate the voiced regions
95 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96 | audio_mask = np.repeat(audio_mask, samples_per_window)
97 |
98 | return wav[audio_mask == True]
99 |
100 |
101 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102 | if increase_only and decrease_only:
103 | raise ValueError("Both increase only and decrease only are set")
104 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106 | return wav
107 | return wav * (10 ** (dBFS_change / 20))
108 |
--------------------------------------------------------------------------------
/speaker_encoder/ckpt/pretrained_bak_5805000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zzy1hjq/NeuralVC/9b562356ac008c76ead4b31251ea956a0c15eda2/speaker_encoder/ckpt/pretrained_bak_5805000.pt
--------------------------------------------------------------------------------
/speaker_encoder/ckpt/pretrained_bak_5805000.pt.txt:
--------------------------------------------------------------------------------
1 | https://github.com/liusongxiang/ppg-vc/tree/main/speaker_encoder/ckpt
--------------------------------------------------------------------------------
/speaker_encoder/compute_embed.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder import inference as encoder
2 | from multiprocessing.pool import Pool
3 | from functools import partial
4 | from pathlib import Path
5 | # from utils import logmmse
6 | # from tqdm import tqdm
7 | # import numpy as np
8 | # import librosa
9 |
10 |
11 | def embed_utterance(fpaths, encoder_model_fpath):
12 | if not encoder.is_loaded():
13 | encoder.load_model(encoder_model_fpath)
14 |
15 | # Compute the speaker embedding of the utterance
16 | wav_fpath, embed_fpath = fpaths
17 | wav = np.load(wav_fpath)
18 | wav = encoder.preprocess_wav(wav)
19 | embed = encoder.embed_utterance(wav)
20 | np.save(embed_fpath, embed, allow_pickle=False)
21 |
22 |
23 | def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24 |
25 | wav_dir = outdir_root.joinpath("audio")
26 | metadata_fpath = synthesizer_root.joinpath("train.txt")
27 | assert wav_dir.exists() and metadata_fpath.exists()
28 | embed_dir = synthesizer_root.joinpath("embeds")
29 | embed_dir.mkdir(exist_ok=True)
30 |
31 | # Gather the input wave filepath and the target output embed filepath
32 | with metadata_fpath.open("r") as metadata_file:
33 | metadata = [line.split("|") for line in metadata_file]
34 | fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35 |
36 | # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37 | # Embed the utterances in separate threads
38 | func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39 | job = Pool(n_processes).imap(func, fpaths)
40 | list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
--------------------------------------------------------------------------------
/speaker_encoder/config.py:
--------------------------------------------------------------------------------
1 | librispeech_datasets = {
2 | "train": {
3 | "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4 | "other": ["LibriSpeech/train-other-500"]
5 | },
6 | "test": {
7 | "clean": ["LibriSpeech/test-clean"],
8 | "other": ["LibriSpeech/test-other"]
9 | },
10 | "dev": {
11 | "clean": ["LibriSpeech/dev-clean"],
12 | "other": ["LibriSpeech/dev-other"]
13 | },
14 | }
15 | libritts_datasets = {
16 | "train": {
17 | "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18 | "other": ["LibriTTS/train-other-500"]
19 | },
20 | "test": {
21 | "clean": ["LibriTTS/test-clean"],
22 | "other": ["LibriTTS/test-other"]
23 | },
24 | "dev": {
25 | "clean": ["LibriTTS/dev-clean"],
26 | "other": ["LibriTTS/dev-other"]
27 | },
28 | }
29 | voxceleb_datasets = {
30 | "voxceleb1" : {
31 | "train": ["VoxCeleb1/wav"],
32 | "test": ["VoxCeleb1/test_wav"]
33 | },
34 | "voxceleb2" : {
35 | "train": ["VoxCeleb2/dev/aac"],
36 | "test": ["VoxCeleb2/test_wav"]
37 | }
38 | }
39 |
40 | other_datasets = [
41 | "LJSpeech-1.1",
42 | "VCTK-Corpus/wav48",
43 | ]
44 |
45 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
46 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/__init__.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
3 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/random_cycler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | class RandomCycler:
4 | """
5 | Creates an internal copy of a sequence and allows access to its items in a constrained random
6 | order. For a source sequence of n items and one or several consecutive queries of a total
7 | of m items, the following guarantees hold (one implies the other):
8 | - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9 | - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10 | """
11 |
12 | def __init__(self, source):
13 | if len(source) == 0:
14 | raise Exception("Can't create RandomCycler from an empty collection")
15 | self.all_items = list(source)
16 | self.next_items = []
17 |
18 | def sample(self, count: int):
19 | shuffle = lambda l: random.sample(l, len(l))
20 |
21 | out = []
22 | while count > 0:
23 | if count >= len(self.all_items):
24 | out.extend(shuffle(list(self.all_items)))
25 | count -= len(self.all_items)
26 | continue
27 | n = min(count, len(self.next_items))
28 | out.extend(self.next_items[:n])
29 | count -= n
30 | self.next_items = self.next_items[n:]
31 | if len(self.next_items) == 0:
32 | self.next_items = shuffle(list(self.all_items))
33 | return out
34 |
35 | def __next__(self):
36 | return self.sample(1)[0]
37 |
38 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/speaker.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.random_cycler import RandomCycler
2 | from speaker_encoder.data_objects.utterance import Utterance
3 | from pathlib import Path
4 |
5 | # Contains the set of utterances of a single speaker
6 | class Speaker:
7 | def __init__(self, root: Path):
8 | self.root = root
9 | self.name = root.name
10 | self.utterances = None
11 | self.utterance_cycler = None
12 |
13 | def _load_utterances(self):
14 | with self.root.joinpath("_sources.txt").open("r") as sources_file:
15 | sources = [l.split(",") for l in sources_file]
16 | sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17 | self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18 | self.utterance_cycler = RandomCycler(self.utterances)
19 |
20 | def random_partial(self, count, n_frames):
21 | """
22 | Samples a batch of unique partial utterances from the disk in a way that all
23 | utterances come up at least once every two cycles and in a random order every time.
24 |
25 | :param count: The number of partial utterances to sample from the set of utterances from
26 | that speaker. Utterances are guaranteed not to be repeated if is not larger than
27 | the number of utterances available.
28 | :param n_frames: The number of frames in the partial utterance.
29 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30 | frames are the frames of the partial utterances and range is the range of the partial
31 | utterance with regard to the complete utterance.
32 | """
33 | if self.utterances is None:
34 | self._load_utterances()
35 |
36 | utterances = self.utterance_cycler.sample(count)
37 |
38 | a = [(u,) + u.random_partial(n_frames) for u in utterances]
39 |
40 | return a
41 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/speaker_batch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import List
3 | from speaker_encoder.data_objects.speaker import Speaker
4 |
5 | class SpeakerBatch:
6 | def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7 | self.speakers = speakers
8 | self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9 |
10 | # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11 | # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12 | self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
13 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/speaker_verification_dataset.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.random_cycler import RandomCycler
2 | from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3 | from speaker_encoder.data_objects.speaker import Speaker
4 | from speaker_encoder.params_data import partials_n_frames
5 | from torch.utils.data import Dataset, DataLoader
6 | from pathlib import Path
7 |
8 | # TODO: improve with a pool of speakers for data efficiency
9 |
10 | class SpeakerVerificationDataset(Dataset):
11 | def __init__(self, datasets_root: Path):
12 | self.root = datasets_root
13 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14 | if len(speaker_dirs) == 0:
15 | raise Exception("No speakers found. Make sure you are pointing to the directory "
16 | "containing all preprocessed speaker directories.")
17 | self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18 | self.speaker_cycler = RandomCycler(self.speakers)
19 |
20 | def __len__(self):
21 | return int(1e10)
22 |
23 | def __getitem__(self, index):
24 | return next(self.speaker_cycler)
25 |
26 | def get_logs(self):
27 | log_string = ""
28 | for log_fpath in self.root.glob("*.txt"):
29 | with log_fpath.open("r") as log_file:
30 | log_string += "".join(log_file.readlines())
31 | return log_string
32 |
33 |
34 | class SpeakerVerificationDataLoader(DataLoader):
35 | def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36 | batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37 | worker_init_fn=None):
38 | self.utterances_per_speaker = utterances_per_speaker
39 |
40 | super().__init__(
41 | dataset=dataset,
42 | batch_size=speakers_per_batch,
43 | shuffle=False,
44 | sampler=sampler,
45 | batch_sampler=batch_sampler,
46 | num_workers=num_workers,
47 | collate_fn=self.collate,
48 | pin_memory=pin_memory,
49 | drop_last=False,
50 | timeout=timeout,
51 | worker_init_fn=worker_init_fn
52 | )
53 |
54 | def collate(self, speakers):
55 | return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/utterance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Utterance:
5 | def __init__(self, frames_fpath, wave_fpath):
6 | self.frames_fpath = frames_fpath
7 | self.wave_fpath = wave_fpath
8 |
9 | def get_frames(self):
10 | return np.load(self.frames_fpath)
11 |
12 | def random_partial(self, n_frames):
13 | """
14 | Crops the frames into a partial utterance of n_frames
15 |
16 | :param n_frames: The number of frames of the partial utterance
17 | :return: the partial utterance frames and a tuple indicating the start and end of the
18 | partial utterance in the complete utterance.
19 | """
20 | frames = self.get_frames()
21 | if frames.shape[0] == n_frames:
22 | start = 0
23 | else:
24 | start = np.random.randint(0, frames.shape[0] - n_frames)
25 | end = start + n_frames
26 | return frames[start:end], (start, end)
--------------------------------------------------------------------------------
/speaker_encoder/hparams.py:
--------------------------------------------------------------------------------
1 | ## Mel-filterbank
2 | mel_window_length = 25 # In milliseconds
3 | mel_window_step = 10 # In milliseconds
4 | mel_n_channels = 40
5 |
6 |
7 | ## Audio
8 | sampling_rate = 16000
9 | # Number of spectrogram frames in a partial utterance
10 | partials_n_frames = 160 # 1600 ms
11 |
12 |
13 | ## Voice Activation Detection
14 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15 | # This sets the granularity of the VAD. Should not need to be changed.
16 | vad_window_length = 30 # In milliseconds
17 | # Number of frames to average together when performing the moving average smoothing.
18 | # The larger this value, the larger the VAD variations must be to not get smoothed out.
19 | vad_moving_average_width = 8
20 | # Maximum number of consecutive silent frames a segment can have.
21 | vad_max_silence_length = 6
22 |
23 |
24 | ## Audio volume normalization
25 | audio_norm_target_dBFS = -30
26 |
27 |
28 | ## Model parameters
29 | model_hidden_size = 256
30 | model_embedding_size = 256
31 | model_num_layers = 3
--------------------------------------------------------------------------------
/speaker_encoder/inference.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.params_data import *
2 | from speaker_encoder.model import SpeakerEncoder
3 | from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4 | from matplotlib import cm
5 | from speaker_encoder import audio
6 | from pathlib import Path
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 |
11 | _model = None # type: SpeakerEncoder
12 | _device = None # type: torch.device
13 |
14 |
15 | def load_model(weights_fpath: Path, device=None):
16 | """
17 | Loads the model in memory. If this function is not explicitely called, it will be run on the
18 | first call to embed_frames() with the default weights file.
19 |
20 | :param weights_fpath: the path to saved model weights.
21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22 | model will be loaded and will run on this device. Outputs will however always be on the cpu.
23 | If None, will default to your GPU if it"s available, otherwise your CPU.
24 | """
25 | # TODO: I think the slow loading of the encoder might have something to do with the device it
26 | # was saved on. Worth investigating.
27 | global _model, _device
28 | if device is None:
29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30 | elif isinstance(device, str):
31 | _device = torch.device(device)
32 | _model = SpeakerEncoder(_device, torch.device("cpu"))
33 | checkpoint = torch.load(weights_fpath)
34 | _model.load_state_dict(checkpoint["model_state"])
35 | _model.eval()
36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37 |
38 |
39 | def is_loaded():
40 | return _model is not None
41 |
42 |
43 | def embed_frames_batch(frames_batch):
44 | """
45 | Computes embeddings for a batch of mel spectrogram.
46 |
47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48 | (batch_size, n_frames, n_channels)
49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50 | """
51 | if _model is None:
52 | raise Exception("Model was not loaded. Call load_model() before inference.")
53 |
54 | frames = torch.from_numpy(frames_batch).to(_device)
55 | embed = _model.forward(frames).detach().cpu().numpy()
56 | return embed
57 |
58 |
59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60 | min_pad_coverage=0.75, overlap=0.5):
61 | """
62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63 | partial utterances of each. Both the waveform and the mel
64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those
66 | defined in params_data.py.
67 |
68 | The returned ranges may be indexing further than the length of the waveform. It is
69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70 |
71 | :param n_samples: the number of samples in the waveform
72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73 | utterance
74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75 | enough frames. If at least of are present,
76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78 | utterance, this parameter is ignored so that the function always returns at least 1 slice.
79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80 | utterances are entirely disjoint.
81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial
83 | utterances.
84 | """
85 | assert 0 <= overlap < 1
86 | assert 0 < min_pad_coverage <= 1
87 |
88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91 |
92 | # Compute the slices
93 | wav_slices, mel_slices = [], []
94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95 | for i in range(0, steps, frame_step):
96 | mel_range = np.array([i, i + partial_utterance_n_frames])
97 | wav_range = mel_range * samples_per_frame
98 | mel_slices.append(slice(*mel_range))
99 | wav_slices.append(slice(*wav_range))
100 |
101 | # Evaluate whether extra padding is warranted or not
102 | last_wav_range = wav_slices[-1]
103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104 | if coverage < min_pad_coverage and len(mel_slices) > 1:
105 | mel_slices = mel_slices[:-1]
106 | wav_slices = wav_slices[:-1]
107 |
108 | return wav_slices, mel_slices
109 |
110 |
111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112 | """
113 | Computes an embedding for a single utterance.
114 |
115 | # TODO: handle multiple wavs to benefit from batching on GPU
116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117 | :param using_partials: if True, then the utterance is split in partial utterances of
118 | frames and the utterance embedding is computed from their
119 | normalized average. If False, the utterance is instead computed from feeding the entire
120 | spectogram to the network.
121 | :param return_partials: if True, the partial embeddings will also be returned along with the
122 | wav slices that correspond to the partial embeddings.
123 | :param kwargs: additional arguments to compute_partial_splits()
124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125 | is True, the partial utterances as a numpy array of float32 of shape
126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127 | returned. If is simultaneously set to False, both these values will be None
128 | instead.
129 | """
130 | # Process the entire utterance if not using partials
131 | if not using_partials:
132 | frames = audio.wav_to_mel_spectrogram(wav)
133 | embed = embed_frames_batch(frames[None, ...])[0]
134 | if return_partials:
135 | return embed, None, None
136 | return embed
137 |
138 | # Compute where to split the utterance into partials and pad if necessary
139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140 | max_wave_length = wave_slices[-1].stop
141 | if max_wave_length >= len(wav):
142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143 |
144 | # Split the utterance into partials
145 | frames = audio.wav_to_mel_spectrogram(wav)
146 | frames_batch = np.array([frames[s] for s in mel_slices])
147 | partial_embeds = embed_frames_batch(frames_batch)
148 |
149 | # Compute the utterance embedding from the partial embeddings
150 | raw_embed = np.mean(partial_embeds, axis=0)
151 | embed = raw_embed / np.linalg.norm(raw_embed, 2)
152 |
153 | if return_partials:
154 | return embed, partial_embeds, wave_slices
155 | return embed
156 |
157 |
158 | def embed_speaker(wavs, **kwargs):
159 | raise NotImplemented()
160 |
161 |
162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163 | if ax is None:
164 | ax = plt.gca()
165 |
166 | if shape is None:
167 | height = int(np.sqrt(len(embed)))
168 | shape = (height, -1)
169 | embed = embed.reshape(shape)
170 |
171 | cmap = cm.get_cmap()
172 | mappable = ax.imshow(embed, cmap=cmap)
173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174 | cbar.set_clim(*color_range)
175 |
176 | ax.set_xticks([]), ax.set_yticks([])
177 | ax.set_title(title)
178 |
--------------------------------------------------------------------------------
/speaker_encoder/model.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.params_model import *
2 | from speaker_encoder.params_data import *
3 | from scipy.interpolate import interp1d
4 | from sklearn.metrics import roc_curve
5 | from torch.nn.utils import clip_grad_norm_
6 | from scipy.optimize import brentq
7 | from torch import nn
8 | import numpy as np
9 | import torch
10 |
11 |
12 | class SpeakerEncoder(nn.Module):
13 | def __init__(self, device, loss_device):
14 | super().__init__()
15 | self.loss_device = loss_device
16 |
17 | # Network defition
18 | self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19 | hidden_size=model_hidden_size, # 256
20 | num_layers=model_num_layers, # 3
21 | batch_first=True).to(device)
22 | self.linear = nn.Linear(in_features=model_hidden_size,
23 | out_features=model_embedding_size).to(device)
24 | self.relu = torch.nn.ReLU().to(device)
25 |
26 | # Cosine similarity scaling (with fixed initial parameter values)
27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29 |
30 | # Loss
31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32 |
33 | def do_gradient_ops(self):
34 | # Gradient scale
35 | self.similarity_weight.grad *= 0.01
36 | self.similarity_bias.grad *= 0.01
37 |
38 | # Gradient clipping
39 | clip_grad_norm_(self.parameters(), 3, norm_type=2)
40 |
41 | def forward(self, utterances, hidden_init=None):
42 | """
43 | Computes the embeddings of a batch of utterance spectrograms.
44 |
45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46 | (batch_size, n_frames, n_channels)
47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48 | batch_size, hidden_size). Will default to a tensor of zeros if None.
49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50 | """
51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52 | # and the final cell state.
53 | out, (hidden, cell) = self.lstm(utterances, hidden_init)
54 |
55 | # We take only the hidden state of the last layer
56 | embeds_raw = self.relu(self.linear(hidden[-1]))
57 |
58 | # L2-normalize it
59 | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60 |
61 | return embeds
62 |
63 | def similarity_matrix(self, embeds):
64 | """
65 | Computes the similarity matrix according the section 2.1 of GE2E.
66 |
67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68 | utterances_per_speaker, embedding_size)
69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70 | utterances_per_speaker, speakers_per_batch)
71 | """
72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73 |
74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76 | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77 |
78 | # Exclusive centroids (1 per utterance)
79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80 | centroids_excl /= (utterances_per_speaker - 1)
81 | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82 |
83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85 | # We vectorize the computation for efficiency.
86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87 | speakers_per_batch).to(self.loss_device)
88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89 | for j in range(speakers_per_batch):
90 | mask = np.where(mask_matrix[j])[0]
91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93 |
94 | ## Even more vectorized version (slower maybe because of transpose)
95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96 | # ).to(self.loss_device)
97 | # eye = np.eye(speakers_per_batch, dtype=np.int)
98 | # mask = np.where(1 - eye)
99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100 | # mask = np.where(eye)
101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102 | # sim_matrix2 = sim_matrix2.transpose(1, 2)
103 |
104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105 | return sim_matrix
106 |
107 | def loss(self, embeds):
108 | """
109 | Computes the softmax loss according the section 2.1 of GE2E.
110 |
111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112 | utterances_per_speaker, embedding_size)
113 | :return: the loss and the EER for this batch of embeddings.
114 | """
115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116 |
117 | # Loss
118 | sim_matrix = self.similarity_matrix(embeds)
119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120 | speakers_per_batch))
121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123 | loss = self.loss_fn(sim_matrix, target)
124 |
125 | # EER (not backpropagated)
126 | with torch.no_grad():
127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128 | labels = np.array([inv_argmax(i) for i in ground_truth])
129 | preds = sim_matrix.detach().cpu().numpy()
130 |
131 | # Snippet from https://yangcha.github.io/EER-ROC/
132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134 |
135 | return loss, eer
--------------------------------------------------------------------------------
/speaker_encoder/params_data.py:
--------------------------------------------------------------------------------
1 |
2 | ## Mel-filterbank
3 | mel_window_length = 25 # In milliseconds
4 | mel_window_step = 10 # In milliseconds
5 | mel_n_channels = 40
6 |
7 |
8 | ## Audio
9 | sampling_rate = 16000
10 | # Number of spectrogram frames in a partial utterance
11 | partials_n_frames = 160 # 1600 ms
12 | # Number of spectrogram frames at inference
13 | inference_n_frames = 80 # 800 ms
14 |
15 |
16 | ## Voice Activation Detection
17 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18 | # This sets the granularity of the VAD. Should not need to be changed.
19 | vad_window_length = 30 # In milliseconds
20 | # Number of frames to average together when performing the moving average smoothing.
21 | # The larger this value, the larger the VAD variations must be to not get smoothed out.
22 | vad_moving_average_width = 8
23 | # Maximum number of consecutive silent frames a segment can have.
24 | vad_max_silence_length = 6
25 |
26 |
27 | ## Audio volume normalization
28 | audio_norm_target_dBFS = -30
29 |
30 |
--------------------------------------------------------------------------------
/speaker_encoder/params_model.py:
--------------------------------------------------------------------------------
1 |
2 | ## Model parameters
3 | model_hidden_size = 256
4 | model_embedding_size = 256
5 | model_num_layers = 3
6 |
7 |
8 | ## Training parameters
9 | learning_rate_init = 1e-4
10 | speakers_per_batch = 64
11 | utterances_per_speaker = 10
12 |
--------------------------------------------------------------------------------
/speaker_encoder/preprocess.py:
--------------------------------------------------------------------------------
1 | from multiprocess.pool import ThreadPool
2 | from speaker_encoder.params_data import *
3 | from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
4 | from datetime import datetime
5 | from speaker_encoder import audio
6 | from pathlib import Path
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 |
11 | class DatasetLog:
12 | """
13 | Registers metadata about the dataset in a text file.
14 | """
15 | def __init__(self, root, name):
16 | self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17 | self.sample_data = dict()
18 |
19 | start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20 | self.write_line("Creating dataset %s on %s" % (name, start_time))
21 | self.write_line("-----")
22 | self._log_params()
23 |
24 | def _log_params(self):
25 | from speaker_encoder import params_data
26 | self.write_line("Parameter values:")
27 | for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28 | value = getattr(params_data, param_name)
29 | self.write_line("\t%s: %s" % (param_name, value))
30 | self.write_line("-----")
31 |
32 | def write_line(self, line):
33 | self.text_file.write("%s\n" % line)
34 |
35 | def add_sample(self, **kwargs):
36 | for param_name, value in kwargs.items():
37 | if not param_name in self.sample_data:
38 | self.sample_data[param_name] = []
39 | self.sample_data[param_name].append(value)
40 |
41 | def finalize(self):
42 | self.write_line("Statistics:")
43 | for param_name, values in self.sample_data.items():
44 | self.write_line("\t%s:" % param_name)
45 | self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46 | self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47 | self.write_line("-----")
48 | end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49 | self.write_line("Finished on %s" % end_time)
50 | self.text_file.close()
51 |
52 |
53 | def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54 | dataset_root = datasets_root.joinpath(dataset_name)
55 | if not dataset_root.exists():
56 | print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57 | return None, None
58 | return dataset_root, DatasetLog(out_dir, dataset_name)
59 |
60 |
61 | def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62 | skip_existing, logger):
63 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64 |
65 | # Function to preprocess utterances for one speaker
66 | def preprocess_speaker(speaker_dir: Path):
67 | # Give a name to the speaker that includes its dataset
68 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69 |
70 | # Create an output directory with that name, as well as a txt file containing a
71 | # reference to each source file.
72 | speaker_out_dir = out_dir.joinpath(speaker_name)
73 | speaker_out_dir.mkdir(exist_ok=True)
74 | sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75 |
76 | # There's a possibility that the preprocessing was interrupted earlier, check if
77 | # there already is a sources file.
78 | if sources_fpath.exists():
79 | try:
80 | with sources_fpath.open("r") as sources_file:
81 | existing_fnames = {line.split(",")[0] for line in sources_file}
82 | except:
83 | existing_fnames = {}
84 | else:
85 | existing_fnames = {}
86 |
87 | # Gather all audio files for that speaker recursively
88 | sources_file = sources_fpath.open("a" if skip_existing else "w")
89 | for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90 | # Check if the target output file already exists
91 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92 | out_fname = out_fname.replace(".%s" % extension, ".npy")
93 | if skip_existing and out_fname in existing_fnames:
94 | continue
95 |
96 | # Load and preprocess the waveform
97 | wav = audio.preprocess_wav(in_fpath)
98 | if len(wav) == 0:
99 | continue
100 |
101 | # Create the mel spectrogram, discard those that are too short
102 | frames = audio.wav_to_mel_spectrogram(wav)
103 | if len(frames) < partials_n_frames:
104 | continue
105 |
106 | out_fpath = speaker_out_dir.joinpath(out_fname)
107 | np.save(out_fpath, frames)
108 | logger.add_sample(duration=len(wav) / sampling_rate)
109 | sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110 |
111 | sources_file.close()
112 |
113 | # Process the utterances for each speaker
114 | with ThreadPool(8) as pool:
115 | list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116 | unit="speakers"))
117 | logger.finalize()
118 | print("Done preprocessing %s.\n" % dataset_name)
119 |
120 |
121 | # Function to preprocess utterances for one speaker
122 | def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
123 | # Give a name to the speaker that includes its dataset
124 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
125 |
126 | # Create an output directory with that name, as well as a txt file containing a
127 | # reference to each source file.
128 | speaker_out_dir = out_dir.joinpath(speaker_name)
129 | speaker_out_dir.mkdir(exist_ok=True)
130 | sources_fpath = speaker_out_dir.joinpath("_sources.txt")
131 |
132 | # There's a possibility that the preprocessing was interrupted earlier, check if
133 | # there already is a sources file.
134 | # if sources_fpath.exists():
135 | # try:
136 | # with sources_fpath.open("r") as sources_file:
137 | # existing_fnames = {line.split(",")[0] for line in sources_file}
138 | # except:
139 | # existing_fnames = {}
140 | # else:
141 | # existing_fnames = {}
142 | existing_fnames = {}
143 | # Gather all audio files for that speaker recursively
144 | sources_file = sources_fpath.open("a" if skip_existing else "w")
145 |
146 | for in_fpath in speaker_dir.glob("**/*.%s" % extension):
147 | # Check if the target output file already exists
148 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
149 | out_fname = out_fname.replace(".%s" % extension, ".npy")
150 | if skip_existing and out_fname in existing_fnames:
151 | continue
152 |
153 | # Load and preprocess the waveform
154 | wav = audio.preprocess_wav(in_fpath)
155 | if len(wav) == 0:
156 | continue
157 |
158 | # Create the mel spectrogram, discard those that are too short
159 | frames = audio.wav_to_mel_spectrogram(wav)
160 | if len(frames) < partials_n_frames:
161 | continue
162 |
163 | out_fpath = speaker_out_dir.joinpath(out_fname)
164 | np.save(out_fpath, frames)
165 | # logger.add_sample(duration=len(wav) / sampling_rate)
166 | sources_file.write("%s,%s\n" % (out_fname, in_fpath))
167 |
168 | sources_file.close()
169 | return len(wav)
170 |
171 | def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
172 | skip_existing, logger):
173 | # from multiprocessing import Pool, cpu_count
174 | from pathos.multiprocessing import ProcessingPool as Pool
175 | # Function to preprocess utterances for one speaker
176 | def __preprocess_speaker(speaker_dir: Path):
177 | # Give a name to the speaker that includes its dataset
178 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
179 |
180 | # Create an output directory with that name, as well as a txt file containing a
181 | # reference to each source file.
182 | speaker_out_dir = out_dir.joinpath(speaker_name)
183 | speaker_out_dir.mkdir(exist_ok=True)
184 | sources_fpath = speaker_out_dir.joinpath("_sources.txt")
185 |
186 | existing_fnames = {}
187 | # Gather all audio files for that speaker recursively
188 | sources_file = sources_fpath.open("a" if skip_existing else "w")
189 | wav_lens = []
190 | for in_fpath in speaker_dir.glob("**/*.%s" % extension):
191 | # Check if the target output file already exists
192 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
193 | out_fname = out_fname.replace(".%s" % extension, ".npy")
194 | if skip_existing and out_fname in existing_fnames:
195 | continue
196 |
197 | # Load and preprocess the waveform
198 | wav = audio.preprocess_wav(in_fpath)
199 | if len(wav) == 0:
200 | continue
201 |
202 | # Create the mel spectrogram, discard those that are too short
203 | frames = audio.wav_to_mel_spectrogram(wav)
204 | if len(frames) < partials_n_frames:
205 | continue
206 |
207 | out_fpath = speaker_out_dir.joinpath(out_fname)
208 | np.save(out_fpath, frames)
209 | # logger.add_sample(duration=len(wav) / sampling_rate)
210 | sources_file.write("%s,%s\n" % (out_fname, in_fpath))
211 | wav_lens.append(len(wav))
212 | sources_file.close()
213 | return wav_lens
214 |
215 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
216 | # Process the utterances for each speaker
217 | # with ThreadPool(8) as pool:
218 | # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
219 | # unit="speakers"))
220 | pool = Pool(processes=20)
221 | for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
222 | for wav_len in wav_lens:
223 | logger.add_sample(duration=wav_len / sampling_rate)
224 | print(f'{i}/{len(speaker_dirs)} \r')
225 |
226 | logger.finalize()
227 | print("Done preprocessing %s.\n" % dataset_name)
228 |
229 |
230 | def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
231 | for dataset_name in librispeech_datasets["train"]["other"]:
232 | # Initialize the preprocessing
233 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
234 | if not dataset_root:
235 | return
236 |
237 | # Preprocess all speakers
238 | speaker_dirs = list(dataset_root.glob("*"))
239 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
240 | skip_existing, logger)
241 |
242 |
243 | def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
244 | # Initialize the preprocessing
245 | dataset_name = "VoxCeleb1"
246 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
247 | if not dataset_root:
248 | return
249 |
250 | # Get the contents of the meta file
251 | with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
252 | metadata = [line.split("\t") for line in metafile][1:]
253 |
254 | # Select the ID and the nationality, filter out non-anglophone speakers
255 | nationalities = {line[0]: line[3] for line in metadata}
256 | # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
257 | # nationality.lower() in anglophone_nationalites]
258 | keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
259 | print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
260 | (len(keep_speaker_ids), len(nationalities)))
261 |
262 | # Get the speaker directories for anglophone speakers only
263 | speaker_dirs = dataset_root.joinpath("wav").glob("*")
264 | speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
265 | speaker_dir.name in keep_speaker_ids]
266 | print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
267 | (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
268 |
269 | # Preprocess all speakers
270 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
271 | skip_existing, logger)
272 |
273 |
274 | def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
275 | # Initialize the preprocessing
276 | dataset_name = "VoxCeleb2"
277 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
278 | if not dataset_root:
279 | return
280 |
281 | # Get the speaker directories
282 | # Preprocess all speakers
283 | speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
284 | _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
285 | skip_existing, logger)
286 |
--------------------------------------------------------------------------------
/speaker_encoder/train.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.visualizations import Visualizations
2 | from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3 | from speaker_encoder.params_model import *
4 | from speaker_encoder.model import SpeakerEncoder
5 | from utils.profiler import Profiler
6 | from pathlib import Path
7 | import torch
8 |
9 | def sync(device: torch.device):
10 | # FIXME
11 | return
12 | # For correct profiling (cuda operations are async)
13 | if device.type == "cuda":
14 | torch.cuda.synchronize(device)
15 |
16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18 | no_visdom: bool):
19 | # Create a dataset and a dataloader
20 | dataset = SpeakerVerificationDataset(clean_data_root)
21 | loader = SpeakerVerificationDataLoader(
22 | dataset,
23 | speakers_per_batch, # 64
24 | utterances_per_speaker, # 10
25 | num_workers=8,
26 | )
27 |
28 | # Setup the device on which to run the forward pass and the loss. These can be different,
29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30 | # hyperparameters) faster on the CPU.
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 | # FIXME: currently, the gradient is None if loss_device is cuda
33 | loss_device = torch.device("cpu")
34 |
35 | # Create the model and the optimizer
36 | model = SpeakerEncoder(device, loss_device)
37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38 | init_step = 1
39 |
40 | # Configure file path for the model
41 | state_fpath = models_dir.joinpath(run_id + ".pt")
42 | backup_dir = models_dir.joinpath(run_id + "_backups")
43 |
44 | # Load any existing model
45 | if not force_restart:
46 | if state_fpath.exists():
47 | print("Found existing model \"%s\", loading it and resuming training." % run_id)
48 | checkpoint = torch.load(state_fpath)
49 | init_step = checkpoint["step"]
50 | model.load_state_dict(checkpoint["model_state"])
51 | optimizer.load_state_dict(checkpoint["optimizer_state"])
52 | optimizer.param_groups[0]["lr"] = learning_rate_init
53 | else:
54 | print("No model \"%s\" found, starting training from scratch." % run_id)
55 | else:
56 | print("Starting the training from scratch.")
57 | model.train()
58 |
59 | # Initialize the visualization environment
60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61 | vis.log_dataset(dataset)
62 | vis.log_params()
63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64 | vis.log_implementation({"Device": device_name})
65 |
66 | # Training loop
67 | profiler = Profiler(summarize_every=10, disabled=False)
68 | for step, speaker_batch in enumerate(loader, init_step):
69 | profiler.tick("Blocking, waiting for batch (threaded)")
70 |
71 | # Forward pass
72 | inputs = torch.from_numpy(speaker_batch.data).to(device)
73 | sync(device)
74 | profiler.tick("Data to %s" % device)
75 | embeds = model(inputs)
76 | sync(device)
77 | profiler.tick("Forward pass")
78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79 | loss, eer = model.loss(embeds_loss)
80 | sync(loss_device)
81 | profiler.tick("Loss")
82 |
83 | # Backward pass
84 | model.zero_grad()
85 | loss.backward()
86 | profiler.tick("Backward pass")
87 | model.do_gradient_ops()
88 | optimizer.step()
89 | profiler.tick("Parameter update")
90 |
91 | # Update visualizations
92 | # learning_rate = optimizer.param_groups[0]["lr"]
93 | vis.update(loss.item(), eer, step)
94 |
95 | # Draw projections and save them to the backup folder
96 | if umap_every != 0 and step % umap_every == 0:
97 | print("Drawing and saving projections (step %d)" % step)
98 | backup_dir.mkdir(exist_ok=True)
99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100 | embeds = embeds.detach().cpu().numpy()
101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102 | vis.save()
103 |
104 | # Overwrite the latest version of the model
105 | if save_every != 0 and step % save_every == 0:
106 | print("Saving the model (step %d)" % step)
107 | torch.save({
108 | "step": step + 1,
109 | "model_state": model.state_dict(),
110 | "optimizer_state": optimizer.state_dict(),
111 | }, state_fpath)
112 |
113 | # Make a backup
114 | if backup_every != 0 and step % backup_every == 0:
115 | print("Making a backup (step %d)" % step)
116 | backup_dir.mkdir(exist_ok=True)
117 | backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118 | torch.save({
119 | "step": step + 1,
120 | "model_state": model.state_dict(),
121 | "optimizer_state": optimizer.state_dict(),
122 | }, backup_fpath)
123 |
124 | profiler.tick("Extras (visualizations, saving)")
125 |
--------------------------------------------------------------------------------
/speaker_encoder/visualizations.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2 | from datetime import datetime
3 | from time import perf_counter as timer
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | # import webbrowser
7 | import visdom
8 | import umap
9 |
10 | colormap = np.array([
11 | [76, 255, 0],
12 | [0, 127, 70],
13 | [255, 0, 0],
14 | [255, 217, 38],
15 | [0, 135, 255],
16 | [165, 0, 165],
17 | [255, 167, 255],
18 | [0, 255, 255],
19 | [255, 96, 38],
20 | [142, 76, 0],
21 | [33, 0, 127],
22 | [0, 0, 0],
23 | [183, 183, 183],
24 | ], dtype=np.float) / 255
25 |
26 |
27 | class Visualizations:
28 | def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29 | # Tracking data
30 | self.last_update_timestamp = timer()
31 | self.update_every = update_every
32 | self.step_times = []
33 | self.losses = []
34 | self.eers = []
35 | print("Updating the visualizations every %d steps." % update_every)
36 |
37 | # If visdom is disabled TODO: use a better paradigm for that
38 | self.disabled = disabled
39 | if self.disabled:
40 | return
41 |
42 | # Set the environment name
43 | now = str(datetime.now().strftime("%d-%m %Hh%M"))
44 | if env_name is None:
45 | self.env_name = now
46 | else:
47 | self.env_name = "%s (%s)" % (env_name, now)
48 |
49 | # Connect to visdom and open the corresponding window in the browser
50 | try:
51 | self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52 | except ConnectionError:
53 | raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54 | "start it.")
55 | # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56 |
57 | # Create the windows
58 | self.loss_win = None
59 | self.eer_win = None
60 | # self.lr_win = None
61 | self.implementation_win = None
62 | self.projection_win = None
63 | self.implementation_string = ""
64 |
65 | def log_params(self):
66 | if self.disabled:
67 | return
68 | from speaker_encoder import params_data
69 | from speaker_encoder import params_model
70 | param_string = "Model parameters:
"
71 | for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72 | value = getattr(params_model, param_name)
73 | param_string += "\t%s: %s
" % (param_name, value)
74 | param_string += "Data parameters:
"
75 | for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76 | value = getattr(params_data, param_name)
77 | param_string += "\t%s: %s
" % (param_name, value)
78 | self.vis.text(param_string, opts={"title": "Parameters"})
79 |
80 | def log_dataset(self, dataset: SpeakerVerificationDataset):
81 | if self.disabled:
82 | return
83 | dataset_string = ""
84 | dataset_string += "Speakers: %s\n" % len(dataset.speakers)
85 | dataset_string += "\n" + dataset.get_logs()
86 | dataset_string = dataset_string.replace("\n", "
")
87 | self.vis.text(dataset_string, opts={"title": "Dataset"})
88 |
89 | def log_implementation(self, params):
90 | if self.disabled:
91 | return
92 | implementation_string = ""
93 | for param, value in params.items():
94 | implementation_string += "%s: %s\n" % (param, value)
95 | implementation_string = implementation_string.replace("\n", "
")
96 | self.implementation_string = implementation_string
97 | self.implementation_win = self.vis.text(
98 | implementation_string,
99 | opts={"title": "Training implementation"}
100 | )
101 |
102 | def update(self, loss, eer, step):
103 | # Update the tracking data
104 | now = timer()
105 | self.step_times.append(1000 * (now - self.last_update_timestamp))
106 | self.last_update_timestamp = now
107 | self.losses.append(loss)
108 | self.eers.append(eer)
109 | print(".", end="")
110 |
111 | # Update the plots every steps
112 | if step % self.update_every != 0:
113 | return
114 | time_string = "Step time: mean: %5dms std: %5dms" % \
115 | (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116 | print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117 | (step, np.mean(self.losses), np.mean(self.eers), time_string))
118 | if not self.disabled:
119 | self.loss_win = self.vis.line(
120 | [np.mean(self.losses)],
121 | [step],
122 | win=self.loss_win,
123 | update="append" if self.loss_win else None,
124 | opts=dict(
125 | legend=["Avg. loss"],
126 | xlabel="Step",
127 | ylabel="Loss",
128 | title="Loss",
129 | )
130 | )
131 | self.eer_win = self.vis.line(
132 | [np.mean(self.eers)],
133 | [step],
134 | win=self.eer_win,
135 | update="append" if self.eer_win else None,
136 | opts=dict(
137 | legend=["Avg. EER"],
138 | xlabel="Step",
139 | ylabel="EER",
140 | title="Equal error rate"
141 | )
142 | )
143 | if self.implementation_win is not None:
144 | self.vis.text(
145 | self.implementation_string + ("%s" % time_string),
146 | win=self.implementation_win,
147 | opts={"title": "Training implementation"},
148 | )
149 |
150 | # Reset the tracking
151 | self.losses.clear()
152 | self.eers.clear()
153 | self.step_times.clear()
154 |
155 | def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156 | max_speakers=10):
157 | max_speakers = min(max_speakers, len(colormap))
158 | embeds = embeds[:max_speakers * utterances_per_speaker]
159 |
160 | n_speakers = len(embeds) // utterances_per_speaker
161 | ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162 | colors = [colormap[i] for i in ground_truth]
163 |
164 | reducer = umap.UMAP()
165 | projected = reducer.fit_transform(embeds)
166 | plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167 | plt.gca().set_aspect("equal", "datalim")
168 | plt.title("UMAP projection (step %d)" % step)
169 | if not self.disabled:
170 | self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171 | if out_fpath is not None:
172 | plt.savefig(out_fpath)
173 | plt.clf()
174 |
175 | def save(self):
176 | if not self.disabled:
177 | self.vis.save([self.env_name])
178 |
--------------------------------------------------------------------------------
/speaker_encoder/voice_encoder.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.hparams import *
2 | from speaker_encoder import audio
3 | from pathlib import Path
4 | from typing import Union, List
5 | from torch import nn
6 | from time import perf_counter as timer
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class SpeakerEncoder(nn.Module):
12 | def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
13 | """
14 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
15 | If None, defaults to cuda if it is available on your machine, otherwise the model will
16 | run on cpu. Outputs are always returned on the cpu, as numpy arrays.
17 | """
18 | super().__init__()
19 |
20 | # Define the network
21 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
22 | self.linear = nn.Linear(model_hidden_size, model_embedding_size)
23 | self.relu = nn.ReLU()
24 |
25 | # Get the target device
26 | if device is None:
27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28 | elif isinstance(device, str):
29 | device = torch.device(device)
30 | self.device = device
31 |
32 | # Load the pretrained model'speaker weights
33 | # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
34 | # if not weights_fpath.exists():
35 | # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
36 | # weights_fpath)
37 |
38 | start = timer()
39 | checkpoint = torch.load(weights_fpath, map_location="cpu")
40 |
41 | self.load_state_dict(checkpoint["model_state"], strict=False)
42 | self.to(device)
43 |
44 | if verbose:
45 | print("Loaded the voice encoder model on %s in %.2f seconds." %
46 | (device.type, timer() - start))
47 |
48 | def forward(self, mels: torch.FloatTensor):
49 | """
50 | Computes the embeddings of a batch of utterance spectrograms.
51 | :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
52 | (batch_size, n_frames, n_channels)
53 | :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
54 | Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
55 | """
56 | # Pass the input through the LSTM layers and retrieve the final hidden state of the last
57 | # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
58 | _, (hidden, _) = self.lstm(mels)
59 | embeds_raw = self.relu(self.linear(hidden[-1]))
60 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
61 |
62 | @staticmethod
63 | def compute_partial_slices(n_samples: int, rate, min_coverage):
64 | """
65 | Computes where to split an utterance waveform and its corresponding mel spectrogram to
66 | obtain partial utterances of each. Both the waveform and the
67 | mel spectrogram slices are returned, so as to make each partial utterance waveform
68 | correspond to its spectrogram.
69 |
70 | The returned ranges may be indexing further than the length of the waveform. It is
71 | recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
72 |
73 | :param n_samples: the number of samples in the waveform
74 | :param rate: how many partial utterances should occur per second. Partial utterances must
75 | cover the span of the entire utterance, thus the rate should not be lower than the inverse
76 | of the duration of a partial utterance. By default, partial utterances are 1.6s long and
77 | the minimum rate is thus 0.625.
78 | :param min_coverage: when reaching the last partial utterance, it may or may not have
79 | enough frames. If at least of are present,
80 | then the last partial utterance will be considered by zero-padding the audio. Otherwise,
81 | it will be discarded. If there aren't enough frames for one partial utterance,
82 | this parameter is ignored so that the function always returns at least one slice.
83 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
84 | respectively the waveform and the mel spectrogram with these slices to obtain the partial
85 | utterances.
86 | """
87 | assert 0 < min_coverage <= 1
88 |
89 | # Compute how many frames separate two partial utterances
90 | samples_per_frame = int((sampling_rate * mel_window_step / 1000))
91 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
92 | frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
93 | assert 0 < frame_step, "The rate is too high"
94 | assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
95 | (sampling_rate / (samples_per_frame * partials_n_frames))
96 |
97 | # Compute the slices
98 | wav_slices, mel_slices = [], []
99 | steps = max(1, n_frames - partials_n_frames + frame_step + 1)
100 | for i in range(0, steps, frame_step):
101 | mel_range = np.array([i, i + partials_n_frames])
102 | wav_range = mel_range * samples_per_frame
103 | mel_slices.append(slice(*mel_range))
104 | wav_slices.append(slice(*wav_range))
105 |
106 | # Evaluate whether extra padding is warranted or not
107 | last_wav_range = wav_slices[-1]
108 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
109 | if coverage < min_coverage and len(mel_slices) > 1:
110 | mel_slices = mel_slices[:-1]
111 | wav_slices = wav_slices[:-1]
112 |
113 | return wav_slices, mel_slices
114 |
115 | def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
116 | """
117 | Computes an embedding for a single utterance. The utterance is divided in partial
118 | utterances and an embedding is computed for each. The complete utterance embedding is the
119 | L2-normed average embedding of the partial utterances.
120 |
121 | TODO: independent batched version of this function
122 |
123 | :param wav: a preprocessed utterance waveform as a numpy array of float32
124 | :param return_partials: if True, the partial embeddings will also be returned along with
125 | the wav slices corresponding to each partial utterance.
126 | :param rate: how many partial utterances should occur per second. Partial utterances must
127 | cover the span of the entire utterance, thus the rate should not be lower than the inverse
128 | of the duration of a partial utterance. By default, partial utterances are 1.6s long and
129 | the minimum rate is thus 0.625.
130 | :param min_coverage: when reaching the last partial utterance, it may or may not have
131 | enough frames. If at least of are present,
132 | then the last partial utterance will be considered by zero-padding the audio. Otherwise,
133 | it will be discarded. If there aren't enough frames for one partial utterance,
134 | this parameter is ignored so that the function always returns at least one slice.
135 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
136 | is True, the partial utterances as a numpy array of float32 of shape
137 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
138 | returned.
139 | """
140 | # Compute where to split the utterance into partials and pad the waveform with zeros if
141 | # the partial utterances cover a larger range.
142 | wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
143 | max_wave_length = wav_slices[-1].stop
144 | if max_wave_length >= len(wav):
145 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
146 |
147 | # Split the utterance into partials and forward them through the model
148 | mel = audio.wav_to_mel_spectrogram(wav)
149 | mels = np.array([mel[s] for s in mel_slices])
150 | with torch.no_grad():
151 | mels = torch.from_numpy(mels).to(self.device)
152 | partial_embeds = self(mels).cpu().numpy()
153 |
154 | # Compute the utterance embedding from the partial embeddings
155 | raw_embed = np.mean(partial_embeds, axis=0)
156 | embed = raw_embed / np.linalg.norm(raw_embed, 2)
157 |
158 | if return_partials:
159 | return embed, partial_embeds, wav_slices
160 | return embed
161 |
162 | def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
163 | """
164 | Compute the embedding of a collection of wavs (presumably from the same speaker) by
165 | averaging their embedding and L2-normalizing it.
166 |
167 | :param wavs: list of wavs a numpy arrays of float32.
168 | :param kwargs: extra arguments to embed_utterance()
169 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
170 | """
171 | raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
172 | for wav in wavs], axis=0)
173 | return raw_embed / np.linalg.norm(raw_embed, 2)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import itertools
5 | import math
6 | import torch
7 | from torch import nn, optim
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader
10 | from utils import HParams
11 |
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch.cuda.amp import autocast, GradScaler
14 |
15 | import commons
16 | import utils
17 | from data_utils import (
18 | TextAudioSpeakerLoader,
19 | TextAudioSpeakerCollate,
20 | DistributedBucketSampler
21 | )
22 | from models import (
23 | HuBERT_NeuralDec_VITS,
24 | MultiPeriodDiscriminator,
25 | )
26 | from losses import (
27 | generator_loss,
28 | discriminator_loss,
29 | feature_loss,
30 | kl_loss
31 | )
32 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
33 | from speaker_encoder.voice_encoder import SpeakerEncoder
34 | smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
35 | torch.backends.cudnn.benchmark = True
36 | global_step = 0
37 |
38 |
39 | class Parameters:
40 | def __init__(self):
41 | self.config = "./configs/hubert-neuraldec-vits.json"
42 | self.model = "neuralvc-temp"
43 |
44 |
45 | args = Parameters()
46 |
47 | def get_hparams(init=True):
48 |
49 | model_dir = os.path.join("./logs", args.model)
50 |
51 | if not os.path.exists(model_dir):
52 | os.makedirs(model_dir)
53 |
54 | config_path = args.config
55 | config_save_path = os.path.join(model_dir, "config.json")
56 | if init:
57 | with open(config_path, "r") as f:
58 | data = f.read()
59 | with open(config_save_path, "w") as f:
60 | f.write(data)
61 | else:
62 | with open(config_save_path, "r") as f:
63 | data = f.read()
64 | config = json.loads(data)
65 |
66 | hparams = HParams(**config)
67 | hparams.model_dir = model_dir
68 | return hparams
69 |
70 | hps = get_hparams()
71 |
72 |
73 | def spk_loss(tgt, gen, batch_size):
74 | loss = 0.0
75 | for i in range(batch_size):
76 | tgt_emb = smodel.embed_utterance(tgt[i][0])
77 | gen_emb = smodel.embed_utterance(gen[i][0])
78 |
79 | loss += F.l1_loss(torch.from_numpy(tgt_emb), torch.from_numpy(gen_emb))
80 |
81 | return loss/batch_size
82 |
83 | def main():
84 | """Assume Single Node Multi GPUs Training Only"""
85 | assert torch.cuda.is_available(), "CPU training is not allowed."
86 | n_gpus = torch.cuda.device_count()
87 | run(0,n_gpus, hps)
88 |
89 | def run(rank, n_gpus, hps):
90 | global global_step
91 | if rank == 0:
92 | logger = utils.get_logger(hps.model_dir)
93 | logger.info(hps)
94 | utils.check_git_hash(hps.model_dir)
95 | writer = SummaryWriter(log_dir=hps.model_dir)
96 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
97 |
98 | #dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
99 | torch.manual_seed(hps.train.seed)
100 | torch.cuda.set_device(rank)
101 |
102 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
103 | train_sampler = DistributedBucketSampler(
104 | train_dataset,
105 | hps.train.batch_size,
106 | [75,100,125,150,175,200,225,250,300,350,400,450,500,550,600,650,700,750,800,850,900,950,1000,1100,1200,1300,1400,1500,2000,3000,4000,5000],
107 | num_replicas=n_gpus,
108 | rank=rank,
109 | shuffle=True)
110 | collate_fn = TextAudioSpeakerCollate(hps)
111 | train_loader = DataLoader(train_dataset, num_workers=0, shuffle=False, pin_memory=True,
112 | collate_fn=collate_fn, batch_sampler=train_sampler)
113 | if rank == 0:
114 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
115 | eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=True,
116 | batch_size=hps.train.batch_size, pin_memory=False,
117 | drop_last=False, collate_fn=collate_fn)
118 |
119 | net_g = HuBERT_NeuralDec_VITS(
120 | hps.data.filter_length // 2 + 1,
121 | hps.train.segment_size // hps.data.hop_length,
122 | **hps.model).cuda(rank)
123 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
124 | optim_g = torch.optim.AdamW(
125 | net_g.parameters(),
126 | hps.train.learning_rate,
127 | betas=hps.train.betas,
128 | eps=hps.train.eps)
129 | optim_d = torch.optim.AdamW(
130 | net_d.parameters(),
131 | hps.train.learning_rate,
132 | betas=hps.train.betas,
133 | eps=hps.train.eps)
134 |
135 | try:
136 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
137 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
138 | global_step = (epoch_str - 1) * len(train_loader)
139 | except:
140 | epoch_str = 1
141 | global_step = 0
142 |
143 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
144 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
145 |
146 | scaler = GradScaler(enabled=hps.train.fp16_run)
147 |
148 | for epoch in range(epoch_str, hps.train.epochs + 1):
149 | if rank==0:
150 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
151 | else:
152 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
153 | scheduler_g.step()
154 | scheduler_d.step()
155 |
156 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
157 |
158 | net_g, net_d = nets
159 | optim_g, optim_d = optims
160 | scheduler_g, scheduler_d = schedulers
161 | train_loader, eval_loader = loaders
162 | if writers is not None:
163 | writer, writer_eval = writers
164 |
165 | train_loader.batch_sampler.set_epoch(epoch)
166 | global global_step
167 |
168 | net_g.train()
169 | net_d.train()
170 | for batch_idx, items in enumerate(train_loader):
171 | if hps.model.use_spk:
172 | c, spec, y, spk = items
173 | g = spk.cuda(rank, non_blocking=True)
174 | else:
175 | c, spec, y = items
176 | g = None
177 | spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
178 | c = c.cuda(rank, non_blocking=True)
179 | # print("***" + spec.shape)
180 | mel = spec_to_mel_torch(
181 | spec,
182 | hps.data.filter_length,
183 | hps.data.n_mel_channels,
184 | hps.data.sampling_rate,
185 | hps.data.mel_fmin,
186 | hps.data.mel_fmax)
187 | real_mel = mel_spectrogram_torch(
188 | y.squeeze(1),
189 | hps.data.filter_length,
190 | hps.data.n_mel_channels,
191 | hps.data.sampling_rate,
192 | hps.data.hop_length,
193 | hps.data.win_length,
194 | hps.data.mel_fmin,
195 | hps.data.mel_fmax
196 | )
197 | #print(torch.max(mel),torch.min(mel),torch.max(real_mel),torch.min(real_mel))
198 | with autocast(enabled=hps.train.fp16_run):
199 | y_hat, ids_slice, z_mask,\
200 | (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g, mel=mel)
201 | #print(torch.max(y),torch.min(y),torch.max(y_hat),torch.min(y_hat))
202 | y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
203 | y_hat_mel = mel_spectrogram_torch(
204 | y_hat.squeeze(1),
205 | hps.data.filter_length,
206 | hps.data.n_mel_channels,
207 | hps.data.sampling_rate,
208 | hps.data.hop_length,
209 | hps.data.win_length,
210 | hps.data.mel_fmin,
211 | hps.data.mel_fmax
212 | )
213 | y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
214 |
215 | # Discriminator
216 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
217 | with autocast(enabled=False):
218 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
219 | loss_disc_all = loss_disc
220 | optim_d.zero_grad()
221 | scaler.scale(loss_disc_all).backward()
222 | scaler.unscale_(optim_d)
223 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
224 | scaler.step(optim_d)
225 | with autocast(enabled=hps.train.fp16_run):
226 | # Generator
227 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
228 | with autocast(enabled=False):
229 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
230 | loss_mel = F.l1_loss(y_hat_mel, y_mel) * hps.train.c_mel
231 | loss_fm = feature_loss(fmap_r, fmap_g)
232 | loss_gen, losses_gen = generator_loss(y_d_hat_g)
233 | loss_spk = spk_loss(y.detach().cpu().numpy(), y_hat.detach().cpu().numpy(), hps.train.batch_size)
234 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + loss_spk
235 | optim_g.zero_grad()
236 | scaler.scale(loss_gen_all).backward()
237 | scaler.unscale_(optim_g)
238 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
239 | scaler.step(optim_g)
240 | scaler.update()
241 |
242 | if rank==0:
243 | if global_step % hps.train.log_interval == 0:
244 | lr = optim_g.param_groups[0]['lr']
245 | losses = [loss_disc, loss_gen, loss_fm, loss_mel]
246 | logger.info('Train Epoch: {} [{:.0f}%]'.format(
247 | epoch,
248 | 100. * batch_idx / len(train_loader)))
249 | logger.info([x.item() for x in losses] + [global_step, lr])
250 |
251 | scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
252 | scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel})
253 |
254 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
255 | scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
256 | scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
257 | image_dict = {
258 | "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
259 | "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
260 | "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
261 | }
262 | utils.summarize(
263 | writer=writer,
264 | global_step=global_step,
265 | images=image_dict,
266 | scalars=scalar_dict)
267 |
268 | if global_step % hps.train.eval_interval == 0:
269 | # evaluate(hps, net_g, eval_loader, writer_eval)
270 | utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
271 | utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
272 | global_step += 1
273 |
274 | if rank == 0:
275 | logger.info('====> Epoch: {}'.format(epoch))
276 |
277 |
278 | if __name__ == "__main__":
279 | main()
280 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 | import torchvision
12 | from torch.nn import functional as F
13 | from commons import sequence_mask
14 | import hifigan
15 | from wavlm import WavLM, WavLMConfig
16 |
17 | MATPLOTLIB_FLAG = False
18 |
19 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
20 | logger = logging
21 |
22 |
23 | def get_cmodel(rank):
24 | checkpoint = torch.load('wavlm/WavLM-Large.pt')
25 | cfg = WavLMConfig(checkpoint['cfg'])
26 | cmodel = WavLM(cfg).cuda(rank)
27 | cmodel.load_state_dict(checkpoint['model'])
28 | cmodel.eval()
29 | return cmodel
30 |
31 |
32 | def get_content(cmodel, y):
33 | with torch.no_grad():
34 | c = cmodel.extract_features(y.squeeze(1))[0]
35 | c = c.transpose(1, 2)
36 | return c
37 |
38 |
39 | def get_vocoder(rank):
40 | with open("hifigan/config.json", "r") as f:
41 | config = json.load(f)
42 | config = hifigan.AttrDict(config)
43 | vocoder = hifigan.Generator(config)
44 | ckpt = torch.load("hifigan/generator_v1")
45 | vocoder.load_state_dict(ckpt["generator"])
46 | vocoder.eval()
47 | vocoder.remove_weight_norm()
48 | vocoder.cuda(rank)
49 | return vocoder
50 |
51 |
52 | def transform(mel, height): # 68-92
53 | #r = np.random.random()
54 | #rate = r * 0.3 + 0.85 # 0.85-1.15
55 | #height = int(mel.size(-2) * rate)
56 | tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
57 | if height >= mel.size(-2):
58 | return tgt[:, :mel.size(-2), :]
59 | else:
60 | silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
61 | silence += torch.randn_like(silence) / 10
62 | return torch.cat((tgt, silence), 1)
63 |
64 |
65 | def stretch(mel, width): # 0.5-2
66 | return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
67 |
68 |
69 | def load_checkpoint(checkpoint_path, model, optimizer=None, strict=False):
70 | assert os.path.isfile(checkpoint_path)
71 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
72 | iteration = checkpoint_dict['iteration']
73 | learning_rate = checkpoint_dict['learning_rate']
74 | if optimizer is not None:
75 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
76 | saved_state_dict = checkpoint_dict['model']
77 | if hasattr(model, 'module'):
78 | state_dict = model.module.state_dict()
79 | else:
80 | state_dict = model.state_dict()
81 | if strict:
82 | assert state_dict.keys() == saved_state_dict.keys(), "Mismatched model config and checkpoint."
83 | new_state_dict= {}
84 | for k, v in state_dict.items():
85 | try:
86 | new_state_dict[k] = saved_state_dict[k]
87 | except:
88 | logger.info("%s is not in the checkpoint" % k)
89 | new_state_dict[k] = v
90 | if hasattr(model, 'module'):
91 | model.module.load_state_dict(new_state_dict)
92 | else:
93 | model.load_state_dict(new_state_dict)
94 | logger.info("Loaded checkpoint '{}' (iteration {})" .format(
95 | checkpoint_path, iteration))
96 | return model, optimizer, learning_rate, iteration
97 |
98 |
99 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
100 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
101 | iteration, checkpoint_path))
102 | if hasattr(model, 'module'):
103 | state_dict = model.module.state_dict()
104 | else:
105 | state_dict = model.state_dict()
106 | torch.save({'model': state_dict,
107 | 'iteration': iteration,
108 | 'optimizer': optimizer.state_dict(),
109 | 'learning_rate': learning_rate}, checkpoint_path)
110 |
111 |
112 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
113 | for k, v in scalars.items():
114 | writer.add_scalar(k, v, global_step)
115 | for k, v in histograms.items():
116 | writer.add_histogram(k, v, global_step)
117 | for k, v in images.items():
118 | writer.add_image(k, v, global_step, dataformats='HWC')
119 | for k, v in audios.items():
120 | writer.add_audio(k, v, global_step, audio_sampling_rate)
121 |
122 |
123 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
124 | f_list = glob.glob(os.path.join(dir_path, regex))
125 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
126 | x = f_list[-1]
127 | print(x)
128 | return x
129 |
130 |
131 | def plot_spectrogram_to_numpy(spectrogram):
132 | global MATPLOTLIB_FLAG
133 | if not MATPLOTLIB_FLAG:
134 | import matplotlib
135 | matplotlib.use("Agg")
136 | MATPLOTLIB_FLAG = True
137 | mpl_logger = logging.getLogger('matplotlib')
138 | mpl_logger.setLevel(logging.WARNING)
139 | import matplotlib.pylab as plt
140 | import numpy as np
141 |
142 | fig, ax = plt.subplots(figsize=(10,2))
143 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
144 | interpolation='none')
145 | plt.colorbar(im, ax=ax)
146 | plt.xlabel("Frames")
147 | plt.ylabel("Channels")
148 | plt.tight_layout()
149 |
150 | fig.canvas.draw()
151 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
152 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
153 | plt.close()
154 | return data
155 |
156 |
157 | def plot_alignment_to_numpy(alignment, info=None):
158 | global MATPLOTLIB_FLAG
159 | if not MATPLOTLIB_FLAG:
160 | import matplotlib
161 | matplotlib.use("Agg")
162 | MATPLOTLIB_FLAG = True
163 | mpl_logger = logging.getLogger('matplotlib')
164 | mpl_logger.setLevel(logging.WARNING)
165 | import matplotlib.pylab as plt
166 | import numpy as np
167 |
168 | fig, ax = plt.subplots(figsize=(6, 4))
169 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
170 | interpolation='none')
171 | fig.colorbar(im, ax=ax)
172 | xlabel = 'Decoder timestep'
173 | if info is not None:
174 | xlabel += '\n\n' + info
175 | plt.xlabel(xlabel)
176 | plt.ylabel('Encoder timestep')
177 | plt.tight_layout()
178 |
179 | fig.canvas.draw()
180 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
181 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
182 | plt.close()
183 | return data
184 |
185 |
186 | def load_wav_to_torch(full_path):
187 | sampling_rate, data = read(full_path)
188 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
189 |
190 |
191 | def load_filepaths_and_text(filename, split="|"):
192 | with open(filename, encoding='utf-8') as f:
193 | filepaths_and_text = [line.strip().split(split) for line in f]
194 | return filepaths_and_text
195 |
196 |
197 | def get_hparams(init=True):
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
200 | help='JSON file for configuration')
201 | parser.add_argument('-m', '--model', type=str, required=True,
202 | help='Model name')
203 |
204 | args = parser.parse_args()
205 | model_dir = os.path.join("./logs", args.model)
206 |
207 | if not os.path.exists(model_dir):
208 | os.makedirs(model_dir)
209 |
210 | config_path = args.config
211 | config_save_path = os.path.join(model_dir, "config.json")
212 | if init:
213 | with open(config_path, "r") as f:
214 | data = f.read()
215 | with open(config_save_path, "w") as f:
216 | f.write(data)
217 | else:
218 | with open(config_save_path, "r") as f:
219 | data = f.read()
220 | config = json.loads(data)
221 |
222 | hparams = HParams(**config)
223 | hparams.model_dir = model_dir
224 | return hparams
225 |
226 |
227 | def get_hparams_from_dir(model_dir):
228 | config_save_path = os.path.join(model_dir, "config.json")
229 | with open(config_save_path, "r") as f:
230 | data = f.read()
231 | config = json.loads(data)
232 |
233 | hparams =HParams(**config)
234 | hparams.model_dir = model_dir
235 | return hparams
236 |
237 |
238 | def get_hparams_from_file(config_path):
239 | with open(config_path, "r") as f:
240 | data = f.read()
241 | config = json.loads(data)
242 |
243 | hparams =HParams(**config)
244 | return hparams
245 |
246 |
247 | def check_git_hash(model_dir):
248 | source_dir = os.path.dirname(os.path.realpath(__file__))
249 | if not os.path.exists(os.path.join(source_dir, ".git")):
250 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
251 | source_dir
252 | ))
253 | return
254 |
255 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
256 |
257 | path = os.path.join(model_dir, "githash")
258 | if os.path.exists(path):
259 | saved_hash = open(path).read()
260 | if saved_hash != cur_hash:
261 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
262 | saved_hash[:8], cur_hash[:8]))
263 | else:
264 | open(path, "w").write(cur_hash)
265 |
266 |
267 | def get_logger(model_dir, filename="train.log"):
268 | global logger
269 | logger = logging.getLogger(os.path.basename(model_dir))
270 | logger.setLevel(logging.DEBUG)
271 |
272 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
273 | if not os.path.exists(model_dir):
274 | os.makedirs(model_dir)
275 | h = logging.FileHandler(os.path.join(model_dir, filename))
276 | h.setLevel(logging.DEBUG)
277 | h.setFormatter(formatter)
278 | logger.addHandler(h)
279 | return logger
280 |
281 |
282 | class HParams():
283 | def __init__(self, **kwargs):
284 | for k, v in kwargs.items():
285 | if type(v) == dict:
286 | v = HParams(**v)
287 | self[k] = v
288 |
289 | def keys(self):
290 | return self.__dict__.keys()
291 |
292 | def items(self):
293 | return self.__dict__.items()
294 |
295 | def values(self):
296 | return self.__dict__.values()
297 |
298 | def __len__(self):
299 | return len(self.__dict__)
300 |
301 | def __getitem__(self, key):
302 | return getattr(self, key)
303 |
304 | def __setitem__(self, key, value):
305 | return setattr(self, key, value)
306 |
307 | def __contains__(self, key):
308 | return key in self.__dict__
309 |
310 | def __repr__(self):
311 | return self.__dict__.__repr__()
312 |
--------------------------------------------------------------------------------
/wavlm/WavLM-Large.pt.txt:
--------------------------------------------------------------------------------
1 | https://github.com/microsoft/unilm/tree/master/wavlm
--------------------------------------------------------------------------------
/wavlm/__init__.py:
--------------------------------------------------------------------------------
1 | from wavlm.WavLM import WavLM, WavLMConfig
--------------------------------------------------------------------------------