├── README.md ├── alias_free_torch ├── __init__.py ├── act.py ├── filter.py └── resample.py ├── augmentation ├── aug.py └── peq.py ├── ckpt ├── config.json └── config_bigvgan.json ├── configs └── config_16k.json ├── infer.sh ├── inference.py ├── model ├── base.py ├── diffhiervc.py ├── diffusion_f0.py ├── diffusion_mel.py ├── diffusion_module.py └── styleencoder.py ├── module ├── __init__.py ├── attentions.py ├── commons.py ├── modules.py ├── transforms.py └── utils.py ├── requirements.txt ├── sample ├── src_p241_004.wav └── tar_p239_022.wav ├── train.py ├── utils ├── data_loader.py └── utils.py └── vocoder ├── activations.py ├── bigvgan.py ├── hifigan.py └── modules.py /README.md: -------------------------------------------------------------------------------- 1 | ## Diff-HierVC: Diffusion-based Hierarchical Voice Conversion with Robust Pitch Generation and Masked Prior for Zero-shot Speaker Adaptation 2 | 3 | The official Pytorch implementation of Diff-HierVC (Interspeeh 2023, Oral) 4 | 5 | Ha-Yeong Choi, Sang-Hoon Lee, Seong-Whan Lee 6 | 7 | ![image](https://github.com/hayeong0/Diff-HierVC/assets/47182864/e8a22c5f-af6f-43e8-92b0-0aac839cb0b6) 8 |

Overall architecture 9 | 10 | > Although voice conversion (VC) systems have shown a remarkable ability to transfer voice style, existing methods still have an inaccurate pitch and low speaker adaptation quality. To address these challenges, we introduce Diff-HierVC, a hierarchical VC system based on two diffusion models. We first introduce DiffPitch, which can effectively generate $F_0$ with the target voice style. Subsequently, the generated $F_0$ is fed to DiffVoice to convert the speech with a target voice style. Furthermore, using the source-filter encoder, we disentangle the speech and use the converted Mel-spectrogram as a data-driven prior in DiffVoice to improve the voice style transfer capacity. Finally, by using the masked prior in diffusion models, our model can improve the speaker adaptation quality. Experimental results verify the superiority of our model in pitch generation and voice style transfer performance, and our model also achieves a CER of 0.83\% and EER of 3.29\% in zero-shot VC scenarios. 11 | 12 | ## 🎧 Audio Demo 13 | https://diff-hiervc.github.io/audio_demo/ 14 | 15 | ## 📑 Pre-trained Model 16 | Our model checkpoints can be downloaded [here](https://drive.google.com/drive/folders/1THkeyDlA7EbZxwnuuxGsUOftV70Fb7h4?usp=sharing). 17 | 18 | - model_diffhier.pth 19 | - voc_hifigan.pth 20 | - voc_bigvgan.pth 21 | 22 | ## 🔨 Usage 23 | 24 | 1. Clone this rep && Install python requirement 25 | 26 | ``` 27 | git clone https://github.com/hayeong0/Diff-HierVC.git 28 | pip install -r req* 29 | ``` 30 | 31 | 2. Download the pre-trained model checkpoint from drive and place it in the following path. 32 | ``` 33 | . 34 | ├── ckpt 35 | │ ├── config.json 36 | │ └── model_diffhier.pth ✅ 37 | ├── inference.py 38 | ├── infer.sh 39 | ├── model 40 | ├── module 41 | ├── requirements.txt 42 | ├── utils 43 | └── vocoder 44 | ├── hifigan.py 45 | ├── modules.py 46 | └── voc_hifigan.pth ✅ 47 | └── voc_bigvgan.pth ✅ 48 | ``` 49 | 3. Run `infer.sh` 50 | 51 | `diffpitch_ts` refers to the time step of the pitch generator and `diffvoice_ts` refers to the time step of the Mel generator. 52 | 53 | Empirically, it has been observed that if the time step of diffpitch is too small, noise remains, and if it is too large, excessive diversity occurs. 54 | 55 | Please use it appropriately for your dataset! 56 | ``` 57 | bash infer.sh 58 | 59 | python3 inference.py \ 60 | --src_path './sample/src_p241_004.wav' \ 61 | --trg_path './sample/tar_p239_022.wav' \ 62 | --ckpt_model './ckpt/model_diffhier.pth' \ 63 | --ckpt_voc './vocoder/voc_bigvgan.pth' \ 64 | --output_dir './converted' \ 65 | --diffpitch_ts 30 \ 66 | --diffvoice_ts 6 67 | ``` 68 | 🎧 Test it on your own dataset and share your interesting results! :) 69 | 70 | 71 | 72 | ## 🎓 Citation 73 | ``` 74 | @inproceedings{choi23d_interspeech, 75 | author={Ha-Yeong Choi and Sang-Hoon Lee and Seong-Whan Lee}, 76 | title={{Diff-HierVC: Diffusion-based Hierarchical Voice Conversion with Robust Pitch Generation and Masked Prior for Zero-shot Speaker Adaptation}}, 77 | year=2023, 78 | booktitle={Proc. INTERSPEECH 2023}, 79 | pages={2283--2287}, 80 | doi={10.21437/Interspeech.2023-817} 81 | } 82 | ``` 83 | 84 | 85 | ## 💎 Acknowledgements 86 | - Our code is based on [DiffVC](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC), [HiFiGAN](https://github.com/jik876/hifi-gan), and [BigVGAN](https://github.com/NVIDIA/BigVGAN). 87 | 88 | 89 | 90 | ## License 91 | This work is licensed under a 92 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa]. 93 | 94 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] 95 | 96 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ 97 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png 98 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg 99 | 100 | -------------------------------------------------------------------------------- /alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /augmentation/aug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio.functional as AF 5 | from .peq import ParametricEqualizer 6 | 7 | class Augment(nn.Module): 8 | def __init__(self, h): 9 | super().__init__() 10 | self.config = h 11 | self.coder = LinearPredictiveCoding( 12 | 32, h.data.win_length, h.data.hop_length) 13 | self.peq = ParametricEqualizer( 14 | h.data.sampling_rate, h.data.win_length) 15 | self.register_buffer( 16 | 'window', 17 | torch.hann_window(h.data.win_length), 18 | persistent=False) 19 | f_min, f_max, peaks = 60, 10000, 8 20 | self.register_buffer( 21 | 'peak_centers', 22 | f_min * (f_max / f_min) ** (torch.arange(peaks) / (peaks - 1)), 23 | persistent=False) 24 | 25 | def forward(self, 26 | wavs: torch.Tensor, 27 | mode: str = 'linear', 28 | ): 29 | """Augment the audio signal, random pitch, formant shift and PEQ. 30 | Args: 31 | wavs: [torch.float32; [B, T]], audio signal. 32 | mode: interpolation mode, `linear` or `nearest`. 33 | """ 34 | auxs = {} 35 | fft = torch.stft( 36 | wavs, 37 | self.config.data.filter_length, 38 | self.config.data.hop_length, 39 | self.config.data.win_length, 40 | self.window, 41 | return_complex=True) 42 | 43 | power, gain = self.sample(wavs) # for fs, ps 44 | 45 | if power is not None: 46 | q_min, q_max = 2, 5 47 | q = q_min * (q_max / q_min) ** power 48 | 49 | if gain is None: 50 | gain = torch.zeros_like(q[:, :-2]) 51 | 52 | bsize = wavs.shape[0] 53 | center = self.peak_centers[None].repeat(bsize, 1) 54 | peaks = torch.prod( 55 | self.peq.peaking_equalizer(center, gain, q[:, :-2]), dim=1) 56 | lowpass = self.peq.low_shelving(60, q[:, -2]) 57 | highpass = self.peq.high_shelving(10000, q[:, -1]) 58 | 59 | filters = peaks * highpass * lowpass 60 | fft = fft * filters[..., None] 61 | auxs.update({'peaks': peaks, 'highpass': highpass, 'lowpass': lowpass}) 62 | 63 | # Formant shifting and Pitch shifting 64 | fs_ratio = 1.4 65 | ps_ratio = 2.0 66 | 67 | code = self.coder.from_stft(fft / fft.abs().mean(dim=1)[:, None].clamp_min(1e-7)) 68 | filter_ = self.coder.envelope(code) 69 | source = fft.transpose(1, 2) / (filter_ + 1e-7) 70 | 71 | bsize = wavs.shape[0] 72 | def sampler(ratio): 73 | shifts = torch.rand(bsize, device=wavs.device) * (ratio - 1.) + 1. 74 | flip = torch.rand(bsize) < 0.5 75 | shifts[flip] = shifts[flip] ** -1 76 | return shifts 77 | 78 | fs_shift = sampler(fs_ratio) 79 | ps_shift = sampler(ps_ratio) 80 | 81 | source = fft.transpose(1, 2) / (filter_ + 1e-7) 82 | 83 | filter_ = self.interp(filter_, fs_shift, mode=mode) 84 | source = self.interp(source, ps_shift, mode=mode) 85 | 86 | fft = (source * filter_).transpose(1, 2) 87 | out = torch.istft( 88 | fft, 89 | self.config.data.filter_length, 90 | self.config.data.hop_length, 91 | self.config.data.win_length, 92 | self.window) 93 | out = out / out.max(dim=-1, keepdim=True).values.clamp_min(1e-7) 94 | 95 | return out 96 | 97 | def sample(self, wavs: torch.Tensor): 98 | bsize, _ = wavs.shape 99 | 100 | # parametric equalizer 101 | peaks = 8 102 | # quality factor 103 | power = torch.rand(bsize, peaks + 2, device=wavs.device) 104 | # gains 105 | g_min, g_max = -12, 12 106 | gain = torch.rand(bsize, peaks, device=wavs.device) * (g_max - g_min) + g_min 107 | 108 | return power, gain 109 | 110 | @staticmethod 111 | def complex_interp(inputs: torch.Tensor, *args, **kwargs): 112 | mag = F.interpolate(inputs.abs(), *args, **kwargs) 113 | angle = F.interpolate(inputs.angle(), *args, **kwargs) 114 | return torch.polar(mag, angle) 115 | 116 | def interp(self, inputs: torch.Tensor, shifts: torch.Tensor, mode: str): 117 | """Interpolate the channel axis with dynamic shifts. 118 | Args: 119 | inputs: [torch.complex64; [B, T, C]], input tensor. 120 | shifts: [torch.float32; [B]], shift factor. 121 | mode: interpolation mode. 122 | Returns: 123 | [torch.complex64; [B, T, C]], interpolated. 124 | """ 125 | INTERPOLATION = { 126 | torch.float32: F.interpolate, 127 | torch.complex64: Augment.complex_interp} 128 | assert inputs.dtype in INTERPOLATION, 'unsupported interpolation' 129 | interp_fn = INTERPOLATION[inputs.dtype] 130 | 131 | _, _, channels = inputs.shape 132 | 133 | interp = [ 134 | interp_fn( 135 | f[None], scale_factor=s.item(), mode=mode)[..., :channels] 136 | for f, s in zip(inputs, shifts)] 137 | 138 | return torch.cat([ 139 | F.pad(f, [0, channels - f.shape[-1]]) 140 | for f in interp], dim=0) 141 | 142 | 143 | class LinearPredictiveCoding(nn.Module): 144 | """LPC: Linear-predictive coding supports. 145 | """ 146 | 147 | def __init__(self, num_code: int, windows: int, strides: int): 148 | """Initializer. 149 | Args: 150 | num_code: the number of the coefficients. 151 | windows: size of the windows. 152 | strides: the number of the frames between adjacent windows. 153 | """ 154 | super().__init__() 155 | self.num_code = num_code 156 | self.windows = windows 157 | self.strides = strides 158 | 159 | def forward(self, inputs: torch.Tensor): 160 | """Compute the linear-predictive coefficients from inputs. 161 | Args: 162 | inputs: [torch.float32; [B, T]], audio signal. 163 | Returns: 164 | [torch.float32; [B, T / strides, num_code]], coefficients. 165 | """ 166 | w = self.windows 167 | frames = F.pad(inputs, [0, w]).unfold(-1, w, self.strides) 168 | corrcoef = LinearPredictiveCoding.autocorr(frames) 169 | 170 | return LinearPredictiveCoding.solve_toeplitz( 171 | corrcoef[..., :self.num_code + 1]) 172 | 173 | def from_stft(self, inputs: torch.Tensor): 174 | """Compute the linear-predictive coefficients from STFT. 175 | Args: 176 | inputs: [torch.complex64; [B, windows // 2 + 1, T / strides]], fourier features. 177 | Returns: 178 | [torch.float32; [B, T / strides, num_code]], linear-predictive coefficient. 179 | """ 180 | corrcoef = torch.fft.irfft(inputs.abs().square(), dim=1) 181 | 182 | return LinearPredictiveCoding.solve_toeplitz( 183 | corrcoef[:, :self.num_code + 1].transpose(1, 2)) 184 | 185 | def envelope(self, lpc: torch.Tensor): 186 | """LPC to spectral envelope. 187 | Args: 188 | lpc: [torch.float32; [..., num_code]], coefficients. 189 | Returns: 190 | [torch.float32; [..., windows // 2 + 1]], filters. 191 | """ 192 | denom = torch.fft.rfft(-F.pad(lpc, [1, 0], value=1.), self.windows, dim=-1).abs() 193 | # for preventing zero-division 194 | denom[(denom.abs() - 1e-7) < 0] = 1. 195 | return denom ** -1 196 | 197 | @staticmethod 198 | def autocorr(wavs: torch.Tensor): 199 | """Compute the autocorrelation. 200 | Args: audio signal. 201 | Returns: auto-correlation. 202 | """ 203 | fft = torch.fft.rfft(wavs, dim=-1) 204 | return torch.fft.irfft(fft.abs().square(), dim=-1) 205 | 206 | @staticmethod 207 | def solve_toeplitz(corrcoef: torch.Tensor): 208 | """Solve the toeplitz matrix. 209 | Args: 210 | corrcoef: [torch.float32; [..., num_code + 1]], auto-correlation. 211 | Returns: 212 | [torch.float32; [..., num_code]], solutions. 213 | """ 214 | 215 | solutions = F.pad( 216 | (-corrcoef[..., 1] / corrcoef[..., 0].clamp_min(1e-7))[..., None], 217 | [1, 0], value=1.) 218 | 219 | extra = corrcoef[..., 0] + corrcoef[..., 1] * solutions[..., 1] 220 | 221 | ## solve residuals 222 | num_code = corrcoef.shape[-1] - 1 223 | for k in range(1, num_code): 224 | lambda_value = ( 225 | -solutions[..., :k + 1] 226 | * torch.flip(corrcoef[..., 1:k + 2], dims=[-1]) 227 | ).sum(dim=-1) / extra.clamp_min(1e-7) 228 | aug = F.pad(solutions, [0, 1]) 229 | solutions = aug + lambda_value[..., None] * torch.flip(aug, dims=[-1]) 230 | extra = (1. - lambda_value ** 2) * extra 231 | 232 | return solutions[..., 1:] -------------------------------------------------------------------------------- /augmentation/peq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ParametricEqualizer(nn.Module): 6 | """Fast-parametric equalizer for approximation of Biquad IIR filter. 7 | """ 8 | def __init__(self, sr: int, windows: int): 9 | """Initializer. 10 | Args: 11 | sr: sample rate. 12 | windows: size of the fft window. 13 | """ 14 | super().__init__() 15 | self.sr = sr 16 | self.windows = windows 17 | 18 | def biquad(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 19 | """Construct frequency level biquad filter. 20 | Args: 21 | a: [torch.float32; [..., 3]], recursive filter, iir. 22 | b: [torch.float32; [..., 3]], finite impulse filter. 23 | Returns: 24 | [torch.float32; [..., windows // 2 + 1]], biquad filter. 25 | """ 26 | iir = torch.fft.rfft(a, self.windows, dim=-1) 27 | fir = torch.fft.rfft(b, self.windows, dim=-1) 28 | return fir / iir 29 | 30 | def low_shelving(self, cutoff: float, q: torch.Tensor) -> torch.Tensor: 31 | """Frequency level low-shelving filter. 32 | Args: 33 | cutoff: cutoff frequency. 34 | q: [torch.float32; [B]], quality factor. 35 | Returns: 36 | [torch.float32; [B, windows // 2 + 1]], frequency filter. 37 | """ 38 | bsize, = q.shape 39 | # ref: torchaudio.functional.lowpass_biquad 40 | w0 = 2 * np.pi * cutoff / self.sr 41 | cos_w0 = np.cos(w0) 42 | # [B] 43 | alpha = np.sin(w0) / 2 / q 44 | cos_w0 = torch.tensor( 45 | [np.cos(w0)] * bsize, dtype=torch.float32, device=q.device) 46 | # [B, windows // 2 + 1] 47 | return self.biquad( 48 | a=torch.stack([1 + alpha, -2 * cos_w0, 1 - alpha], dim=-1), 49 | b=torch.stack([(1 - cos_w0) / 2, 1 - cos_w0, (1 - cos_w0) / 2], dim=-1)) 50 | 51 | def high_shelving(self, cutoff: float, q: torch.Tensor) -> torch.Tensor: 52 | """Frequency level high-shelving filter. 53 | Args: 54 | cutoff: cutoff frequency. 55 | q: [torch.float32; [B]], quality factor. 56 | Returns: 57 | [torch.float32; [B, windows // 2 + 1]], frequency filter. 58 | """ 59 | bsize, = q.shape 60 | w0 = 2 * np.pi * cutoff / self.sr 61 | 62 | alpha = np.sin(w0) / 2 / q 63 | cos_w0 = torch.tensor( 64 | [np.cos(w0)] * bsize, dtype=torch.float32, device=q.device) 65 | 66 | return self.biquad( 67 | a=torch.stack([1 + alpha, -2 * cos_w0, 1 - alpha], dim=-1), 68 | b=torch.stack([(1 + cos_w0) / 2, -1 - cos_w0, (1 + cos_w0) / 2], dim=-1)) 69 | 70 | def peaking_equalizer(self, 71 | center: torch.Tensor, 72 | gain: torch.Tensor, 73 | q: torch.Tensor) -> torch.Tensor: 74 | """Frequency level peaking equalizer. 75 | Args: 76 | center: [torch.float32; [...]], center frequency. 77 | gain: [torch.float32; [...]], boost or attenuation in decibel. 78 | q: [torch.float32; [...]], quality factor. 79 | Returns: 80 | [torch.float32; [..., windows // 2 + 1]], frequency filter. 81 | """ 82 | w0 = 2 * np.pi * center / self.sr 83 | alpha = torch.sin(w0) / 2 / q 84 | cos_w0 = torch.cos(w0) 85 | A = (gain / 40. * np.log(10)).exp() 86 | return self.biquad( 87 | a=torch.stack([1 + alpha / A, -2 * cos_w0, 1 - alpha / A], dim=-1), 88 | b=torch.stack([1 + alpha * A, -2 * cos_w0, 1 - alpha * A], dim=-1)) -------------------------------------------------------------------------------- /ckpt/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 1000, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 1000, 8 | "optimizer": "adamw", 9 | "lr_decay_on": true, 10 | "learning_rate": 5e-5, 11 | "betas": [0.8, 0.99], 12 | "eps": 1e-9, 13 | "batch_size": 32, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 35840, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 1, 20 | "aug": true, 21 | "lambda_commit": 0.02 22 | }, 23 | "data": { 24 | "sampling_rate": 16000, 25 | "filter_length": 1280, 26 | "hop_length": 320, 27 | "win_length": 1280, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0, 30 | "mel_fmax": 8000 31 | }, 32 | "model": { 33 | "inter_channels": 192, 34 | "hidden_channels": 192, 35 | "filter_channels": 768, 36 | "n_heads": 2, 37 | "n_layers": 6, 38 | "kernel_size": 3, 39 | "p_dropout": 0.1, 40 | "resblock": "1", 41 | "resblock_kernel_sizes": [3,7,11], 42 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 43 | "upsample_rates": [5,4,4,2,2], 44 | "upsample_initial_channel": 512, 45 | "upsample_kernel_sizes": [11,8,8,4,4], 46 | "mixup_ratio": 0.6, 47 | "n_layers_q": 3, 48 | "use_spectral_norm": false, 49 | "hidden_size": 128 50 | }, 51 | "diffusion" : { 52 | "dec_dim" : 64, 53 | "spk_dim" : 128, 54 | "beta_min" : 0.05, 55 | "beta_max" : 20.0 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /ckpt/config_bigvgan.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 1000, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 1000, 8 | "optimizer": "adamw", 9 | "lr_decay_on": true, 10 | "learning_rate": 5e-5, 11 | "betas": [0.8, 0.99], 12 | "eps": 1e-9, 13 | "batch_size": 32, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 35840, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 1, 20 | "aug": true, 21 | "lambda_commit": 0.02 22 | }, 23 | "data": { 24 | "sampling_rate": 16000, 25 | "filter_length": 1280, 26 | "hop_length": 320, 27 | "win_length": 1280, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0, 30 | "mel_fmax": 8000 31 | }, 32 | "model": { 33 | "inter_channels": 192, 34 | "hidden_channels": 192, 35 | "filter_channels": 768, 36 | "n_heads": 2, 37 | "n_layers": 8, 38 | "kernel_size": 3, 39 | "p_dropout": 0.1, 40 | "resblock": "1", 41 | "resblock_kernel_sizes": [3,7,11], 42 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 43 | "upsample_rates": [5,4,2,2,2,2], 44 | "upsample_initial_channel": 1024, 45 | "upsample_kernel_sizes": [11,8,4,4,4,4], 46 | "hidden_size": 128 47 | }, 48 | "diffusion" : { 49 | "dec_dim" : 64, 50 | "spk_dim" : 128, 51 | "beta_min" : 0.05, 52 | "beta_max" : 20.0 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /configs/config_16k.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 1000, 4 | "eval_interval": 10000, 5 | "save_interval": 10000, 6 | "seed": 1234, 7 | "epochs": 1000, 8 | "optimizer": "adamw", 9 | "lr_decay_on": true, 10 | "learning_rate": 5e-5, 11 | "betas": [0.8, 0.99], 12 | "eps": 1e-9, 13 | "batch_size": 32, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 35840, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 1, 20 | "aug": true, 21 | "lambda_commit": 0.02 22 | }, 23 | "data": { 24 | "train_filelist_path": "fp_16k/train_wav.txt", 25 | "test_filelist_path": "fp_16k/test_wav.txt", 26 | "sampling_rate": 16000, 27 | "filter_length": 1280, 28 | "hop_length": 320, 29 | "win_length": 1280, 30 | "n_mel_channels": 80, 31 | "mel_fmin": 0, 32 | "mel_fmax": 8000 33 | }, 34 | "model": { 35 | "inter_channels": 192, 36 | "hidden_channels": 192, 37 | "filter_channels": 768, 38 | "n_heads": 2, 39 | "n_layers": 6, 40 | "kernel_size": 3, 41 | "p_dropout": 0.1, 42 | "resblock": "1", 43 | "resblock_kernel_sizes": [3,7,11], 44 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 45 | "upsample_rates": [5,4,4,2,2], 46 | "upsample_initial_channel": 512, 47 | "upsample_kernel_sizes": [11,8,8,4,4], 48 | "mixup_ratio": 0.6, 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false, 51 | "hidden_size": 128 52 | }, 53 | "diffusion" : { 54 | "dec_dim" : 64, 55 | "spk_dim" : 128, 56 | "beta_min" : 0.05, 57 | "beta_max" : 20.0 58 | } 59 | } -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | python3 inference.py \ 2 | --src_path './sample/src_p241_004.wav' \ 3 | --trg_path './sample/tar_p239_022.wav' \ 4 | --ckpt_model './ckpt/model_diffhier.pth' \ 5 | --voc 'bigvgan' \ 6 | --ckpt_voc './vocoder/voc_bigvgan.pth' \ 7 | --output_dir './converted' \ 8 | --diffpitch_ts 30 \ 9 | --diffvoice_ts 6 10 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import tqdm 5 | import numpy as np 6 | from glob import glob 7 | from scipy.io.wavfile import write 8 | from torch.nn import functional as F 9 | import torchaudio 10 | import copy 11 | import utils.utils as utils 12 | import amfm_decompy.pYAAPT as pYAAPT 13 | import amfm_decompy.basic_tools as basic 14 | from vocoder.hifigan import HiFi 15 | from vocoder.bigvgan import BigvGAN 16 | from model.diffhiervc import DiffHierVC, Wav2vec2 17 | from utils.utils import MelSpectrogramFixed 18 | 19 | h = None 20 | device = None 21 | seed = 1234 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | 25 | def load_audio(path): 26 | audio, sr = torchaudio.load(path) 27 | audio = audio[:1] 28 | if sr != 16000: 29 | audio = torchaudio.functional.resample(audio, sr, 16000, resampling_method="kaiser_window") 30 | 31 | p = (audio.shape[-1] // 1280 + 1) * 1280 - audio.shape[-1] 32 | audio = torch.nn.functional.pad(audio, (0, p)) 33 | 34 | return audio 35 | 36 | def save_audio(wav, out_file, syn_sr=16000): 37 | wav = (wav.squeeze() / wav.abs().max() * 0.999 * 32767.0).cpu().numpy().astype('int16') 38 | write(out_file, syn_sr, wav) 39 | 40 | def get_yaapt_f0(audio, sr=16000, interp=False): 41 | to_pad = int(20.0 / 1000 * sr) // 2 42 | f0s = [] 43 | for y in audio.astype(np.float64): 44 | y_pad = np.pad(y.squeeze(), (to_pad, to_pad), "constant", constant_values=0) 45 | pitch = pYAAPT.yaapt(basic.SignalObj(y_pad, sr), 46 | **{'frame_length': 20.0, 'frame_space': 5.0, 'nccf_thresh1': 0.25, 'tda_frame_length': 25.0}) 47 | f0s.append(pitch.samp_interp[None, None, :] if interp else pitch.samp_values[None, None, :]) 48 | 49 | return np.vstack(f0s) 50 | 51 | def inference(a): 52 | os.makedirs(a.output_dir, exist_ok=True) 53 | mel_fn = MelSpectrogramFixed( 54 | sample_rate=hps.data.sampling_rate, 55 | n_fft=hps.data.filter_length, 56 | win_length=hps.data.win_length, 57 | hop_length=hps.data.hop_length, 58 | f_min=hps.data.mel_fmin, 59 | f_max=hps.data.mel_fmax, 60 | n_mels=hps.data.n_mel_channels, 61 | window_fn=torch.hann_window 62 | ).cuda() 63 | 64 | # Load pre-trained w2v (XLS-R) 65 | w2v = Wav2vec2().cuda() 66 | 67 | # Load model 68 | model = DiffHierVC(hps.data.n_mel_channels, hps.diffusion.spk_dim, 69 | hps.diffusion.dec_dim, hps.diffusion.beta_min, hps.diffusion.beta_max, hps).cuda() 70 | model.load_state_dict(torch.load(a.ckpt_model)) 71 | model.eval() 72 | 73 | # Load vocoder 74 | if a.voc == "hifigan": 75 | net_v = HiFi(hps.data.n_mel_channels, hps.train.segment_size // hps.data.hop_length, **hps.model).cuda() 76 | utils.load_checkpoint(a.ckpt_voc, net_v, None) 77 | elif a.voc == "bigvgan": 78 | net_v = BigvGAN(hps.data.n_mel_channels, hps.train.segment_size // hps.data.hop_length, **hps.model).cuda() 79 | utils.load_checkpoint(a.ckpt_voc, net_v, None) 80 | net_v.eval().dec.remove_weight_norm() 81 | 82 | # Convert audio 83 | print('>> Converting each utterance...') 84 | src_name = os.path.splitext(os.path.basename(a.src_path))[0] 85 | audio = load_audio(a.src_path) 86 | 87 | src_mel = mel_fn(audio.cuda()) 88 | src_length = torch.LongTensor([src_mel.size(-1)]).cuda() 89 | w2v_x = w2v(F.pad(audio, (40, 40), "reflect").cuda()) 90 | 91 | try: 92 | f0 = get_yaapt_f0(audio.numpy()) 93 | except: 94 | f0 = np.zeros((1, audio.shape[-1] // 80), dtype=np.float32) 95 | 96 | f0_x = f0.copy() 97 | f0_x = torch.log(torch.FloatTensor(f0_x+1)).cuda() 98 | ii = f0 != 0 99 | f0[ii] = (f0[ii] - f0[ii].mean()) / f0[ii].std() 100 | f0_norm_x = torch.FloatTensor(f0).cuda() 101 | 102 | trg_name = os.path.splitext(os.path.basename(a.trg_path))[0] 103 | trg_audio = load_audio(a.trg_path) 104 | 105 | trg_mel = mel_fn(trg_audio.cuda()) 106 | trg_length = torch.LongTensor([trg_mel.size(-1)]).to(device) 107 | 108 | with torch.no_grad(): 109 | c = model.infer_vc(src_mel, w2v_x, f0_norm_x, f0_x, src_length, trg_mel, trg_length, 110 | diffpitch_ts=a.diffpitch_ts, diffvoice_ts=a.diffvoice_ts) 111 | converted_audio = net_v(c) 112 | 113 | f_name = f'{src_name}_to_{trg_name}.wav' 114 | out = os.path.join(a.output_dir, f_name) 115 | save_audio(converted_audio, out) 116 | 117 | 118 | def main(): 119 | print('>> Initializing Inference Process...') 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--src_path', type=str, default='/workspace/ha0/data/src.wav') 123 | parser.add_argument('--trg_path', type=str, default='/workspace/ha0/data/tar.wav') 124 | parser.add_argument('--ckpt_model', type=str, default='./ckpt/model_diffhier.pth') 125 | parser.add_argument('--voc', type=str, default='bigvgan') 126 | parser.add_argument('--ckpt_voc', type=str, default='./vocoder/voc_bigvgan.pth') 127 | parser.add_argument('--output_dir', '-o', type=str, default='./converted') 128 | parser.add_argument('--diffpitch_ts', '-dpts', type=int, default=30) 129 | parser.add_argument('--diffvoice_ts', '-dvts', type=int, default=6) 130 | 131 | global hps, hps_voc, device, a 132 | a = parser.parse_args() 133 | config = os.path.join(os.path.split(a.ckpt_model)[0], 'config_bigvgan.json') 134 | hps = utils.get_hparams_from_file(config) 135 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 136 | 137 | inference(a) 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class BaseModule(torch.nn.Module): 6 | def __init__(self): 7 | super(BaseModule, self).__init__() 8 | 9 | @property 10 | def nparams(self): 11 | num_params = 0 12 | for name, param in self.named_parameters(): 13 | if param.requires_grad: 14 | num_params += np.prod(param.detach().cpu().numpy().shape) 15 | return num_params 16 | 17 | 18 | def relocate_input(self, x: list): 19 | device = next(self.parameters()).device 20 | for i in range(len(x)): 21 | if isinstance(x[i], torch.Tensor) and x[i].device != device: 22 | x[i] = x[i].to(device) 23 | return x 24 | -------------------------------------------------------------------------------- /model/diffhiervc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from model.base import BaseModule 6 | from model.diffusion_mel import Diffusion as Mel_Diffusion 7 | from model.diffusion_f0 import Diffusion as F0_Diffusion 8 | from model.styleencoder import StyleEncoder 9 | 10 | import copy 11 | import transformers 12 | import typing as tp 13 | 14 | from module.modules import * 15 | from module.utils import * 16 | 17 | 18 | class Wav2vec2(torch.nn.Module): 19 | def __init__(self, layer=12): 20 | super().__init__() 21 | self.wav2vec2 = transformers.Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xls-r-300m") 22 | for param in self.wav2vec2.parameters(): 23 | param.requires_grad = False 24 | param.grad = None 25 | self.wav2vec2.eval() 26 | self.feature_layer = layer 27 | 28 | @torch.no_grad() 29 | def forward(self, x): 30 | outputs = self.wav2vec2(x.squeeze(1), output_hidden_states=True) 31 | y = outputs.hidden_states[self.feature_layer] 32 | 33 | return y.permute((0, 2, 1)) 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, 37 | in_channels, 38 | hidden_channels, 39 | kernel_size, 40 | dilation_rate, 41 | n_layers, 42 | mel_size=80, 43 | gin_channels=0, 44 | p_dropout=0): 45 | super().__init__() 46 | self.in_channels = in_channels 47 | self.hidden_channels = hidden_channels 48 | self.kernel_size = kernel_size 49 | self.dilation_rate = dilation_rate 50 | self.n_layers = n_layers 51 | self.gin_channels = gin_channels 52 | self.p_dropout = p_dropout 53 | 54 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 55 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, p_dropout=p_dropout) 56 | self.proj = nn.Conv1d(hidden_channels, mel_size, 1) 57 | 58 | def forward(self, x, x_mask, g=None): 59 | x = self.pre(x * x_mask) * x_mask 60 | x = self.enc(x, x_mask, g=g) 61 | x = self.proj(x) * x_mask 62 | 63 | return x 64 | 65 | 66 | class SynthesizerTrn(nn.Module): 67 | def __init__(self, hidden_size): 68 | super().__init__() 69 | self.emb_c = nn.Conv1d(1024, hidden_size, 1) 70 | self.emb_c_f0 = nn.Conv1d(1024, hidden_size, 1) 71 | self.emb_f0 = nn.Conv1d(1, hidden_size, kernel_size=9, stride=4, padding=4) 72 | self.emb_norm_f0 = nn.Conv1d(1, hidden_size, 1) 73 | self.emb_g = StyleEncoder(in_dim=80, hidden_dim=256, out_dim=256) 74 | 75 | self.mel_enc_c = Encoder(hidden_size, hidden_size, 5, 1, 8, 80, gin_channels=256, p_dropout=0) 76 | self.mel_enc_f = Encoder(hidden_size, hidden_size, 5, 1, 8, 80, gin_channels=256, p_dropout=0) 77 | self.f0_enc = Encoder(hidden_size, hidden_size, 5, 1, 8, 128, gin_channels=256, p_dropout=0) 78 | self.proj = nn.Conv1d(hidden_size, 1, 1) 79 | 80 | def forward(self, x_mel, w2v, norm_f0, f0, x_mask, f0_mask): 81 | content = self.emb_c(w2v) 82 | content_f = self.emb_c_f0(w2v) 83 | f0 = self.emb_f0(f0) 84 | norm_f0 = self.emb_norm_f0(norm_f0) 85 | 86 | g = self.emb_g(x_mel, x_mask).unsqueeze(-1) 87 | y_cont = self.mel_enc_c(F.relu(content), x_mask, g=g) 88 | y_f0 = self.mel_enc_f(F.relu(f0), x_mask, g=g) 89 | y_mel = y_cont + y_f0 90 | 91 | content_f = F.interpolate(content_f, norm_f0.shape[-1]) 92 | enc_f0 = self.f0_enc(F.relu(content_f+norm_f0), f0_mask, g=g) 93 | y_f0_hat = self.proj(enc_f0) 94 | 95 | return g, y_mel, enc_f0, y_f0_hat 96 | 97 | def spk_embedding(self, mel, length): 98 | x_mask = torch.unsqueeze(commons.sequence_mask(length, mel.size(-1)), 1).to(mel.dtype) 99 | 100 | return self.emb_g(mel, x_mask).unsqueeze(-1) 101 | 102 | def mel_predictor(self, w2v, x_mask, spk, pred_f0): 103 | content = self.emb_c(w2v) 104 | pred_f0 = self.emb_f0(pred_f0) 105 | 106 | y_cont = self.mel_enc_c(F.relu(content), x_mask, g=spk) 107 | y_f0 = self.mel_enc_f(F.relu(pred_f0), x_mask, g=spk) 108 | y_mel = y_cont + y_f0 109 | 110 | return y_mel 111 | 112 | def f0_predictor(self, w2v, x_f0_norm, y_mel, y_mask, f0_mask): 113 | content_f = self.emb_c_f0(w2v) 114 | norm_f0 = self.emb_norm_f0(x_f0_norm) 115 | g = self.emb_g(y_mel, y_mask).unsqueeze(-1) 116 | content_f = F.interpolate(content_f, norm_f0.shape[-1]) 117 | 118 | enc_f0 = self.f0_enc(F.relu(content_f+norm_f0), f0_mask, g=g) 119 | y_f0_hat = self.proj(enc_f0) 120 | 121 | return g, y_f0_hat, enc_f0 122 | 123 | 124 | class DiffHierVC(BaseModule): 125 | def __init__(self, n_feats, spk_dim, dec_dim, beta_min, beta_max, hps): 126 | super(DiffHierVC, self).__init__() 127 | self.n_feats = n_feats 128 | self.spk_dim = spk_dim 129 | self.dec_dim = dec_dim 130 | self.beta_min = beta_min 131 | self.beta_max = beta_max 132 | 133 | self.encoder = SynthesizerTrn(hps.model.hidden_size) 134 | self.f0_dec = F0_Diffusion(n_feats, 64, spk_dim, beta_min, beta_max) 135 | self.mel_dec = Mel_Diffusion(n_feats, dec_dim, spk_dim, beta_min, beta_max) 136 | 137 | @torch.no_grad() 138 | def forward(self, x, w2v, norm_y_f0, f0_x, x_length, n_timesteps, mode='ml'): 139 | x_mask = sequence_mask(x_length, x.size(2)).unsqueeze(1).to(x.dtype) 140 | f0_mask = sequence_mask(x_length*4, x.size(2)*4).unsqueeze(1).to(x.dtype) 141 | 142 | max_length = int(x_length.max()) 143 | spk, y_mel, h_f0, y_f0_hat = self.encoder(x, w2v, norm_y_f0, f0_x, x_mask, f0_mask) 144 | f0_mean_x = self.f0_dec.compute_diffused_z_pr(f0_x, f0_mask, y_f0_hat, 1.0) 145 | 146 | z_f0 = f0_mean_x * f0_mask 147 | z_f0 += torch.randn_like(z_f0, device=z_f0.device) 148 | o_f0 = self.f0_dec.reverse(z_f0, f0_mask, y_f0_hat*f0_mask, h_f0*f0_mask, spk, n_timesteps) 149 | 150 | z_mel = self.mel_dec.compute_diffused_z_pr(x, x_mask, y_mel, 1.0) 151 | z_mel += torch.randn_like(z_mel, device=z_mel.device) 152 | 153 | o_mel = self.mel_dec.reverse(z_mel, x_mask, y_mel, spk, n_timesteps) 154 | 155 | return y_f0_hat, y_mel, o_f0, o_mel[:, :, :max_length] 156 | 157 | def infer_vc(self, x, x_w2v, x_f0_norm, x_f0, x_length, y, y_length, diffpitch_ts, diffvoice_ts): 158 | x_mask = sequence_mask(x_length, x.size(2)).unsqueeze(1).to(x.dtype) 159 | y_mask = sequence_mask(y_length, y.size(2)).unsqueeze(1).to(y.dtype) 160 | f0_mask = sequence_mask(x_length*4, x.size(2)*4).unsqueeze(1).to(x.dtype) 161 | 162 | spk, y_f0_hat, enc_f0 = self.encoder.f0_predictor(x_w2v, x_f0_norm, y, y_mask, f0_mask) 163 | 164 | # Diff-Pitch 165 | z_f0 = self.f0_dec.compute_diffused_z_pr(x_f0, f0_mask, y_f0_hat, 1.0) 166 | z_f0 += torch.randn_like(z_f0, device=z_f0.device) 167 | pred_f0 = self.f0_dec.reverse(z_f0, f0_mask, y_f0_hat*f0_mask, enc_f0*f0_mask, spk, ts=diffpitch_ts) 168 | f0_zeros_mask = (x_f0 == 0) 169 | pred_f0[f0_zeros_mask.expand_as(pred_f0)] = 0 170 | 171 | # Diff-Voice 172 | y_mel = self.encoder.mel_predictor(x_w2v, x_mask, spk, pred_f0) 173 | z_mel = self.mel_dec.compute_diffused_z_pr(x, x_mask, y_mel, 1.0) 174 | z_mel += torch.randn_like(z_mel, device=z_mel.device) 175 | o_mel = self.mel_dec.reverse(z_mel, x_mask, y_mel, spk, ts=diffvoice_ts) 176 | 177 | return o_mel[:, :, :x_length] 178 | 179 | 180 | def compute_loss(self, x, w2v_x, norm_f0_x, f0_x, x_length): 181 | x_mask = sequence_mask(x_length, x.size(2)).unsqueeze(1).to(x.dtype) 182 | f0_mask = sequence_mask(x_length*4, x.size(2)*4).unsqueeze(1).to(x.dtype) 183 | 184 | spk, y_mel, y_f0, y_f0_hat = self.encoder(x, w2v_x, norm_f0_x, f0_x, x_mask, f0_mask) 185 | 186 | f0_loss = torch.sum(torch.abs(f0_x - y_f0_hat)*f0_mask) / (torch.sum(f0_mask)) 187 | mel_loss = torch.sum(torch.abs(x - y_mel)*x_mask) / (torch.sum(x_mask) * self.n_feats) 188 | 189 | f0_diff_loss = self.f0_dec.compute_t(f0_x, f0_mask, y_f0_hat, y_f0, spk) 190 | mel_diff_loss, mel_recon_loss = self.mel_dec.compute_t(x, x_mask, y_mel, spk) 191 | 192 | return mel_diff_loss, mel_recon_loss, f0_diff_loss, mel_loss, f0_loss 193 | 194 | 195 | -------------------------------------------------------------------------------- /model/diffusion_f0.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.base import BaseModule 6 | from model.diffusion_module import * 7 | from math import sqrt 8 | 9 | Linear = nn.Linear 10 | ConvTranspose2d = nn.ConvTranspose2d 11 | 12 | def Conv1d(*args, **kwargs): 13 | layer = nn.Conv1d(*args, **kwargs) 14 | nn.init.kaiming_normal_(layer.weight) 15 | return layer 16 | 17 | class ResidualBlock(nn.Module): 18 | def __init__(self, n_mels, residual_channels, dilation, dim_base): 19 | super().__init__() 20 | self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) 21 | self.diffusion_projection = Linear(dim_base, residual_channels) 22 | self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) 23 | self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) 24 | 25 | def forward(self, x, diffusion_step, conditioner, x_mask): 26 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 27 | y = x + diffusion_step 28 | 29 | conditioner = self.conditioner_projection(conditioner) 30 | y = self.dilated_conv(y*x_mask) + conditioner 31 | 32 | gate, filter = torch.chunk(y, 2, dim=1) 33 | y = torch.sigmoid(gate) * torch.tanh(filter) 34 | 35 | y = self.output_projection(y*x_mask) 36 | residual, skip = torch.chunk(y, 2, dim=1) 37 | return (x + residual) / sqrt(2.0), skip 38 | 39 | 40 | class GradLogPEstimator(BaseModule): 41 | def __init__(self, dim_base, dim_cond, res_layer=30, res_ch=64, dilation_cycle=10): 42 | super(GradLogPEstimator, self).__init__() 43 | 44 | self.time_pos_emb = SinusoidalPosEmb(dim_base) 45 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4), 46 | Mish(), 47 | torch.nn.Linear(dim_base * 4, dim_base), 48 | Mish()) 49 | 50 | cond_total = dim_base + 256 + 128 51 | self.cond_block = torch.nn.Sequential(Conv1d(cond_total, 4 * dim_cond, 1), 52 | Mish(), 53 | Conv1d(4 * dim_cond, dim_cond, 1), 54 | Mish()) 55 | 56 | self.input_projection = torch.nn.Sequential(Conv1d(1, res_ch, 1), Mish()) 57 | self.residual_layers = nn.ModuleList([ 58 | ResidualBlock(dim_cond, res_ch, 2 ** (i % dilation_cycle), dim_base) 59 | for i in range(res_layer) 60 | ]) 61 | self.skip_projection = torch.nn.Sequential(Conv1d(res_ch, res_ch, 1), Mish()) 62 | self.output_projection = Conv1d(res_ch, 1, 1) 63 | nn.init.zeros_(self.output_projection.weight) 64 | 65 | def forward(self, x, x_mask, f0, spk, t): 66 | condition = self.time_pos_emb(t) 67 | t = self.mlp(condition) 68 | x = self.input_projection(x) * x_mask 69 | 70 | condition = torch.cat([f0, condition.unsqueeze(-1).expand(-1, -1, f0.size(2)), spk.expand(-1, -1, f0.size(2))], 1) 71 | condition = self.cond_block(condition)*x_mask 72 | 73 | skip = None 74 | for layer in self.residual_layers: 75 | x, skip_connection = layer(x, t, condition, x_mask) 76 | skip = skip_connection * x_mask if skip is None else (skip_connection + skip) * x_mask 77 | 78 | x = skip / sqrt(len(self.residual_layers)) 79 | x = self.skip_projection(x) * x_mask 80 | x = self.output_projection(x) * x_mask 81 | 82 | return x 83 | 84 | @torch.no_grad() 85 | def infer(self, x, x_mask, f0, spk, t): 86 | condition = self.time_pos_emb(t) 87 | t = self.mlp(condition) 88 | x = self.input_projection(x) * x_mask 89 | 90 | condition = torch.cat([f0, condition.unsqueeze(-1).expand(-1, -1, f0.size(2)), spk.expand(-1, -1, f0.size(2))], 1) 91 | condition = self.cond_block(condition)*x_mask 92 | 93 | skip = None 94 | for layer in self.residual_layers: 95 | x, skip_connection = layer(x, t, condition, x_mask) 96 | skip = skip_connection * x_mask if skip is None else (skip_connection + skip) * x_mask 97 | 98 | x = skip / sqrt(len(self.residual_layers)) 99 | x = self.skip_projection(x) * x_mask 100 | x = self.output_projection(x) * x_mask 101 | 102 | return x 103 | 104 | class Diffusion(BaseModule): 105 | def __init__(self, n_feats, dim, dim_spk, beta_min, beta_max): 106 | super(Diffusion, self).__init__() 107 | self.estimator_f0 = GradLogPEstimator(dim, dim_spk) 108 | 109 | self.n_feats = n_feats 110 | self.dim_unet = dim 111 | self.dim_spk = dim_spk 112 | self.beta_min = beta_min 113 | self.beta_max = beta_max 114 | 115 | def get_beta(self, t): 116 | beta = self.beta_min + (self.beta_max - self.beta_min) * t 117 | return beta 118 | 119 | def get_gamma(self, s, t, p=1.0, use_torch=False): 120 | beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s) 121 | beta_integral *= (t - s) 122 | if use_torch: 123 | gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1) 124 | else: 125 | gamma = math.exp(-0.5 * p * beta_integral) 126 | return gamma 127 | 128 | def get_mu(self, s, t): 129 | a = self.get_gamma(s, t) 130 | b = 1.0 - self.get_gamma(0, s, p=2.0) 131 | c = 1.0 - self.get_gamma(0, t, p=2.0) 132 | return a * b / c 133 | 134 | def get_nu(self, s, t): 135 | a = self.get_gamma(0, s) 136 | b = 1.0 - self.get_gamma(s, t, p=2.0) 137 | c = 1.0 - self.get_gamma(0, t, p=2.0) 138 | return a * b / c 139 | 140 | def get_sigma(self, s, t): 141 | a = 1.0 - self.get_gamma(0, s, p=2.0) 142 | b = 1.0 - self.get_gamma(s, t, p=2.0) 143 | c = 1.0 - self.get_gamma(0, t, p=2.0) 144 | return math.sqrt(a * b / c) 145 | 146 | def compute_diffused_z_pr(self, x0, mask, z_pr, t, use_torch=False): 147 | x0_weight = self.get_gamma(0, t, use_torch=use_torch) 148 | z_pr_weight = 1.0 - x0_weight 149 | xt_z_pr = x0 * x0_weight + z_pr * z_pr_weight 150 | return xt_z_pr * mask 151 | 152 | def forward_diffusion(self, x0, mask, src_out, t): 153 | xt_src = self.compute_diffused_z_pr(x0, mask, src_out, t, use_torch=True) 154 | variance = 1.0 - self.get_gamma(0, t, p=2.0, use_torch=True) 155 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) 156 | xt_src = xt_src + z * torch.sqrt(variance) 157 | 158 | return xt_src * mask, z * mask 159 | 160 | @torch.no_grad() 161 | def reverse(self, z, mask, y_hat, z_f0, spk, ts): 162 | h = 1.0 / ts 163 | xt = z * mask 164 | 165 | for i in range(ts): 166 | t = 1.0 - i * h 167 | time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 168 | beta_t = self.get_beta(t) 169 | 170 | kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) 171 | kappa /= (self.get_gamma(0, t) * beta_t * h) 172 | kappa -= 1.0 173 | omega = self.get_nu(t - h, t) / self.get_gamma(0, t) 174 | omega += self.get_mu(t - h, t) 175 | omega -= (0.5 * beta_t * h + 1.0) 176 | sigma = self.get_sigma(t - h, t) 177 | 178 | dxt = (y_hat - xt) * (0.5 * beta_t * h + omega) 179 | dxt -= (self.estimator_f0.infer(xt, mask, z_f0, spk, time)) * (1.0 + kappa) * (beta_t * h) 180 | dxt += torch.randn_like(z, device=z.device) * sigma 181 | xt = (xt - dxt) * mask 182 | 183 | return xt 184 | 185 | 186 | def compute_loss(self, x0, mask, x0_hat, spk, f0, t): 187 | xt, z = self.forward_diffusion(x0, mask, x0_hat, t) 188 | z_estimation = self.estimator_f0(xt, mask, f0, spk, t) 189 | z_estimation *= torch.sqrt(1.0 - self.get_gamma(0, t, p=2.0, use_torch=True)) 190 | loss = torch.sum((z_estimation + z) ** 2) / (torch.sum(mask)) 191 | 192 | return loss 193 | 194 | def compute_t(self, x0, mask, x0_hat, f0, spk, offset=1e-5): 195 | b = x0.shape[0] 196 | t = torch.rand(b, dtype=x0.dtype, device=x0.device, requires_grad=False) 197 | t = torch.clamp(t, offset, 1.0 - offset) 198 | 199 | return self.compute_loss(x0, mask, x0_hat, spk, f0, t) 200 | -------------------------------------------------------------------------------- /model/diffusion_mel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import numpy as np 5 | from torch.nn import functional as F 6 | 7 | from model.base import BaseModule 8 | from model.diffusion_module import * 9 | 10 | 11 | class GradLogPEstimator(BaseModule): 12 | def __init__(self, dim_base, dim_cond, dim_mults=(1, 2, 4)): 13 | super(GradLogPEstimator, self).__init__() 14 | 15 | dims = [2 + dim_cond, *map(lambda m: dim_base * m, dim_mults)] 16 | in_out = list(zip(dims[:-1], dims[1:])) 17 | 18 | self.time_pos_emb = SinusoidalPosEmb(dim_base) 19 | self.mlp = torch.nn.Sequential(torch.nn.Linear(dim_base, dim_base * 4), 20 | Mish(), torch.nn.Linear(dim_base * 4, dim_base)) 21 | cond_total = dim_base + 256 22 | self.cond_block = torch.nn.Sequential(torch.nn.Linear(cond_total, 4 * dim_cond), 23 | Mish(), torch.nn.Linear(4 * dim_cond, dim_cond)) 24 | 25 | self.downs = torch.nn.ModuleList([]) 26 | self.ups = torch.nn.ModuleList([]) 27 | num_resolutions = len(in_out) 28 | 29 | for ind, (dim_in, dim_out) in enumerate(in_out): 30 | is_last = ind >= (num_resolutions - 1) 31 | self.downs.append(torch.nn.ModuleList([ 32 | ResnetBlock(dim_in, dim_out, time_emb_dim=dim_base), 33 | ResnetBlock(dim_out, dim_out, time_emb_dim=dim_base), 34 | Residual(Rezero(LinearAttention(dim_out))), 35 | Downsample(dim_out) if not is_last else torch.nn.Identity()])) 36 | 37 | mid_dim = dims[-1] 38 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base) 39 | self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) 40 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim_base) 41 | 42 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 43 | self.ups.append(torch.nn.ModuleList([ 44 | ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim_base), 45 | ResnetBlock(dim_in, dim_in, time_emb_dim=dim_base), 46 | Residual(Rezero(LinearAttention(dim_in))), 47 | Upsample(dim_in)])) 48 | 49 | self.m_final_block = Block(dim_base, dim_base) 50 | self.m_final_conv = torch.nn.Conv2d(dim_base, 1, 1) 51 | 52 | self.z_final_block = Block(dim_base, dim_base) 53 | self.z_final_conv = torch.nn.Conv2d(dim_base, 1, 1) 54 | 55 | def forward(self, x, x_mask, enc_out, spk, t): 56 | condition = self.time_pos_emb(t) 57 | t = self.mlp(condition) 58 | 59 | x = torch.stack([enc_out, x], 1) 60 | x_mask = x_mask.unsqueeze(1) 61 | 62 | condition = torch.cat([condition, spk.squeeze(2)], 1) 63 | condition = self.cond_block(condition).unsqueeze(-1).unsqueeze(-1) 64 | 65 | condition = torch.cat(x.shape[2] * [condition], 2) 66 | condition = torch.cat(x.shape[3] * [condition], 3) 67 | x = torch.cat([x, condition], 1) 68 | 69 | hiddens = [] 70 | masks = [x_mask] 71 | 72 | for resnet1, resnet2, attn, downsample in self.downs: 73 | mask_down = masks[-1] 74 | x = resnet1(x, mask_down, t) 75 | x = resnet2(x, mask_down, t) 76 | x = attn(x) 77 | hiddens.append(x) 78 | x = downsample(x * mask_down) 79 | masks.append(mask_down[:, :, :, ::2]) 80 | 81 | masks = masks[:-1] 82 | mask_mid = masks[-1] 83 | x = self.mid_block1(x, mask_mid, t) 84 | x = self.mid_attn(x) 85 | x = self.mid_block2(x, mask_mid, t) 86 | 87 | for resnet1, resnet2, attn, upsample in self.ups: 88 | mask_up = masks.pop() 89 | x = torch.cat((x, hiddens.pop()), dim=1) 90 | x = resnet1(x, mask_up, t) 91 | x = resnet2(x, mask_up, t) 92 | x = attn(x) 93 | x = upsample(x * mask_up) 94 | 95 | m_x = self.m_final_block(x, x_mask) 96 | m_output = self.m_final_conv(m_x * x_mask) 97 | 98 | z_x = self.z_final_block(x, x_mask) 99 | z_output = self.z_final_conv(z_x * x_mask) 100 | 101 | return (m_output * x_mask).squeeze(1), (z_output * x_mask).squeeze(1) 102 | 103 | 104 | class Diffusion(BaseModule): 105 | def __init__(self, n_feats, dim_unet, dim_spk, beta_min, beta_max): 106 | super(Diffusion, self).__init__() 107 | self.estimator = GradLogPEstimator(dim_unet, dim_spk) 108 | 109 | self.n_feats = n_feats 110 | self.dim_unet = dim_unet 111 | self.dim_spk = dim_spk 112 | self.beta_min = beta_min 113 | self.beta_max = beta_max 114 | 115 | def get_beta(self, t): 116 | beta = self.beta_min + (self.beta_max - self.beta_min) * t 117 | return beta 118 | 119 | def get_gamma(self, s, t, p=1.0, use_torch=False): 120 | beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s) 121 | beta_integral *= (t - s) 122 | if use_torch: 123 | gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1) 124 | else: 125 | gamma = math.exp(-0.5 * p * beta_integral) 126 | return gamma 127 | 128 | def get_mu(self, s, t): 129 | a = self.get_gamma(s, t) 130 | b = 1.0 - self.get_gamma(0, s, p=2.0) 131 | c = 1.0 - self.get_gamma(0, t, p=2.0) 132 | return a * b / c 133 | 134 | def get_nu(self, s, t): 135 | a = self.get_gamma(0, s) 136 | b = 1.0 - self.get_gamma(s, t, p=2.0) 137 | c = 1.0 - self.get_gamma(0, t, p=2.0) 138 | return a * b / c 139 | 140 | def get_sigma(self, s, t): 141 | a = 1.0 - self.get_gamma(0, s, p=2.0) 142 | b = 1.0 - self.get_gamma(s, t, p=2.0) 143 | c = 1.0 - self.get_gamma(0, t, p=2.0) 144 | return math.sqrt(a * b / c) 145 | 146 | def compute_diffused_z_pr(self, x0, mask, z_pr, t, use_torch=False): 147 | x0_weight = self.get_gamma(0, t, use_torch=use_torch) 148 | z_pr_weight = 1.0 - x0_weight 149 | xt_z_pr = x0 * x0_weight + z_pr * z_pr_weight 150 | return xt_z_pr * mask 151 | 152 | 153 | @torch.no_grad() 154 | def reverse(self, z, mask, z_pr, spk, ts): 155 | h = 1.0 / ts 156 | xt = z * mask 157 | 158 | for i in range(ts): 159 | t = 1.0 - i * h 160 | time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) 161 | beta_t = self.get_beta(t) 162 | 163 | kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0)) 164 | kappa /= (self.get_gamma(0, t) * beta_t * h) 165 | kappa -= 1.0 166 | omega = self.get_nu(t - h, t) / self.get_gamma(0, t) 167 | omega += self.get_mu(t - h, t) 168 | omega -= (0.5 * beta_t * h + 1.0) 169 | sigma = self.get_sigma(t - h, t) 170 | 171 | dxt = (z_pr - xt) * (0.5 * beta_t * h + omega) 172 | tmp, dxt_ = self.estimator(xt, mask, z_pr, spk, time) 173 | dxt -= dxt_ * (1.0 + kappa) * (beta_t * h) 174 | dxt += torch.randn_like(z, device=z.device) * sigma 175 | xt = (xt - dxt) * mask 176 | 177 | return xt 178 | 179 | @torch.no_grad() 180 | def forward(self, z, mask, enc_out, spk, n_timesteps, mode): 181 | return self.reverse_diffusion(z, mask, enc_out, spk, n_timesteps, mode) 182 | 183 | def random_masking(self, xt, num, frame): 184 | xt_mask = torch.ones_like(xt) 185 | x0_mask = torch.ones_like(xt) 186 | for _ in range(num): 187 | idx = random.randint(0, xt.size(1)-frame) 188 | xt[:, idx:idx+frame, :] = 0 189 | xt_mask[:, idx:idx+frame, :] = 0 190 | x0_mask -= xt_mask 191 | 192 | return xt, xt_mask, x0_mask 193 | 194 | 195 | def compute_diffused_z_pr(self, x0, mask, z_pr, t, use_torch=False): 196 | x0_weight = self.get_gamma(0, t, use_torch=use_torch) 197 | z_pr_weight = 1.0 - x0_weight 198 | xt_z_pr = x0 * x0_weight + z_pr * z_pr_weight 199 | return xt_z_pr * mask 200 | 201 | 202 | def forward_diffusion(self, x0, mask, enc_out, t): 203 | xt = self.compute_diffused_z_pr(x0, mask, enc_out, t, use_torch=True) 204 | variance = 1.0 - self.get_gamma(0, t, p=2.0, use_torch=True) 205 | z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False) 206 | xt = xt + z * torch.sqrt(variance) 207 | 208 | return xt * mask, z * mask 209 | 210 | def compute_loss(self, x0, mask, enc_out, spk, t): 211 | xt, z = self.forward_diffusion(x0, mask, enc_out, t) 212 | masked_xt, xt_mask, x0_mask = self.random_masking(xt, num=4, frame=8) 213 | 214 | m_estimation, z_estimation = self.estimator(masked_xt, mask, enc_out, spk, t) 215 | m_estimation *= torch.sqrt(1.0 - self.get_gamma(0, t, p=2.0, use_torch=True)) 216 | z_estimation *= torch.sqrt(1.0 - self.get_gamma(0, t, p=2.0, use_torch=True)) 217 | diff_loss = torch.sum((z_estimation*xt_mask + z) ** 2) / (torch.sum(mask) * self.n_feats) 218 | recon_loss = F.l1_loss(x0*x0_mask, m_estimation*x0_mask) 219 | 220 | return diff_loss, recon_loss 221 | 222 | def compute_t(self, x0, mask, enc_out, spk, offset=1e-5): 223 | b = x0.shape[0] 224 | t = torch.rand(b, dtype=x0.dtype, device=x0.device, requires_grad=False) 225 | t = torch.clamp(t, offset, 1.0 - offset) 226 | 227 | return self.compute_loss(x0, mask, enc_out, spk, t) 228 | -------------------------------------------------------------------------------- /model/diffusion_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from einops import rearrange 4 | 5 | from model.base import BaseModule 6 | 7 | 8 | class Mish(BaseModule): 9 | def forward(self, x): 10 | return x * torch.tanh(torch.nn.functional.softplus(x)) 11 | 12 | 13 | class Upsample(BaseModule): 14 | def __init__(self, dim): 15 | super(Upsample, self).__init__() 16 | self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) 17 | 18 | def forward(self, x): 19 | return self.conv(x) 20 | 21 | 22 | class Downsample(BaseModule): 23 | def __init__(self, dim): 24 | super(Downsample, self).__init__() 25 | self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) 26 | 27 | def forward(self, x): 28 | return self.conv(x) 29 | 30 | 31 | class Rezero(BaseModule): 32 | def __init__(self, fn): 33 | super(Rezero, self).__init__() 34 | self.fn = fn 35 | self.g = torch.nn.Parameter(torch.zeros(1)) 36 | 37 | def forward(self, x): 38 | return self.fn(x) * self.g 39 | 40 | 41 | class Block(BaseModule): 42 | def __init__(self, dim, dim_out, groups=8): 43 | super(Block, self).__init__() 44 | self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, 45 | padding=1), torch.nn.GroupNorm( 46 | groups, dim_out), Mish()) 47 | 48 | def forward(self, x, mask): 49 | output = self.block(x * mask) 50 | return output * mask 51 | 52 | 53 | class ResnetBlock(BaseModule): 54 | def __init__(self, dim, dim_out, time_emb_dim, groups=8): 55 | super(ResnetBlock, self).__init__() 56 | self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 57 | dim_out)) 58 | 59 | self.block1 = Block(dim, dim_out, groups=groups) 60 | self.block2 = Block(dim_out, dim_out, groups=groups) 61 | if dim != dim_out: 62 | self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) 63 | else: 64 | self.res_conv = torch.nn.Identity() 65 | 66 | def forward(self, x, mask, time_emb): 67 | h = self.block1(x, mask) 68 | h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) 69 | h = self.block2(h, mask) 70 | output = h + self.res_conv(x * mask) 71 | return output 72 | 73 | 74 | class LinearAttention(BaseModule): 75 | def __init__(self, dim, heads=4, dim_head=32): 76 | super(LinearAttention, self).__init__() 77 | self.heads = heads 78 | hidden_dim = dim_head * heads 79 | self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 80 | self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) 81 | 82 | def forward(self, x): 83 | b, c, h, w = x.shape 84 | qkv = self.to_qkv(x) 85 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', 86 | heads = self.heads, qkv=3) 87 | k = k.softmax(dim=-1) 88 | context = torch.einsum('bhdn,bhen->bhde', k, v) 89 | out = torch.einsum('bhde,bhdn->bhen', context, q) 90 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', 91 | heads=self.heads, h=h, w=w) 92 | return self.to_out(out) 93 | 94 | 95 | class Residual(BaseModule): 96 | def __init__(self, fn): 97 | super(Residual, self).__init__() 98 | self.fn = fn 99 | 100 | def forward(self, x, *args, **kwargs): 101 | output = self.fn(x, *args, **kwargs) + x 102 | return output 103 | 104 | 105 | class SinusoidalPosEmb(BaseModule): 106 | def __init__(self, dim): 107 | super(SinusoidalPosEmb, self).__init__() 108 | self.dim = dim 109 | 110 | def forward(self, x): 111 | device = x.device 112 | half_dim = self.dim // 2 113 | emb = math.log(10000) / (half_dim - 1) 114 | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) 115 | emb = 1000.0 * x.unsqueeze(1) * emb.unsqueeze(0) 116 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 117 | return emb 118 | 119 | 120 | class RefBlock(BaseModule): 121 | def __init__(self, out_dim, time_emb_dim): 122 | super(RefBlock, self).__init__() 123 | base_dim = out_dim // 4 124 | self.mlp1 = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 125 | base_dim)) 126 | self.mlp2 = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, 127 | 2 * base_dim)) 128 | self.block11 = torch.nn.Sequential(torch.nn.Conv2d(1, 2 * base_dim, 129 | 3, 1, 1), torch.nn.InstanceNorm2d(2 * base_dim, affine=True), 130 | torch.nn.GLU(dim=1)) 131 | self.block12 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 2 * base_dim, 132 | 3, 1, 1), torch.nn.InstanceNorm2d(2 * base_dim, affine=True), 133 | torch.nn.GLU(dim=1)) 134 | self.block21 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 4 * base_dim, 135 | 3, 1, 1), torch.nn.InstanceNorm2d(4 * base_dim, affine=True), 136 | torch.nn.GLU(dim=1)) 137 | self.block22 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 4 * base_dim, 138 | 3, 1, 1), torch.nn.InstanceNorm2d(4 * base_dim, affine=True), 139 | torch.nn.GLU(dim=1)) 140 | self.block31 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 8 * base_dim, 141 | 3, 1, 1), torch.nn.InstanceNorm2d(8 * base_dim, affine=True), 142 | torch.nn.GLU(dim=1)) 143 | self.block32 = torch.nn.Sequential(torch.nn.Conv2d(4 * base_dim, 8 * base_dim, 144 | 3, 1, 1), torch.nn.InstanceNorm2d(8 * base_dim, affine=True), 145 | torch.nn.GLU(dim=1)) 146 | self.final_conv = torch.nn.Conv2d(4 * base_dim, out_dim, 1) 147 | 148 | def forward(self, x, mask, time_emb): 149 | y = self.block11(x * mask) 150 | y = self.block12(y * mask) 151 | y += self.mlp1(time_emb).unsqueeze(-1).unsqueeze(-1) 152 | y = self.block21(y * mask) 153 | y = self.block22(y * mask) 154 | y += self.mlp2(time_emb).unsqueeze(-1).unsqueeze(-1) 155 | y = self.block31(y * mask) 156 | y = self.block32(y * mask) 157 | y = self.final_conv(y * mask) 158 | return (y * mask).sum((2, 3)) / (mask.sum((2, 3)) * x.shape[2]) 159 | -------------------------------------------------------------------------------- /model/styleencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from module.attentions import * 4 | 5 | class Mish(nn.Module): 6 | def __init__(self): 7 | super(Mish, self).__init__() 8 | def forward(self, x): 9 | return x * torch.tanh(torch.nn.functional.softplus(x)) 10 | 11 | 12 | class Conv1dGLU(nn.Module): 13 | def __init__(self, in_channels, out_channels, kernel_size, dropout): 14 | super(Conv1dGLU, self).__init__() 15 | self.out_channels = out_channels 16 | self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2) 17 | self.dropout = nn.Dropout(dropout) 18 | 19 | def forward(self, x): 20 | residual = x 21 | x = self.conv1(x) 22 | x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) 23 | x = x1 * torch.sigmoid(x2) 24 | x = residual + self.dropout(x) 25 | 26 | return x 27 | 28 | 29 | class StyleEncoder(torch.nn.Module): 30 | def __init__(self, in_dim, hidden_dim, out_dim): 31 | super().__init__() 32 | 33 | self.in_dim = in_dim 34 | self.hidden_dim = hidden_dim 35 | self.out_dim = out_dim 36 | self.kernel_size = 5 37 | self.n_head = 2 38 | self.dropout = 0.1 39 | 40 | self.spectral = nn.Sequential( 41 | nn.Conv1d(self.in_dim, self.hidden_dim, 1), 42 | Mish(), 43 | nn.Dropout(self.dropout), 44 | nn.Conv1d(self.hidden_dim, self.hidden_dim, 1), 45 | Mish(), 46 | nn.Dropout(self.dropout) 47 | ) 48 | 49 | self.temporal = nn.Sequential( 50 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 51 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 52 | ) 53 | 54 | self.slf_attn = MultiHeadAttention(self.hidden_dim, self.hidden_dim, self.n_head, p_dropout=self.dropout, proximal_bias=False, proximal_init=True) 55 | self.atten_drop = nn.Dropout(self.dropout) 56 | self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1) 57 | 58 | def forward(self, x, mask=None): 59 | x = self.spectral(x)*mask 60 | x = self.temporal(x)*mask 61 | 62 | attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1) 63 | y = self.slf_attn(x,x, attn_mask=attn_mask) 64 | x = x + self.atten_drop(y) 65 | x = self.fc(x) 66 | 67 | return self.temporal_avg_pool(x, mask=mask) 68 | 69 | def temporal_avg_pool(self, x, mask=None): 70 | if mask is None: 71 | out = torch.mean(x, dim=2) 72 | else: 73 | x = x.sum(dim=2) 74 | out = torch.div(x, mask.sum(dim=2)) 75 | 76 | return out -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hayeong0/Diff-HierVC/3ed2253cbe5bdd13f3934eae7e40ab5102ab2bde/module/__init__.py -------------------------------------------------------------------------------- /module/attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from module.commons import * 9 | from module.modules import LayerNorm 10 | 11 | class Encoder(nn.Module): 12 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, 13 | **kwargs): 14 | super().__init__() 15 | self.hidden_channels = hidden_channels 16 | self.filter_channels = filter_channels 17 | self.n_heads = n_heads 18 | self.n_layers = n_layers 19 | self.kernel_size = kernel_size 20 | self.p_dropout = p_dropout 21 | self.window_size = window_size 22 | 23 | self.drop = nn.Dropout(p_dropout) 24 | self.attn_layers = nn.ModuleList() 25 | self.norm_layers_1 = nn.ModuleList() 26 | self.ffn_layers = nn.ModuleList() 27 | self.norm_layers_2 = nn.ModuleList() 28 | for i in range(self.n_layers): 29 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 30 | window_size=window_size)) 31 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 32 | self.ffn_layers.append( 33 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 34 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 35 | 36 | def forward(self, x, x_mask): 37 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 38 | x = x * x_mask 39 | for i in range(self.n_layers): 40 | y = self.attn_layers[i](x, x, attn_mask) 41 | y = self.drop(y) 42 | x = self.norm_layers_1[i](x + y) 43 | 44 | y = self.ffn_layers[i](x, x_mask) 45 | y = self.drop(y) 46 | x = self.norm_layers_2[i](x + y) 47 | x = x * x_mask 48 | return x 49 | 50 | 51 | class Decoder(nn.Module): 52 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., 53 | proximal_bias=False, proximal_init=True, **kwargs): 54 | super().__init__() 55 | self.hidden_channels = hidden_channels 56 | self.filter_channels = filter_channels 57 | self.n_heads = n_heads 58 | self.n_layers = n_layers 59 | self.kernel_size = kernel_size 60 | self.p_dropout = p_dropout 61 | self.proximal_bias = proximal_bias 62 | self.proximal_init = proximal_init 63 | 64 | self.drop = nn.Dropout(p_dropout) 65 | self.self_attn_layers = nn.ModuleList() 66 | self.norm_layers_0 = nn.ModuleList() 67 | self.encdec_attn_layers = nn.ModuleList() 68 | self.norm_layers_1 = nn.ModuleList() 69 | self.ffn_layers = nn.ModuleList() 70 | self.norm_layers_2 = nn.ModuleList() 71 | for i in range(self.n_layers): 72 | self.self_attn_layers.append( 73 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 74 | proximal_bias=proximal_bias, proximal_init=proximal_init)) 75 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 76 | self.encdec_attn_layers.append( 77 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 78 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 79 | self.ffn_layers.append( 80 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 81 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 82 | 83 | def forward(self, x, x_mask, h, h_mask): 84 | """ 85 | x: decoder input 86 | h: encoder output 87 | """ 88 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 89 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 90 | x = x * x_mask 91 | for i in range(self.n_layers): 92 | y = self.self_attn_layers[i](x, x, self_attn_mask) 93 | y = self.drop(y) 94 | x = self.norm_layers_0[i](x + y) 95 | 96 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 97 | y = self.drop(y) 98 | x = self.norm_layers_1[i](x + y) 99 | 100 | y = self.ffn_layers[i](x, x_mask) 101 | y = self.drop(y) 102 | x = self.norm_layers_2[i](x + y) 103 | x = x * x_mask 104 | return x 105 | 106 | 107 | class MultiHeadAttention(nn.Module): 108 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, 109 | block_length=None, proximal_bias=False, proximal_init=False): 110 | super().__init__() 111 | assert channels % n_heads == 0 112 | 113 | self.channels = channels 114 | self.out_channels = out_channels 115 | self.n_heads = n_heads 116 | self.p_dropout = p_dropout 117 | self.window_size = window_size 118 | self.heads_share = heads_share 119 | self.block_length = block_length 120 | self.proximal_bias = proximal_bias 121 | self.proximal_init = proximal_init 122 | self.attn = None 123 | 124 | self.k_channels = channels // n_heads 125 | self.conv_q = nn.Conv1d(channels, channels, 1) 126 | self.conv_k = nn.Conv1d(channels, channels, 1) 127 | self.conv_v = nn.Conv1d(channels, channels, 1) 128 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 129 | self.drop = nn.Dropout(p_dropout) 130 | 131 | if window_size is not None: 132 | n_heads_rel = 1 if heads_share else n_heads 133 | rel_stddev = self.k_channels ** -0.5 134 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 135 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 136 | 137 | nn.init.xavier_uniform_(self.conv_q.weight) 138 | nn.init.xavier_uniform_(self.conv_k.weight) 139 | nn.init.xavier_uniform_(self.conv_v.weight) 140 | if proximal_init: 141 | with torch.no_grad(): 142 | self.conv_k.weight.copy_(self.conv_q.weight) 143 | self.conv_k.bias.copy_(self.conv_q.bias) 144 | 145 | def forward(self, x, c, attn_mask=None): 146 | q = self.conv_q(x) 147 | k = self.conv_k(c) 148 | v = self.conv_v(c) 149 | 150 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 151 | 152 | x = self.conv_o(x) 153 | return x 154 | 155 | def attention(self, query, key, value, mask=None): 156 | # reshape [b, d, t] -> [b, n_h, t, d_k] 157 | b, d, t_s, t_t = (*key.size(), query.size(2)) 158 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 159 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 160 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 161 | 162 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 163 | if self.window_size is not None: 164 | assert t_s == t_t, "Relative attention is only available for self-attention." 165 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 166 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) 167 | scores_local = self._relative_position_to_absolute_position(rel_logits) 168 | scores = scores + scores_local 169 | if self.proximal_bias: 170 | assert t_s == t_t, "Proximal bias is only available for self-attention." 171 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 172 | if mask is not None: 173 | scores = scores.masked_fill(mask == 0, -1e4) 174 | if self.block_length is not None: 175 | assert t_s == t_t, "Local attention is only available for self-attention." 176 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 177 | scores = scores.masked_fill(block_mask == 0, -1e4) 178 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 179 | p_attn = self.drop(p_attn) 180 | output = torch.matmul(p_attn, value) 181 | if self.window_size is not None: 182 | relative_weights = self._absolute_position_to_relative_position(p_attn) 183 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 184 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 185 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 186 | return output, p_attn 187 | 188 | def _matmul_with_relative_values(self, x, y): 189 | """ 190 | x: [b, h, l, m] 191 | y: [h or 1, m, d] 192 | ret: [b, h, l, d] 193 | """ 194 | ret = torch.matmul(x, y.unsqueeze(0)) 195 | return ret 196 | 197 | def _matmul_with_relative_keys(self, x, y): 198 | """ 199 | x: [b, h, l, d] 200 | y: [h or 1, m, d] 201 | ret: [b, h, l, m] 202 | """ 203 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 204 | return ret 205 | 206 | def _get_relative_embeddings(self, relative_embeddings, length): 207 | max_relative_position = 2 * self.window_size + 1 208 | # Pad first before slice to avoid using cond ops. 209 | pad_length = max(length - (self.window_size + 1), 0) 210 | slice_start_position = max((self.window_size + 1) - length, 0) 211 | slice_end_position = slice_start_position + 2 * length - 1 212 | if pad_length > 0: 213 | padded_relative_embeddings = F.pad( 214 | relative_embeddings, 215 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 216 | else: 217 | padded_relative_embeddings = relative_embeddings 218 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] 219 | return used_relative_embeddings 220 | 221 | def _relative_position_to_absolute_position(self, x): 222 | """ 223 | x: [b, h, l, 2*l-1] 224 | ret: [b, h, l, l] 225 | """ 226 | batch, heads, length, _ = x.size() 227 | # Concat columns of pad to shift from relative to absolute indexing. 228 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 229 | 230 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 231 | x_flat = x.view([batch, heads, length * 2 * length]) 232 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 233 | 234 | # Reshape and slice out the padded elements. 235 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] 236 | return x_final 237 | 238 | def _absolute_position_to_relative_position(self, x): 239 | """ 240 | x: [b, h, l, l] 241 | ret: [b, h, l, 2*l-1] 242 | """ 243 | batch, heads, length, _ = x.size() 244 | # padd along column 245 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 246 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 247 | # add 0's in the beginning that will skew the elements after reshape 248 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 249 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 250 | return x_final 251 | 252 | def _attention_bias_proximal(self, length): 253 | """Bias for self-attention to encourage attention to close positions. 254 | Args: 255 | length: an integer scalar. 256 | Returns: 257 | a Tensor with shape [1, 1, length, length] 258 | """ 259 | r = torch.arange(length, dtype=torch.float32) 260 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 261 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 262 | 263 | 264 | class FFN(nn.Module): 265 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, 266 | causal=False): 267 | super().__init__() 268 | self.in_channels = in_channels 269 | self.out_channels = out_channels 270 | self.filter_channels = filter_channels 271 | self.kernel_size = kernel_size 272 | self.p_dropout = p_dropout 273 | self.activation = activation 274 | self.causal = causal 275 | 276 | if causal: 277 | self.padding = self._causal_padding 278 | else: 279 | self.padding = self._same_padding 280 | 281 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 282 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 283 | self.drop = nn.Dropout(p_dropout) 284 | 285 | def forward(self, x, x_mask): 286 | x = self.conv_1(self.padding(x * x_mask)) 287 | if self.activation == "gelu": 288 | x = x * torch.sigmoid(1.702 * x) 289 | else: 290 | x = torch.relu(x) 291 | x = self.drop(x) 292 | x = self.conv_2(self.padding(x * x_mask)) 293 | return x * x_mask 294 | 295 | def _causal_padding(self, x): 296 | if self.kernel_size == 1: 297 | return x 298 | pad_l = self.kernel_size - 1 299 | pad_r = 0 300 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 301 | x = F.pad(x, commons.convert_pad_shape(padding)) 302 | return x 303 | 304 | def _same_padding(self, x): 305 | if self.kernel_size == 1: 306 | return x 307 | pad_l = (self.kernel_size - 1) // 2 308 | pad_r = self.kernel_size // 2 309 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 310 | x = F.pad(x, commons.convert_pad_shape(padding)) 311 | return x 312 | -------------------------------------------------------------------------------- /module/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | def slice_segments_audio(x, ids_str, segment_size=4): 57 | ret = torch.zeros_like(x[:, :segment_size]) 58 | for i in range(x.size(0)): 59 | idx_str = ids_str[i] 60 | idx_end = idx_str + segment_size 61 | ret[i] = x[i, idx_str:idx_end] 62 | return ret 63 | 64 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 65 | b, d, t = x.size() 66 | if x_lengths is None: 67 | x_lengths = t 68 | ids_str_max = x_lengths - segment_size + 1 69 | ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(dtype=torch.long) 70 | ret = slice_segments(x, ids_str, segment_size) 71 | return ret, ids_str 72 | 73 | 74 | def get_timing_signal_1d( 75 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 76 | position = torch.arange(length, dtype=torch.float) 77 | num_timescales = channels // 2 78 | log_timescale_increment = ( 79 | math.log(float(max_timescale) / float(min_timescale)) / 80 | (num_timescales - 1)) 81 | inv_timescales = min_timescale * torch.exp( 82 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 83 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 84 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 85 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 86 | signal = signal.view(1, channels, length) 87 | return signal 88 | 89 | 90 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 91 | b, channels, length = x.size() 92 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 93 | return x + signal.to(dtype=x.dtype, device=x.device) 94 | 95 | 96 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 97 | b, channels, length = x.size() 98 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 99 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 100 | 101 | 102 | def subsequent_mask(length): 103 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 104 | return mask 105 | 106 | 107 | @torch.jit.script 108 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 109 | n_channels_int = n_channels[0] 110 | in_act = input_a + input_b 111 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 112 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 113 | acts = t_act * s_act 114 | return acts 115 | 116 | 117 | def convert_pad_shape(pad_shape): 118 | l = pad_shape[::-1] 119 | pad_shape = [item for sublist in l for item in sublist] 120 | return pad_shape 121 | 122 | 123 | def shift_1d(x): 124 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 125 | return x 126 | 127 | 128 | def sequence_mask(length, max_length=None): 129 | if max_length is None: 130 | max_length = length.max() 131 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 132 | return x.unsqueeze(0) < length.unsqueeze(1) 133 | 134 | 135 | def generate_path(duration, mask): 136 | """ 137 | duration: [b, 1, t_x] 138 | mask: [b, 1, t_y, t_x] 139 | """ 140 | device = duration.device 141 | 142 | b, _, t_y, t_x = mask.shape 143 | cum_duration = torch.cumsum(duration, -1) 144 | 145 | cum_duration_flat = cum_duration.view(b * t_x) 146 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 147 | path = path.view(b, t_x, t_y) 148 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 149 | path = path.unsqueeze(1).transpose(2,3) * mask 150 | return path 151 | 152 | 153 | def clip_grad_value_(parameters, clip_value, norm_type=2): 154 | if isinstance(parameters, torch.Tensor): 155 | parameters = [parameters] 156 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 157 | norm_type = float(norm_type) 158 | if clip_value is not None: 159 | clip_value = float(clip_value) 160 | 161 | total_norm = 0 162 | for p in parameters: 163 | param_norm = p.grad.data.norm(norm_type) 164 | total_norm += param_norm.item() ** norm_type 165 | if clip_value is not None: 166 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 167 | total_norm = total_norm ** (1. / norm_type) 168 | return total_norm 169 | -------------------------------------------------------------------------------- /module/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | import torchaudio.transforms as T 7 | 8 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 9 | from torch.nn.utils import weight_norm, remove_weight_norm 10 | 11 | from module.commons import * 12 | from module.transforms import piecewise_rational_quadratic_transform 13 | from torch.cuda.amp import autocast 14 | 15 | LRELU_SLOPE = 0.1 16 | 17 | class LayerNorm(nn.Module): 18 | def __init__(self, channels, eps=1e-5): 19 | super().__init__() 20 | self.channels = channels 21 | self.eps = eps 22 | 23 | self.gamma = nn.Parameter(torch.ones(channels)) 24 | self.beta = nn.Parameter(torch.zeros(channels)) 25 | 26 | def forward(self, x): 27 | x = x.transpose(1, -1) 28 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 29 | return x.transpose(1, -1) 30 | 31 | 32 | class ConvReluNorm(nn.Module): 33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 34 | super().__init__() 35 | self.in_channels = in_channels 36 | self.hidden_channels = hidden_channels 37 | self.out_channels = out_channels 38 | self.kernel_size = kernel_size 39 | self.n_layers = n_layers 40 | self.p_dropout = p_dropout 41 | assert n_layers > 1, "Number of layers should be larger than 0." 42 | 43 | self.conv_layers = nn.ModuleList() 44 | self.norm_layers = nn.ModuleList() 45 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 46 | self.norm_layers.append(LayerNorm(hidden_channels)) 47 | self.relu_drop = nn.Sequential( 48 | nn.ReLU(), 49 | nn.Dropout(p_dropout)) 50 | for _ in range(n_layers - 1): 51 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 52 | self.norm_layers.append(LayerNorm(hidden_channels)) 53 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 54 | self.proj.weight.data.zero_() 55 | self.proj.bias.data.zero_() 56 | 57 | def forward(self, x, x_mask): 58 | x_org = x 59 | for i in range(self.n_layers): 60 | x = self.conv_layers[i](x * x_mask) 61 | x = self.norm_layers[i](x) 62 | x = self.relu_drop(x) 63 | x = x_org + self.proj(x) 64 | return x * x_mask 65 | 66 | 67 | class DDSConv(nn.Module): 68 | """ 69 | Dialted and Depth-Separable Convolution 70 | """ 71 | 72 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 73 | super().__init__() 74 | self.channels = channels 75 | self.kernel_size = kernel_size 76 | self.n_layers = n_layers 77 | self.p_dropout = p_dropout 78 | 79 | self.drop = nn.Dropout(p_dropout) 80 | self.convs_sep = nn.ModuleList() 81 | self.convs_1x1 = nn.ModuleList() 82 | self.norms_1 = nn.ModuleList() 83 | self.norms_2 = nn.ModuleList() 84 | for i in range(n_layers): 85 | dilation = kernel_size ** i 86 | padding = (kernel_size * dilation - dilation) // 2 87 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 88 | groups=channels, dilation=dilation, padding=padding 89 | )) 90 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 91 | self.norms_1.append(LayerNorm(channels)) 92 | self.norms_2.append(LayerNorm(channels)) 93 | 94 | def forward(self, x, x_mask, g=None): 95 | if g is not None: 96 | x = x + g 97 | for i in range(self.n_layers): 98 | y = self.convs_sep[i](x * x_mask) 99 | y = self.norms_1[i](y) 100 | y = F.gelu(y) 101 | y = self.convs_1x1[i](y) 102 | y = self.norms_2[i](y) 103 | y = F.gelu(y) 104 | y = self.drop(y) 105 | x = x + y 106 | return x * x_mask 107 | 108 | 109 | class WN(torch.nn.Module): 110 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 111 | super(WN, self).__init__() 112 | assert (kernel_size % 2 == 1) 113 | self.hidden_channels = hidden_channels 114 | self.kernel_size = kernel_size, 115 | self.dilation_rate = dilation_rate 116 | self.n_layers = n_layers 117 | self.gin_channels = gin_channels 118 | self.p_dropout = p_dropout 119 | 120 | self.in_layers = torch.nn.ModuleList() 121 | self.res_skip_layers = torch.nn.ModuleList() 122 | self.drop = nn.Dropout(p_dropout) 123 | 124 | if gin_channels != 0: 125 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 126 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 127 | 128 | for i in range(n_layers): 129 | dilation = dilation_rate ** i 130 | padding = int((kernel_size * dilation - dilation) / 2) 131 | in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, 132 | dilation=dilation, padding=padding) 133 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 134 | self.in_layers.append(in_layer) 135 | 136 | # last one is not necessary 137 | if i < n_layers - 1: 138 | res_skip_channels = 2 * hidden_channels 139 | else: 140 | res_skip_channels = hidden_channels 141 | 142 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 143 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 144 | self.res_skip_layers.append(res_skip_layer) 145 | 146 | def forward(self, x, x_mask, g=None, **kwargs): 147 | output = torch.zeros_like(x) 148 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 149 | 150 | if g is not None: 151 | g = self.cond_layer(g) 152 | 153 | for i in range(self.n_layers): 154 | x_in = self.in_layers[i](x) 155 | if g is not None: 156 | cond_offset = i * 2 * self.hidden_channels 157 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] 158 | else: 159 | g_l = torch.zeros_like(x_in) 160 | 161 | acts = fused_add_tanh_sigmoid_multiply( 162 | x_in, 163 | g_l, 164 | n_channels_tensor) 165 | acts = self.drop(acts) 166 | 167 | res_skip_acts = self.res_skip_layers[i](acts) 168 | if i < self.n_layers - 1: 169 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 170 | x = (x + res_acts) * x_mask 171 | output = output + res_skip_acts[:, self.hidden_channels:, :] 172 | else: 173 | output = output + res_skip_acts 174 | return output * x_mask 175 | 176 | def remove_weight_norm(self): 177 | if self.gin_channels != 0: 178 | torch.nn.utils.remove_weight_norm(self.cond_layer) 179 | for l in self.in_layers: 180 | torch.nn.utils.remove_weight_norm(l) 181 | for l in self.res_skip_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | 184 | class ResBlock1(torch.nn.Module): 185 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 186 | super(ResBlock1, self).__init__() 187 | self.convs1 = nn.ModuleList([ 188 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 189 | padding=get_padding(kernel_size, dilation[0]))), 190 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 191 | padding=get_padding(kernel_size, dilation[1]))), 192 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 193 | padding=get_padding(kernel_size, dilation[2]))) 194 | ]) 195 | self.convs1.apply(init_weights) 196 | 197 | self.convs2 = nn.ModuleList([ 198 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 199 | padding=get_padding(kernel_size, 1))), 200 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 201 | padding=get_padding(kernel_size, 1))), 202 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 203 | padding=get_padding(kernel_size, 1))) 204 | ]) 205 | self.convs2.apply(init_weights) 206 | 207 | def forward(self, x, x_mask=None): 208 | for c1, c2 in zip(self.convs1, self.convs2): 209 | xt = F.leaky_relu(x, LRELU_SLOPE) 210 | if x_mask is not None: 211 | xt = xt * x_mask 212 | xt = c1(xt) 213 | xt = F.leaky_relu(xt, LRELU_SLOPE) 214 | if x_mask is not None: 215 | xt = xt * x_mask 216 | xt = c2(xt) 217 | x = xt + x 218 | if x_mask is not None: 219 | x = x * x_mask 220 | return x 221 | 222 | def remove_weight_norm(self): 223 | for l in self.convs1: 224 | remove_weight_norm(l) 225 | for l in self.convs2: 226 | remove_weight_norm(l) 227 | 228 | 229 | class ResBlock2(torch.nn.Module): 230 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 231 | super(ResBlock2, self).__init__() 232 | self.convs = nn.ModuleList([ 233 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 234 | padding=get_padding(kernel_size, dilation[0]))), 235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 236 | padding=get_padding(kernel_size, dilation[1]))) 237 | ]) 238 | self.convs.apply(init_weights) 239 | 240 | def forward(self, x, x_mask=None): 241 | for c in self.convs: 242 | xt = F.leaky_relu(x, LRELU_SLOPE) 243 | if x_mask is not None: 244 | xt = xt * x_mask 245 | xt = c(xt) 246 | x = xt + x 247 | if x_mask is not None: 248 | x = x * x_mask 249 | return x 250 | 251 | def remove_weight_norm(self): 252 | for l in self.convs: 253 | remove_weight_norm(l) 254 | 255 | 256 | class Log(nn.Module): 257 | def forward(self, x, x_mask, reverse=False, **kwargs): 258 | if not reverse: 259 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 260 | logdet = torch.sum(-y, [1, 2]) 261 | return y, logdet 262 | else: 263 | x = torch.exp(x) * x_mask 264 | return x 265 | 266 | 267 | class Flip(nn.Module): 268 | def forward(self, x, *args, reverse=False, **kwargs): 269 | x = torch.flip(x, [1]) 270 | if not reverse: 271 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 272 | return x, logdet 273 | else: 274 | return x 275 | 276 | 277 | class ElementwiseAffine(nn.Module): 278 | def __init__(self, channels): 279 | super().__init__() 280 | self.channels = channels 281 | self.m = nn.Parameter(torch.zeros(channels, 1)) 282 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 283 | 284 | def forward(self, x, x_mask, reverse=False, **kwargs): 285 | if not reverse: 286 | y = self.m + torch.exp(self.logs) * x 287 | y = y * x_mask 288 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 289 | return y, logdet 290 | else: 291 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 292 | return x 293 | 294 | 295 | class ResidualCouplingLayer(nn.Module): 296 | def __init__(self, 297 | channels, 298 | hidden_channels, 299 | kernel_size, 300 | dilation_rate, 301 | n_layers, 302 | p_dropout=0, 303 | gin_channels=0, 304 | mean_only=False): 305 | assert channels % 2 == 0, "channels should be divisible by 2" 306 | super().__init__() 307 | self.channels = channels 308 | self.hidden_channels = hidden_channels 309 | self.kernel_size = kernel_size 310 | self.dilation_rate = dilation_rate 311 | self.n_layers = n_layers 312 | self.half_channels = channels // 2 313 | self.mean_only = mean_only 314 | 315 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 316 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, 317 | gin_channels=gin_channels) 318 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 319 | self.post.weight.data.zero_() 320 | self.post.bias.data.zero_() 321 | 322 | def forward(self, x, x_mask, g=None, reverse=False): 323 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 324 | h = self.pre(x0) * x_mask 325 | h = self.enc(h, x_mask, g=g) 326 | stats = self.post(h) * x_mask 327 | if not self.mean_only: 328 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 329 | else: 330 | m = stats 331 | logs = torch.zeros_like(m) 332 | 333 | if not reverse: 334 | x1 = m + x1 * torch.exp(logs) * x_mask 335 | x = torch.cat([x0, x1], 1) 336 | logdet = torch.sum(logs, [1, 2]) 337 | return x, logdet 338 | else: 339 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 340 | x = torch.cat([x0, x1], 1) 341 | return x 342 | 343 | 344 | class ConvFlow(nn.Module): 345 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 346 | super().__init__() 347 | self.in_channels = in_channels 348 | self.filter_channels = filter_channels 349 | self.kernel_size = kernel_size 350 | self.n_layers = n_layers 351 | self.num_bins = num_bins 352 | self.tail_bound = tail_bound 353 | self.half_channels = in_channels // 2 354 | 355 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 356 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 357 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 358 | self.proj.weight.data.zero_() 359 | self.proj.bias.data.zero_() 360 | 361 | def forward(self, x, x_mask, g=None, reverse=False): 362 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 363 | h = self.pre(x0) 364 | h = self.convs(h, x_mask, g=g) 365 | h = self.proj(h) * x_mask 366 | 367 | b, c, t = x0.shape 368 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 369 | 370 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 371 | unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) 372 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 373 | 374 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 375 | unnormalized_widths, 376 | unnormalized_heights, 377 | unnormalized_derivatives, 378 | inverse=reverse, 379 | tails='linear', 380 | tail_bound=self.tail_bound 381 | ) 382 | 383 | x = torch.cat([x0, x1], 1) * x_mask 384 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 385 | if not reverse: 386 | return x, logdet 387 | else: 388 | return x 389 | -------------------------------------------------------------------------------- /module/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | 5 | DEFAULT_MIN_BIN_WIDTH = 1e-3 6 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 7 | DEFAULT_MIN_DERIVATIVE = 1e-3 8 | 9 | 10 | def piecewise_rational_quadratic_transform(inputs, 11 | unnormalized_widths, 12 | unnormalized_heights, 13 | unnormalized_derivatives, 14 | inverse=False, 15 | tails=None, 16 | tail_bound=1., 17 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 18 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 19 | min_derivative=DEFAULT_MIN_DERIVATIVE): 20 | 21 | if tails is None: 22 | spline_fn = rational_quadratic_spline 23 | spline_kwargs = {} 24 | else: 25 | spline_fn = unconstrained_rational_quadratic_spline 26 | spline_kwargs = { 27 | 'tails': tails, 28 | 'tail_bound': tail_bound 29 | } 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum( 48 | inputs[..., None] >= bin_locations, 49 | dim=-1 50 | ) - 1 51 | 52 | 53 | def unconstrained_rational_quadratic_spline(inputs, 54 | unnormalized_widths, 55 | unnormalized_heights, 56 | unnormalized_derivatives, 57 | inverse=False, 58 | tails='linear', 59 | tail_bound=1., 60 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 61 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 62 | min_derivative=DEFAULT_MIN_DERIVATIVE): 63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 64 | outside_interval_mask = ~inside_interval_mask 65 | 66 | outputs = torch.zeros_like(inputs) 67 | logabsdet = torch.zeros_like(inputs) 68 | 69 | if tails == 'linear': 70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 71 | constant = np.log(np.exp(1 - min_derivative) - 1) 72 | unnormalized_derivatives[..., 0] = constant 73 | unnormalized_derivatives[..., -1] = constant 74 | 75 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 76 | logabsdet[outside_interval_mask] = 0 77 | else: 78 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 79 | 80 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 81 | inputs=inputs[inside_interval_mask], 82 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 83 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 84 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 85 | inverse=inverse, 86 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 87 | min_bin_width=min_bin_width, 88 | min_bin_height=min_bin_height, 89 | min_derivative=min_derivative 90 | ) 91 | 92 | return outputs, logabsdet 93 | 94 | def rational_quadratic_spline(inputs, 95 | unnormalized_widths, 96 | unnormalized_heights, 97 | unnormalized_derivatives, 98 | inverse=False, 99 | left=0., right=1., bottom=0., top=1., 100 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 101 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 102 | min_derivative=DEFAULT_MIN_DERIVATIVE): 103 | if torch.min(inputs) < left or torch.max(inputs) > right: 104 | raise ValueError('Input to a transform is not within its domain') 105 | 106 | num_bins = unnormalized_widths.shape[-1] 107 | 108 | if min_bin_width * num_bins > 1.0: 109 | raise ValueError('Minimal bin width too large for the number of bins') 110 | if min_bin_height * num_bins > 1.0: 111 | raise ValueError('Minimal bin height too large for the number of bins') 112 | 113 | widths = F.softmax(unnormalized_widths, dim=-1) 114 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 115 | cumwidths = torch.cumsum(widths, dim=-1) 116 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 117 | cumwidths = (right - left) * cumwidths + left 118 | cumwidths[..., 0] = left 119 | cumwidths[..., -1] = right 120 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 121 | 122 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 123 | 124 | heights = F.softmax(unnormalized_heights, dim=-1) 125 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 126 | cumheights = torch.cumsum(heights, dim=-1) 127 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 128 | cumheights = (top - bottom) * cumheights + bottom 129 | cumheights[..., 0] = bottom 130 | cumheights[..., -1] = top 131 | heights = cumheights[..., 1:] - cumheights[..., :-1] 132 | 133 | if inverse: 134 | bin_idx = searchsorted(cumheights, inputs)[..., None] 135 | else: 136 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 137 | 138 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 139 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 140 | 141 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 142 | delta = heights / widths 143 | input_delta = delta.gather(-1, bin_idx)[..., 0] 144 | 145 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 146 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 147 | 148 | input_heights = heights.gather(-1, bin_idx)[..., 0] 149 | 150 | if inverse: 151 | a = (((inputs - input_cumheights) * (input_derivatives 152 | + input_derivatives_plus_one 153 | - 2 * input_delta) 154 | + input_heights * (input_delta - input_derivatives))) 155 | b = (input_heights * input_derivatives 156 | - (inputs - input_cumheights) * (input_derivatives 157 | + input_derivatives_plus_one 158 | - 2 * input_delta)) 159 | c = - input_delta * (inputs - input_cumheights) 160 | 161 | discriminant = b.pow(2) - 4 * a * c 162 | assert (discriminant >= 0).all() 163 | 164 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 165 | outputs = root * input_bin_widths + input_cumwidths 166 | 167 | theta_one_minus_theta = root * (1 - root) 168 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 169 | * theta_one_minus_theta) 170 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 171 | + 2 * input_delta * theta_one_minus_theta 172 | + input_derivatives * (1 - root).pow(2)) 173 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 174 | 175 | return outputs, -logabsdet 176 | else: 177 | theta = (inputs - input_cumwidths) / input_bin_widths 178 | theta_one_minus_theta = theta * (1 - theta) 179 | 180 | numerator = input_heights * (input_delta * theta.pow(2) 181 | + input_derivatives * theta_one_minus_theta) 182 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 183 | * theta_one_minus_theta) 184 | outputs = input_cumheights + numerator / denominator 185 | 186 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 187 | + 2 * input_delta * theta_one_minus_theta 188 | + input_derivatives * (1 - theta).pow(2)) 189 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 190 | 191 | return outputs, logabsdet 192 | -------------------------------------------------------------------------------- /module/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 48 | iteration, checkpoint_path)) 49 | if hasattr(model, 'module'): 50 | state_dict = model.module.state_dict() 51 | else: 52 | state_dict = model.state_dict() 53 | torch.save({'model': state_dict, 54 | 'iteration': iteration, 55 | 'optimizer': optimizer.state_dict(), 56 | 'learning_rate': learning_rate}, checkpoint_path) 57 | 58 | 59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 60 | for k, v in scalars.items(): 61 | writer.add_scalar(k, v, global_step) 62 | for k, v in histograms.items(): 63 | writer.add_histogram(k, v, global_step) 64 | for k, v in images.items(): 65 | writer.add_image(k, v, global_step, dataformats='HWC') 66 | for k, v in audios.items(): 67 | writer.add_audio(k, v, global_step, audio_sampling_rate) 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | 78 | def load_wav_to_torch(full_path): 79 | sampling_rate, data = read(full_path) 80 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 81 | 82 | 83 | def load_filepaths_and_text(filename, split="|"): 84 | with open(filename, encoding='utf-8') as f: 85 | filepaths_and_text = [line.strip().split(split) for line in f] 86 | return filepaths_and_text 87 | 88 | 89 | def get_hparams(init=True): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('-c', '--config', type=str, required=True, 92 | help='JSON file for configuration') 93 | parser.add_argument('-m', '--model', type=str, required=True, 94 | help='Model name') 95 | 96 | args = parser.parse_args() 97 | model_dir = os.path.join("/workspace/raid/data/ha0/logs_rfhiervc", args.model) 98 | 99 | if not os.path.exists(model_dir): 100 | os.makedirs(model_dir) 101 | 102 | config_path = args.config 103 | config_save_path = os.path.join(model_dir, "config.json") 104 | if init: 105 | with open(config_path, "r") as f: 106 | data = f.read() 107 | with open(config_save_path, "w") as f: 108 | f.write(data) 109 | else: 110 | with open(config_save_path, "r") as f: 111 | data = f.read() 112 | config = json.loads(data) 113 | 114 | hparams = HParams(**config) 115 | hparams.model_dir = model_dir 116 | return hparams 117 | 118 | 119 | def get_hparams_from_dir(model_dir): 120 | config_save_path = os.path.join(model_dir, "config.json") 121 | with open(config_save_path, "r") as f: 122 | data = f.read() 123 | config = json.loads(data) 124 | 125 | hparams = HParams(**config) 126 | hparams.model_dir = model_dir 127 | return hparams 128 | 129 | 130 | def get_hparams_from_file(config_path): 131 | with open(config_path, "r") as f: 132 | data = f.read() 133 | config = json.loads(data) 134 | 135 | hparams = HParams(**config) 136 | return hparams 137 | 138 | 139 | def check_git_hash(model_dir): 140 | source_dir = os.path.dirname(os.path.realpath(__file__)) 141 | if not os.path.exists(os.path.join(source_dir, ".git")): 142 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 143 | source_dir 144 | )) 145 | return 146 | 147 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 148 | 149 | path = os.path.join(model_dir, "githash") 150 | if os.path.exists(path): 151 | saved_hash = open(path).read() 152 | if saved_hash != cur_hash: 153 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 154 | saved_hash[:8], cur_hash[:8])) 155 | else: 156 | open(path, "w").write(cur_hash) 157 | 158 | 159 | def get_logger(model_dir, filename="train.log"): 160 | global logger 161 | logger = logging.getLogger(os.path.basename(model_dir)) 162 | logger.setLevel(logging.DEBUG) 163 | 164 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 165 | if not os.path.exists(model_dir): 166 | os.makedirs(model_dir) 167 | h = logging.FileHandler(os.path.join(model_dir, filename)) 168 | h.setLevel(logging.DEBUG) 169 | h.setFormatter(formatter) 170 | logger.addHandler(h) 171 | return logger 172 | 173 | 174 | def parse_filelist(filelist_path): 175 | with open(filelist_path, 'r') as f: 176 | filelist = [line.strip() for line in f.readlines()] 177 | return filelist 178 | 179 | 180 | def parse_filelist_and_spk_id(filelist_path, split="|"): 181 | with open(filelist_path, encoding='utf-8') as f: 182 | filepaths_and_spkid = [line.strip().split(split) for line in f] 183 | return filepaths_and_spkid 184 | 185 | 186 | class HParams(): 187 | def __init__(self, **kwargs): 188 | for k, v in kwargs.items(): 189 | if type(v) == dict: 190 | v = HParams(**v) 191 | self[k] = v 192 | 193 | def keys(self): 194 | return self.__dict__.keys() 195 | 196 | def items(self): 197 | return self.__dict__.items() 198 | 199 | def values(self): 200 | return self.__dict__.values() 201 | 202 | def __len__(self): 203 | return len(self.__dict__) 204 | 205 | def __getitem__(self, key): 206 | return getattr(self, key) 207 | 208 | def __setitem__(self, key, value): 209 | return setattr(self, key, value) 210 | 211 | def __contains__(self, key): 212 | return key in self.__dict__ 213 | 214 | def __repr__(self): 215 | return self.__dict__.__repr__() 216 | 217 | 218 | def sequence_mask(length, max_length=None): 219 | if max_length is None: 220 | max_length = length.max() 221 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 222 | return x.unsqueeze(0) < length.unsqueeze(1) 223 | 224 | def convert_pad_shape(pad_shape): 225 | l = pad_shape[::-1] 226 | pad_shape = [item for sublist in l for item in sublist] 227 | return pad_shape 228 | 229 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 230 | while True: 231 | if length % (2**num_downsamplings_in_unet) == 0: 232 | return length 233 | length += 1 234 | 235 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | amfm_decompy==1.0.11 2 | einops==0.7.0 3 | numpy==1.21.4 4 | scipy==1.6.3 5 | torch==1.11.0+cu113 6 | torchaudio==0.11.0+cu113 7 | tqdm==4.62.3 8 | transformers==4.35.0 9 | -------------------------------------------------------------------------------- /sample/src_p241_004.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hayeong0/Diff-HierVC/3ed2253cbe5bdd13f3934eae7e40ab5102ab2bde/sample/src_p241_004.wav -------------------------------------------------------------------------------- /sample/tar_p239_022.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hayeong0/Diff-HierVC/3ed2253cbe5bdd13f3934eae7e40ab5102ab2bde/sample/tar_p239_022.wav -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn import functional as F 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | from torch.cuda.amp import autocast, GradScaler 9 | 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data.distributed import DistributedSampler 13 | 14 | import random 15 | import commons 16 | import utils 17 | 18 | from augmentation.aug import Augment 19 | from model.diffhiervc import Wav2vec2, DiffHierVC 20 | from data_loader import AudioDataset, MelSpectrogramFixed 21 | from vocoder.hifigan import HiFi 22 | from torch.utils.data import DataLoader 23 | 24 | torch.backends.cudnn.benchmark = True 25 | global_step = 0 26 | 27 | def get_param_num(model): 28 | num_param = sum(param.numel() for param in model.parameters()) 29 | return num_param 30 | 31 | def main(): 32 | """Assume Single Node Multi GPUs Training Only""" 33 | assert torch.cuda.is_available(), "CPU training is not allowed." 34 | 35 | n_gpus = torch.cuda.device_count() 36 | port = 50000 + random.randint(0, 100) 37 | os.environ['MASTER_ADDR'] = 'localhost' 38 | os.environ['MASTER_PORT'] = str(port) 39 | 40 | hps = utils.get_hparams() 41 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 42 | 43 | def run(rank, n_gpus, hps): 44 | global global_step 45 | if rank == 0: 46 | logger = utils.get_logger(hps.model_dir) 47 | logger.info(hps) 48 | utils.check_git_hash(hps.model_dir) 49 | writer = SummaryWriter(log_dir=hps.model_dir) 50 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 51 | 52 | dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank) 53 | torch.manual_seed(hps.train.seed) 54 | torch.cuda.set_device(rank) 55 | 56 | mel_fn = MelSpectrogramFixed( 57 | sample_rate=hps.data.sampling_rate, 58 | n_fft=hps.data.filter_length, 59 | win_length=hps.data.win_length, 60 | hop_length=hps.data.hop_length, 61 | f_min=hps.data.mel_fmin, 62 | f_max=hps.data.mel_fmax, 63 | n_mels=hps.data.n_mel_channels, 64 | window_fn=torch.hann_window 65 | ).cuda(rank) 66 | 67 | train_dataset = AudioDataset(hps, training=True) 68 | train_sampler = DistributedSampler(train_dataset) if n_gpus > 1 else None 69 | train_loader = DataLoader( 70 | train_dataset, batch_size=hps.train.batch_size, num_workers=32, 71 | sampler=train_sampler, drop_last=True, persistent_workers=True, pin_memory=True 72 | ) 73 | 74 | if rank == 0: 75 | test_dataset = AudioDataset(hps, training=False) 76 | eval_loader = DataLoader(test_dataset, batch_size=1) 77 | 78 | w2v = Wav2vec2().cuda(rank) 79 | aug = Augment(hps).cuda(rank) 80 | 81 | model = DiffHierVC(hps.data.n_mel_channels, hps.diffusion.spk_dim, 82 | hps.diffusion.dec_dim, hps.diffusion.beta_min, hps.diffusion.beta_max, hps).cuda() 83 | 84 | net_v = HiFi( 85 | hps.data.n_mel_channels, 86 | hps.train.segment_size // hps.data.hop_length, 87 | **hps.model).cuda() 88 | path_ckpt = './vocoder/voc_hifigan.pth' 89 | 90 | utils.load_checkpoint(path_ckpt, net_v, None) 91 | net_v.eval() 92 | net_v.dec.remove_weight_norm() 93 | 94 | if rank == 0: 95 | num_param = get_param_num(model.encoder) 96 | print('[Encoder] number of Parameters:', num_param) 97 | num_param = get_param_num(model.f0_dec) 98 | print('[F0 Decoder] number of Parameters:', num_param) 99 | num_param = get_param_num(model.mel_dec) 100 | print('[Mel Decoder] number of Parameters:', num_param) 101 | 102 | optimizer = torch.optim.AdamW( 103 | model.parameters(), 104 | hps.train.learning_rate, 105 | betas=hps.train.betas, 106 | eps=hps.train.eps) 107 | 108 | model = DDP(model, device_ids=[rank]) 109 | 110 | try: 111 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), model, optimizer) 112 | global_step = (epoch_str - 1) * len(train_loader) 113 | except: 114 | epoch_str = 1 115 | global_step = 0 116 | 117 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 118 | scaler = GradScaler(enabled=hps.train.fp16_run) 119 | 120 | for epoch in range(epoch_str, hps.train.epochs + 1): 121 | if rank == 0: 122 | train_and_evaluate(rank, epoch, hps, [model, mel_fn, w2v, aug, net_v], optimizer, 123 | scheduler_g, scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) 124 | else: 125 | train_and_evaluate(rank, epoch, hps, [model, mel_fn, w2v, aug, net_v], optimizer, 126 | scheduler_g, scaler, [train_loader, None], None, None) 127 | scheduler_g.step() 128 | 129 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): 130 | model, mel_fn, w2v, aug, net_v = nets 131 | optimizer = optims 132 | scheduler_g = schedulers 133 | train_loader, eval_loader = loaders 134 | 135 | if writers is not None: 136 | writer, writer_eval = writers 137 | global global_step 138 | 139 | train_loader.sampler.set_epoch(epoch) 140 | model.train() 141 | for batch_idx, (x, norm_f0, x_f0, length) in enumerate(train_loader): 142 | x = x.cuda(rank, non_blocking=True) 143 | norm_f0 = norm_f0.cuda(rank, non_blocking=True) 144 | x_f0 = x_f0.cuda(rank, non_blocking=True) 145 | length = length.cuda(rank, non_blocking=True).squeeze() 146 | 147 | mel_x = mel_fn(x) 148 | aug_x = aug(x) 149 | nan_x = torch.isnan(aug_x).any() 150 | x = x if nan_x else aug_x 151 | x_pad = F.pad(x, (40, 40), "reflect") 152 | 153 | w2v_x = w2v(x_pad) 154 | f0_x = torch.log(x_f0+1) 155 | 156 | optimizer.zero_grad() 157 | loss_mel_diff, loss_mel_diff_rec, loss_f0_diff, loss_mel, loss_f0 = model.module.compute_loss(mel_x, w2v_x, norm_f0, f0_x, length) 158 | loss_gen_all = loss_mel_diff + loss_mel_diff_rec + loss_f0_diff + loss_mel*hps.train.c_mel + loss_f0 159 | 160 | if hps.train.fp16_run: 161 | scaler.scale(loss_gen_all).backward() 162 | scaler.unscale_(optimizer) 163 | grad_norm_g = commons.clip_grad_value_(model.parameters(), None) 164 | scaler.step(optimizer) 165 | scaler.update() 166 | else: 167 | loss_gen_all.backward() 168 | grad_norm_g = commons.clip_grad_value_(model.parameters(), None) 169 | optimizer.step() 170 | 171 | if rank == 0: 172 | if global_step % hps.train.log_interval == 0: 173 | lr = optimizer.param_groups[0]['lr'] 174 | losses = [loss_mel_diff, loss_f0_diff] 175 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 176 | epoch, 177 | 100. * batch_idx / len(train_loader))) 178 | logger.info([x.item() for x in losses] + [global_step, lr]) 179 | 180 | scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} 181 | scalar_dict.update({"loss/g/diff": loss_mel_diff, "loss/g/diff_rec": loss_mel_diff_rec, "loss/g/f0_diff": loss_f0_diff, "loss/g/mel": loss_mel, "loss/g/f0": loss_f0}) 182 | 183 | utils.summarize( 184 | writer=writer, 185 | global_step=global_step, 186 | scalars=scalar_dict) 187 | 188 | if global_step % hps.train.eval_interval == 0: 189 | torch.cuda.empty_cache() 190 | evaluate(hps, model, mel_fn, w2v, net_v, eval_loader, writer_eval) 191 | 192 | if global_step % hps.train.save_interval == 0: 193 | utils.save_checkpoint(model, optimizer, hps.train.learning_rate, epoch, 194 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) 195 | 196 | global_step += 1 197 | 198 | if rank == 0: 199 | logger.info('====> Epoch: {}'.format(epoch)) 200 | 201 | 202 | def evaluate(hps, model, mel_fn, w2v, net_v, eval_loader, writer_eval): 203 | model.eval() 204 | image_dict = {} 205 | audio_dict = {} 206 | mel_loss = 0 207 | enc_loss = 0 208 | enc_f0_loss = 0 209 | diff_f0_loss = 0 210 | 211 | with torch.no_grad(): 212 | for batch_idx, (y, norm_y_f0, y_f0) in enumerate(eval_loader): 213 | y = y.cuda(0) 214 | norm_y_f0 = norm_y_f0.cuda(0) 215 | y_f0 = y_f0.cuda(0) 216 | 217 | mel_y = mel_fn(y) 218 | f0_y = torch.log(y_f0+1) 219 | length = torch.LongTensor([mel_y.size(2)]).cuda(0) 220 | 221 | y_pad = F.pad(y, (40, 40), "reflect") 222 | w2v_y = w2v(y_pad) 223 | 224 | y_f0_hat, y_mel, o_f0, o_mel = model(mel_y, w2v_y, norm_y_f0, f0_y, length, n_timesteps=6, mode='ml') 225 | 226 | mel_loss += F.l1_loss(mel_y, o_mel).item() 227 | enc_loss += F.l1_loss(mel_y, y_mel).item() 228 | enc_f0_loss += F.l1_loss(f0_y, y_f0_hat).item() 229 | diff_f0_loss += F.l1_loss(f0_y, o_f0).item() 230 | 231 | if batch_idx > 100: 232 | break 233 | if batch_idx <= 4: 234 | y_hat = net_v(o_mel) 235 | enc_hat = net_v(y_mel) 236 | 237 | plot_mel = torch.cat([mel_y, o_mel, y_mel], dim=1) 238 | plot_mel = plot_mel.clip(min=-10, max=10) 239 | 240 | image_dict.update({ 241 | "gen/mel_{}".format(batch_idx): utils.plot_spectrogram_to_numpy(plot_mel.squeeze().cpu().numpy()), 242 | "F0/f0_{}".format(batch_idx): 243 | utils.plot_f0_contour_to_numpy(mel_y.repeat_interleave(repeats=4, dim=2).squeeze().cpu().numpy(), 244 | f0s= {'target_f0': y_f0.squeeze().cpu(), 245 | 'enc_f0': (torch.exp(y_f0_hat)-1).squeeze().cpu(), 246 | 'diff_6_f0': (torch.exp(o_f0)-1).squeeze().cpu() 247 | }) 248 | }) 249 | audio_dict.update({ 250 | "gen/audio_{}".format(batch_idx): y_hat.squeeze(), 251 | "gen/enc_audio_{}".format(batch_idx): enc_hat.squeeze() 252 | }) 253 | if global_step == 0: 254 | audio_dict.update({"gt/audio_{}".format(batch_idx): y.squeeze()}) 255 | 256 | mel_loss /= 100 257 | enc_loss /= 100 258 | enc_f0_loss /= 100 259 | diff_f0_loss /= 100 260 | 261 | scalar_dict = {"val/mel": mel_loss, "val/enc_mel": enc_loss, "val/enc_f0": enc_f0_loss, "val/diff_f0": diff_f0_loss} 262 | utils.summarize( 263 | writer=writer_eval, 264 | global_step=global_step, 265 | images=image_dict, 266 | audios=audio_dict, 267 | audio_sampling_rate=hps.data.sampling_rate, 268 | scalars=scalar_dict 269 | ) 270 | model.train() 271 | 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchaudio 4 | from torchaudio.transforms import MelSpectrogram 5 | from module.utils import parse_filelist 6 | from torch.nn import functional as F 7 | np.random.seed(1234) 8 | 9 | class AudioDataset(torch.utils.data.Dataset): 10 | """ 11 | Provides dataset management for given filelist. 12 | """ 13 | def __init__(self, config, training=True): 14 | super(AudioDataset, self).__init__() 15 | self.config = config 16 | self.hop_length = config.data.hop_length 17 | self.training = training 18 | self.mel_length = config.train.segment_size // config.data.hop_length 19 | self.segment_length = config.train.segment_size 20 | self.sample_rate = config.data.sampling_rate 21 | 22 | self.filelist_path = config.data.train_filelist_path \ 23 | if self.training else config.data.test_filelist_path 24 | self.audio_paths = parse_filelist(self.filelist_path) \ 25 | if self.training else parse_filelist(self.filelist_path)[:101] 26 | 27 | self.f0_norm_paths = parse_filelist(self.filelist_path.replace('_wav', '_f0_norm')) 28 | self.f0_paths = parse_filelist(self.filelist_path.replace('_wav', '_f0')) 29 | 30 | 31 | def load_audio_to_torch(self, audio_path): 32 | audio, sample_rate = torchaudio.load(audio_path) 33 | 34 | if not self.training: 35 | p = (audio.shape[-1] // 1280 + 1) * 1280 - audio.shape[-1] 36 | audio = F.pad(audio, (0, p), mode='constant').data 37 | return audio.squeeze(), sample_rate 38 | 39 | def __getitem__(self, index): 40 | audio_path = self.audio_paths[index] 41 | f0_norm_path = self.f0_norm_paths[index] 42 | f0_path = self.f0_paths[index] 43 | 44 | audio, sample_rate = self.load_audio_to_torch(audio_path) 45 | f0_norm = torch.load(f0_norm_path) 46 | f0 = torch.load(f0_path) 47 | 48 | assert sample_rate == self.sample_rate, \ 49 | f"""Got path to audio of sampling rate {sample_rate}, \ 50 | but required {self.sample_rate} according config.""" 51 | 52 | if not self.training: 53 | return audio, f0_norm, f0 54 | 55 | if audio.shape[-1] > self.segment_length: 56 | max_f0_start = f0.shape[-1] - self.segment_length//80 57 | 58 | f0_start = np.random.randint(0, max_f0_start) 59 | f0_norm_seg = f0_norm[:, f0_start:f0_start + self.segment_length // 80] 60 | f0_seg = f0[:, f0_start:f0_start + self.segment_length // 80] 61 | 62 | audio_start = f0_start*80 63 | segment = audio[audio_start:audio_start + self.segment_length] 64 | 65 | if segment.shape[-1] < self.segment_length: 66 | segment = F.pad(segment, (0, self.segment_length - segment.shape[-1]), 'constant') 67 | length = torch.LongTensor([self.mel_length]) 68 | 69 | else: 70 | segment = F.pad(audio, (0, self.segment_length - audio.shape[-1]), 'constant') 71 | length = torch.LongTensor([audio.shape[-1] // self.hop_length]) 72 | 73 | f0_norm_seg = F.pad(f0_norm, (0, self.segment_length // 80 - f0_norm.shape[-1]), 'constant') 74 | 75 | f0_seg = F.pad(f0, (0, self.segment_length // 80 - f0.shape[-1]), 'constant') 76 | 77 | return segment, f0_norm_seg, f0_seg, length 78 | 79 | def __len__(self): 80 | return len(self.audio_paths) 81 | 82 | def sample_test_batch(self, size): 83 | idx = np.random.choice(range(len(self)), size=size, replace=False) 84 | test_batch = [] 85 | for index in idx: 86 | test_batch.append(self.__getitem__(index)) 87 | return test_batch 88 | 89 | class MelSpectrogramFixed(torch.nn.Module): 90 | """In order to remove padding of torchaudio package + add log10 scale.""" 91 | 92 | def __init__(self, **kwargs): 93 | super(MelSpectrogramFixed, self).__init__() 94 | self.torchaudio_backend = MelSpectrogram(**kwargs) 95 | 96 | def forward(self, x): 97 | outputs = torch.log(self.torchaudio_backend(x) + 0.001) 98 | 99 | return outputs[..., :-1] -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | import torch 10 | from scipy.io.wavfile import read 11 | MATPLOTLIB_FLAG = False 12 | 13 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 14 | logger = logging 15 | 16 | from torchaudio.transforms import MelSpectrogram 17 | 18 | class MelSpectrogramFixed(torch.nn.Module): 19 | def __init__(self, **kwargs): 20 | super(MelSpectrogramFixed, self).__init__() 21 | self.torchaudio_backend = MelSpectrogram(**kwargs) 22 | 23 | def forward(self, x): 24 | outputs = torch.log(self.torchaudio_backend(x) + 0.001) 25 | return outputs[..., :-1] 26 | 27 | def load_checkpoint(checkpoint_path, model, optimizer=None): 28 | assert os.path.isfile(checkpoint_path) 29 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 30 | iteration = checkpoint_dict['iteration'] 31 | learning_rate = checkpoint_dict['learning_rate'] 32 | if optimizer is not None: 33 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 34 | saved_state_dict = checkpoint_dict['model'] 35 | if hasattr(model, 'module'): 36 | state_dict = model.module.state_dict() 37 | else: 38 | state_dict = model.state_dict() 39 | new_state_dict = {} 40 | for k, v in state_dict.items(): 41 | try: 42 | new_state_dict[k] = saved_state_dict[k] 43 | except: 44 | logger.info("%s is not in the checkpoint" % k) 45 | new_state_dict[k] = v 46 | if hasattr(model, 'module'): 47 | model.module.load_state_dict(new_state_dict) 48 | else: 49 | model.load_state_dict(new_state_dict) 50 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 51 | checkpoint_path, iteration)) 52 | return model, optimizer, learning_rate, iteration 53 | 54 | 55 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 56 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 57 | iteration, checkpoint_path)) 58 | if hasattr(model, 'module'): 59 | state_dict = model.module.state_dict() 60 | else: 61 | state_dict = model.state_dict() 62 | torch.save({'model': state_dict, 63 | 'iteration': iteration, 64 | 'optimizer': optimizer.state_dict(), 65 | 'learning_rate': learning_rate}, checkpoint_path) 66 | 67 | 68 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 69 | for k, v in scalars.items(): 70 | writer.add_scalar(k, v, global_step) 71 | for k, v in histograms.items(): 72 | writer.add_histogram(k, v, global_step) 73 | for k, v in images.items(): 74 | writer.add_image(k, v, global_step, dataformats='HWC') 75 | for k, v in audios.items(): 76 | writer.add_audio(k, v, global_step, audio_sampling_rate) 77 | 78 | 79 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 80 | f_list = glob.glob(os.path.join(dir_path, regex)) 81 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 82 | x = f_list[-1] 83 | print(x) 84 | return x 85 | 86 | 87 | def load_wav_to_torch(full_path): 88 | sampling_rate, data = read(full_path) 89 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 90 | 91 | 92 | def load_filepaths_and_text(filename, split="|"): 93 | with open(filename, encoding='utf-8') as f: 94 | filepaths_and_text = [line.strip().split(split) for line in f] 95 | return filepaths_and_text 96 | 97 | def get_hparams(init=True): 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('-c', '--config', type=str, required=True, 100 | help='JSON file for configuration') 101 | parser.add_argument('-m', '--model', type=str, required=True, 102 | help='Model name') 103 | 104 | args = parser.parse_args() 105 | model_dir = os.path.join("/workspace/raid/ha0/logs_diffhier", args.model) 106 | 107 | if not os.path.exists(model_dir): 108 | os.makedirs(model_dir) 109 | 110 | config_path = args.config 111 | config_save_path = os.path.join(model_dir, "config.json") 112 | if init: 113 | with open(config_path, "r") as f: 114 | data = f.read() 115 | with open(config_save_path, "w") as f: 116 | f.write(data) 117 | else: 118 | with open(config_save_path, "r") as f: 119 | data = f.read() 120 | config = json.loads(data) 121 | 122 | hparams = HParams(**config) 123 | hparams.model_dir = model_dir 124 | return hparams 125 | 126 | def get_hparams_from_dir(model_dir): 127 | config_save_path = os.path.join(model_dir, "config.json") 128 | with open(config_save_path, "r") as f: 129 | data = f.read() 130 | config = json.loads(data) 131 | 132 | hparams = HParams(**config) 133 | hparams.model_dir = model_dir 134 | return hparams 135 | 136 | def get_hparams_from_file(config_path): 137 | with open(config_path, "r") as f: 138 | data = f.read() 139 | config = json.loads(data) 140 | 141 | hparams = HParams(**config) 142 | return hparams 143 | 144 | def check_git_hash(model_dir): 145 | source_dir = os.path.dirname(os.path.realpath(__file__)) 146 | if not os.path.exists(os.path.join(source_dir, ".git")): 147 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 148 | source_dir 149 | )) 150 | return 151 | 152 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 153 | 154 | path = os.path.join(model_dir, "githash") 155 | if os.path.exists(path): 156 | saved_hash = open(path).read() 157 | if saved_hash != cur_hash: 158 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 159 | saved_hash[:8], cur_hash[:8])) 160 | else: 161 | open(path, "w").write(cur_hash) 162 | 163 | 164 | def get_logger(model_dir, filename="train.log"): 165 | global logger 166 | logger = logging.getLogger(os.path.basename(model_dir)) 167 | logger.setLevel(logging.DEBUG) 168 | 169 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 170 | if not os.path.exists(model_dir): 171 | os.makedirs(model_dir) 172 | h = logging.FileHandler(os.path.join(model_dir, filename)) 173 | h.setLevel(logging.DEBUG) 174 | h.setFormatter(formatter) 175 | logger.addHandler(h) 176 | return logger 177 | 178 | def parse_filelist(filelist_path): 179 | with open(filelist_path, 'r') as f: 180 | filelist = [line.strip() for line in f.readlines()] 181 | return filelist 182 | 183 | 184 | def parse_filelist_and_spk_id(filelist_path, split="|"): 185 | with open(filelist_path, encoding='utf-8') as f: 186 | filepaths_and_spkid = [line.strip().split(split) for line in f] 187 | return filepaths_and_spkid 188 | 189 | 190 | class HParams(): 191 | def __init__(self, **kwargs): 192 | for k, v in kwargs.items(): 193 | if type(v) == dict: 194 | v = HParams(**v) 195 | self[k] = v 196 | 197 | def keys(self): 198 | return self.__dict__.keys() 199 | 200 | def items(self): 201 | return self.__dict__.items() 202 | 203 | def values(self): 204 | return self.__dict__.values() 205 | 206 | def __len__(self): 207 | return len(self.__dict__) 208 | 209 | def __getitem__(self, key): 210 | return getattr(self, key) 211 | 212 | def __setitem__(self, key, value): 213 | return setattr(self, key, value) 214 | 215 | def __contains__(self, key): 216 | return key in self.__dict__ 217 | 218 | def __repr__(self): 219 | return self.__dict__.__repr__() 220 | -------------------------------------------------------------------------------- /vocoder/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x 121 | -------------------------------------------------------------------------------- /vocoder/bigvgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 6 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 7 | from torch.cuda.amp import autocast 8 | import torchaudio 9 | from einops import rearrange 10 | 11 | from alias_free_torch import * 12 | from module.commons import init_weights, get_padding 13 | import vocoder.modules as modules 14 | import vocoder.activations as activations 15 | 16 | class AMPBlock1(torch.nn.Module): 17 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): 18 | super(AMPBlock1, self).__init__() 19 | 20 | self.convs1 = nn.ModuleList([ 21 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 22 | padding=get_padding(kernel_size, dilation[0]))), 23 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 24 | padding=get_padding(kernel_size, dilation[1]))), 25 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 26 | padding=get_padding(kernel_size, dilation[2]))) 27 | ]) 28 | self.convs1.apply(init_weights) 29 | 30 | self.convs2 = nn.ModuleList([ 31 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 32 | padding=get_padding(kernel_size, 1))), 33 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 34 | padding=get_padding(kernel_size, 1))), 35 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 36 | padding=get_padding(kernel_size, 1))) 37 | ]) 38 | self.convs2.apply(init_weights) 39 | 40 | self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers 41 | 42 | 43 | self.activations = nn.ModuleList([ 44 | Activation1d( 45 | activation=activations.SnakeBeta(channels, alpha_logscale=True)) 46 | for _ in range(self.num_layers) 47 | ]) 48 | 49 | def forward(self, x): 50 | acts1, acts2 = self.activations[::2], self.activations[1::2] 51 | for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): 52 | xt = a1(x) 53 | xt = c1(xt) 54 | xt = a2(xt) 55 | xt = c2(xt) 56 | x = xt + x 57 | 58 | return x 59 | 60 | def remove_weight_norm(self): 61 | for l in self.convs1: 62 | remove_weight_norm(l) 63 | for l in self.convs2: 64 | remove_weight_norm(l) 65 | 66 | 67 | class AMPBlock2(torch.nn.Module): 68 | def __init__(self, channels, kernel_size=3, dilation=(1, 3), activation=None): 69 | super(AMPBlock2, self).__init__() 70 | 71 | 72 | self.convs = nn.ModuleList([ 73 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 74 | padding=get_padding(kernel_size, dilation[0]))), 75 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 76 | padding=get_padding(kernel_size, dilation[1]))) 77 | ]) 78 | self.convs.apply(init_weights) 79 | 80 | self.num_layers = len(self.convs) # total number of conv layers 81 | 82 | if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing 83 | self.activations = nn.ModuleList([ 84 | Activation1d( 85 | activation=activations.Snake(channels, alpha_logscale=True)) 86 | for _ in range(self.num_layers) 87 | ]) 88 | elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing 89 | self.activations = nn.ModuleList([ 90 | Activation1d( 91 | activation=activations.SnakeBeta(channels, alpha_logscale=True)) 92 | for _ in range(self.num_layers) 93 | ]) 94 | else: 95 | raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") 96 | 97 | def forward(self, x): 98 | for c, a in zip (self.convs, self.activations): 99 | xt = a(x) 100 | xt = c(xt) 101 | x = xt + x 102 | 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs: 107 | remove_weight_norm(l) 108 | 109 | class Generator(torch.nn.Module): 110 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 111 | super(Generator, self).__init__() 112 | self.num_kernels = len(resblock_kernel_sizes) 113 | self.num_upsamples = len(upsample_rates) 114 | 115 | self.conv_pre = weight_norm(Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)) 116 | resblock = AMPBlock1 if resblock == '1' else AMPBlock2 117 | 118 | self.ups = nn.ModuleList() 119 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 120 | self.ups.append(weight_norm( 121 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 122 | k, u, padding=(k-u)//2))) 123 | 124 | self.resblocks = nn.ModuleList() 125 | for i in range(len(self.ups)): 126 | ch = upsample_initial_channel//(2**(i+1)) 127 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 128 | self.resblocks.append(resblock(ch, k, d, activation="snakebeta")) 129 | 130 | activation_post = activations.SnakeBeta(ch, alpha_logscale=True) 131 | self.activation_post = Activation1d(activation=activation_post) 132 | 133 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 134 | self.ups.apply(init_weights) 135 | 136 | if gin_channels != 0: 137 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 138 | 139 | def forward(self, x, g=None): 140 | x = self.conv_pre(x) 141 | if g is not None: 142 | x = x + self.cond(g) 143 | 144 | for i in range(self.num_upsamples): 145 | 146 | x = self.ups[i](x) 147 | xs = None 148 | for j in range(self.num_kernels): 149 | if xs is None: 150 | xs = self.resblocks[i*self.num_kernels+j](x) 151 | else: 152 | xs += self.resblocks[i*self.num_kernels+j](x) 153 | x = xs / self.num_kernels 154 | 155 | x = self.activation_post(x) 156 | x = self.conv_post(x) 157 | x = torch.tanh(x) 158 | 159 | return x 160 | 161 | def remove_weight_norm(self): 162 | print('Removing weight norm...') 163 | for l in self.ups: 164 | remove_weight_norm(l) 165 | for l in self.resblocks: 166 | l.remove_weight_norm() 167 | 168 | class DiscriminatorP(torch.nn.Module): 169 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 170 | super(DiscriminatorP, self).__init__() 171 | self.period = period 172 | self.use_spectral_norm = use_spectral_norm 173 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 174 | self.convs = nn.ModuleList([ 175 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 176 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 177 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 178 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 179 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 180 | ]) 181 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 182 | 183 | def forward(self, x): 184 | fmap = [] 185 | 186 | b, c, t = x.shape 187 | if t % self.period != 0: 188 | n_pad = self.period - (t % self.period) 189 | x = F.pad(x, (0, n_pad), "reflect") 190 | t = t + n_pad 191 | x = x.view(b, c, t // self.period, self.period) 192 | 193 | for l in self.convs: 194 | x = l(x) 195 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 196 | fmap.append(x) 197 | x = self.conv_post(x) 198 | fmap.append(x) 199 | x = torch.flatten(x, 1, -1) 200 | 201 | return x, fmap 202 | 203 | class DiscriminatorR(torch.nn.Module): 204 | def __init__(self, resolution, use_spectral_norm=False): 205 | super(DiscriminatorR, self).__init__() 206 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 207 | 208 | n_fft, hop_length, win_length = resolution 209 | self.spec_transform = torchaudio.transforms.Spectrogram( 210 | n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window, 211 | normalized=True, center=False, pad_mode=None, power=None) 212 | 213 | self.convs = nn.ModuleList([ 214 | norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))), 215 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 216 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))), 217 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))), 218 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 219 | ]) 220 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 221 | 222 | def forward(self, y): 223 | fmap = [] 224 | 225 | x = self.spec_transform(y) 226 | x = torch.cat([x.real, x.imag], dim=1) 227 | x = rearrange(x, 'b c w t -> b c t w') 228 | 229 | for l in self.convs: 230 | x = l(x) 231 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 232 | fmap.append(x) 233 | x = self.conv_post(x) 234 | fmap.append(x) 235 | x = torch.flatten(x, 1, -1) 236 | 237 | return x, fmap 238 | 239 | 240 | class MultiPeriodDiscriminator(torch.nn.Module): 241 | def __init__(self, use_spectral_norm=False): 242 | super(MultiPeriodDiscriminator, self).__init__() 243 | periods = [2,3,5,7,11] 244 | resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]] 245 | 246 | discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))] 247 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 248 | self.discriminators = nn.ModuleList(discs) 249 | 250 | def forward(self, y, y_hat): 251 | y_d_rs = [] 252 | y_d_gs = [] 253 | fmap_rs = [] 254 | fmap_gs = [] 255 | for i, d in enumerate(self.discriminators): 256 | y_d_r, fmap_r = d(y) 257 | y_d_g, fmap_g = d(y_hat) 258 | y_d_rs.append(y_d_r) 259 | y_d_gs.append(y_d_g) 260 | fmap_rs.append(fmap_r) 261 | fmap_gs.append(fmap_g) 262 | 263 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 264 | 265 | class BigvGAN(nn.Module): 266 | """ 267 | Synthesizer for Training 268 | """ 269 | 270 | def __init__(self, 271 | 272 | spec_channels, 273 | segment_size, 274 | inter_channels, 275 | hidden_channels, 276 | filter_channels, 277 | n_heads, 278 | n_layers, 279 | kernel_size, 280 | p_dropout, 281 | resblock, 282 | resblock_kernel_sizes, 283 | resblock_dilation_sizes, 284 | upsample_rates, 285 | upsample_initial_channel, 286 | upsample_kernel_sizes, 287 | **kwargs): 288 | 289 | super().__init__() 290 | self.spec_channels = spec_channels 291 | self.inter_channels = inter_channels 292 | self.hidden_channels = hidden_channels 293 | self.filter_channels = filter_channels 294 | self.n_heads = n_heads 295 | self.n_layers = n_layers 296 | self.kernel_size = kernel_size 297 | self.p_dropout = p_dropout 298 | self.resblock = resblock 299 | self.resblock_kernel_sizes = resblock_kernel_sizes 300 | self.resblock_dilation_sizes = resblock_dilation_sizes 301 | self.upsample_rates = upsample_rates 302 | self.upsample_initial_channel = upsample_initial_channel 303 | self.upsample_kernel_sizes = upsample_kernel_sizes 304 | self.segment_size = segment_size 305 | 306 | self.dec = Generator(spec_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes) 307 | 308 | def forward(self, x): 309 | 310 | y = self.dec(x) 311 | return y 312 | 313 | def infer(self, x, max_len=None): 314 | 315 | o = self.dec(x[:,:,:max_len]) 316 | return o 317 | 318 | -------------------------------------------------------------------------------- /vocoder/hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import vocoder.modules as modules 5 | 6 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 7 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 8 | from module.commons import * 9 | from torch.cuda.amp import autocast 10 | import torchaudio 11 | from einops import rearrange 12 | import typing as tp 13 | 14 | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): 15 | return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) 16 | 17 | class Generator(torch.nn.Module): 18 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 19 | super(Generator, self).__init__() 20 | self.num_kernels = len(resblock_kernel_sizes) 21 | self.num_upsamples = len(upsample_rates) 22 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 23 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 24 | 25 | self.ups = nn.ModuleList() 26 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 27 | self.ups.append(weight_norm( 28 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 29 | k, u, padding=(k-u)//2))) 30 | 31 | self.resblocks = nn.ModuleList() 32 | for i in range(len(self.ups)): 33 | ch = upsample_initial_channel//(2**(i+1)) 34 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 35 | self.resblocks.append(resblock(ch, k, d)) 36 | 37 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 38 | self.ups.apply(init_weights) 39 | 40 | if gin_channels != 0: 41 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 42 | 43 | def forward(self, x, g=None): 44 | x = self.conv_pre(x) 45 | if g is not None: 46 | x = x + self.cond(g) 47 | 48 | for i in range(self.num_upsamples): 49 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 50 | x = self.ups[i](x) 51 | xs = None 52 | for j in range(self.num_kernels): 53 | if xs is None: 54 | xs = self.resblocks[i*self.num_kernels+j](x) 55 | else: 56 | xs += self.resblocks[i*self.num_kernels+j](x) 57 | x = xs / self.num_kernels 58 | x = F.leaky_relu(x) 59 | x = self.conv_post(x) 60 | x = torch.tanh(x) 61 | 62 | return x 63 | 64 | def remove_weight_norm(self): 65 | print('Removing weight norm...') 66 | for l in self.ups: 67 | remove_weight_norm(l) 68 | for l in self.resblocks: 69 | l.remove_weight_norm() 70 | 71 | class DiscriminatorS(torch.nn.Module): 72 | def __init__(self, use_spectral_norm=False): 73 | super(DiscriminatorS, self).__init__() 74 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 75 | self.convs = nn.ModuleList([ 76 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 77 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 78 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 79 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 80 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 81 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 82 | ]) 83 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 84 | 85 | def forward(self, x): 86 | fmap = [] 87 | 88 | for l in self.convs: 89 | x = l(x) 90 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 91 | fmap.append(x) 92 | x = self.conv_post(x) 93 | fmap.append(x) 94 | x = torch.flatten(x, 1, -1) 95 | 96 | return x, fmap 97 | 98 | class DiscriminatorP(torch.nn.Module): 99 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 100 | super(DiscriminatorP, self).__init__() 101 | self.period = period 102 | self.use_spectral_norm = use_spectral_norm 103 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 104 | self.convs = nn.ModuleList([ 105 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 106 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 107 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 108 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 109 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 110 | ]) 111 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 112 | 113 | def forward(self, x): 114 | fmap = [] 115 | 116 | # 1d to 2d 117 | b, c, t = x.shape 118 | if t % self.period != 0: # pad first 119 | n_pad = self.period - (t % self.period) 120 | x = F.pad(x, (0, n_pad), "reflect") 121 | t = t + n_pad 122 | x = x.view(b, c, t // self.period, self.period) 123 | 124 | for l in self.convs: 125 | x = l(x) 126 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 127 | fmap.append(x) 128 | x = self.conv_post(x) 129 | fmap.append(x) 130 | x = torch.flatten(x, 1, -1) 131 | 132 | return x, fmap 133 | 134 | class DiscriminatorR(torch.nn.Module): 135 | def __init__(self, resolution, use_spectral_norm=False): 136 | super(DiscriminatorR, self).__init__() 137 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 138 | 139 | n_fft, hop_length, win_length = resolution 140 | self.spec_transform = torchaudio.transforms.Spectrogram( 141 | n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=torch.hann_window, 142 | normalized=True, center=False, pad_mode=None, power=None) 143 | 144 | self.convs = nn.ModuleList([ 145 | norm_f(nn.Conv2d(2, 32, (3, 9), padding=(1, 4))), 146 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), 147 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(2,1), padding=(2, 4))), 148 | norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), dilation=(4,1), padding=(4, 4))), 149 | norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), 150 | ]) 151 | self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) 152 | 153 | def forward(self, y): 154 | fmap = [] 155 | 156 | x = self.spec_transform(y) # [B, 2, Freq, Frames, 2] 157 | x = torch.cat([x.real, x.imag], dim=1) 158 | x = rearrange(x, 'b c w t -> b c t w') 159 | 160 | for l in self.convs: 161 | x = l(x) 162 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 163 | fmap.append(x) 164 | x = self.conv_post(x) 165 | fmap.append(x) 166 | x = torch.flatten(x, 1, -1) 167 | 168 | return x, fmap 169 | 170 | 171 | class MultiPeriodDiscriminator(torch.nn.Module): 172 | def __init__(self, use_spectral_norm=False): 173 | super(MultiPeriodDiscriminator, self).__init__() 174 | # periods = [2,3,5,7,11] 175 | # resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]] 176 | resolutions = [[2048, 512, 2048], [1024, 256, 1024], [512, 128, 512], [256, 64, 256], [128, 32, 128]] 177 | 178 | discs = [DiscriminatorR(resolutions[i], use_spectral_norm=use_spectral_norm) for i in range(len(resolutions))] 179 | # discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 180 | # discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 181 | self.discriminators = nn.ModuleList(discs) 182 | 183 | def forward(self, y, y_hat): 184 | y_d_rs = [] 185 | y_d_gs = [] 186 | fmap_rs = [] 187 | fmap_gs = [] 188 | for i, d in enumerate(self.discriminators): 189 | y_d_r, fmap_r = d(y) 190 | y_d_g, fmap_g = d(y_hat) 191 | y_d_rs.append(y_d_r) 192 | y_d_gs.append(y_d_g) 193 | fmap_rs.append(fmap_r) 194 | fmap_gs.append(fmap_g) 195 | 196 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 197 | 198 | class HiFi(nn.Module): 199 | """ 200 | Synthesizer for Training 201 | """ 202 | 203 | def __init__(self, 204 | 205 | spec_channels, 206 | segment_size, 207 | inter_channels, 208 | hidden_channels, 209 | filter_channels, 210 | n_heads, 211 | n_layers, 212 | kernel_size, 213 | p_dropout, 214 | resblock, 215 | resblock_kernel_sizes, 216 | resblock_dilation_sizes, 217 | upsample_rates, 218 | upsample_initial_channel, 219 | upsample_kernel_sizes, 220 | **kwargs): 221 | 222 | super().__init__() 223 | self.spec_channels = spec_channels 224 | self.inter_channels = inter_channels 225 | self.hidden_channels = hidden_channels 226 | self.filter_channels = filter_channels 227 | self.n_heads = n_heads 228 | self.n_layers = n_layers 229 | self.kernel_size = kernel_size 230 | self.p_dropout = p_dropout 231 | self.resblock = resblock 232 | self.resblock_kernel_sizes = resblock_kernel_sizes 233 | self.resblock_dilation_sizes = resblock_dilation_sizes 234 | self.upsample_rates = upsample_rates 235 | self.upsample_initial_channel = upsample_initial_channel 236 | self.upsample_kernel_sizes = upsample_kernel_sizes 237 | self.segment_size = segment_size 238 | 239 | self.dec = Generator(spec_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes) 240 | 241 | def forward(self, x): 242 | 243 | y = self.dec(x) 244 | return y 245 | 246 | def infer(self, x, max_len=None): 247 | 248 | o = self.dec(x[:,:,:max_len]) 249 | return o 250 | 251 | -------------------------------------------------------------------------------- /vocoder/modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | import torchaudio.transforms as T 10 | 11 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 12 | from torch.nn.utils import weight_norm, remove_weight_norm 13 | 14 | from module.commons import * 15 | from module.commons import init_weights, get_padding 16 | from torch.cuda.amp import autocast 17 | 18 | LRELU_SLOPE = 0.1 19 | 20 | DEFAULT_MIN_BIN_WIDTH = 1e-3 21 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 22 | DEFAULT_MIN_DERIVATIVE = 1e-3 23 | 24 | def piecewise_rational_quadratic_transform(inputs, 25 | unnormalized_widths, 26 | unnormalized_heights, 27 | unnormalized_derivatives, 28 | inverse=False, 29 | tails=None, 30 | tail_bound=1., 31 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 32 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 33 | min_derivative=DEFAULT_MIN_DERIVATIVE): 34 | 35 | if tails is None: 36 | spline_fn = rational_quadratic_spline 37 | spline_kwargs = {} 38 | else: 39 | spline_fn = unconstrained_rational_quadratic_spline 40 | spline_kwargs = { 41 | 'tails': tails, 42 | 'tail_bound': tail_bound 43 | } 44 | 45 | outputs, logabsdet = spline_fn( 46 | inputs=inputs, 47 | unnormalized_widths=unnormalized_widths, 48 | unnormalized_heights=unnormalized_heights, 49 | unnormalized_derivatives=unnormalized_derivatives, 50 | inverse=inverse, 51 | min_bin_width=min_bin_width, 52 | min_bin_height=min_bin_height, 53 | min_derivative=min_derivative, 54 | **spline_kwargs 55 | ) 56 | return outputs, logabsdet 57 | 58 | 59 | def searchsorted(bin_locations, inputs, eps=1e-6): 60 | bin_locations[..., -1] += eps 61 | return torch.sum( 62 | inputs[..., None] >= bin_locations, 63 | dim=-1 64 | ) - 1 65 | 66 | 67 | def unconstrained_rational_quadratic_spline(inputs, 68 | unnormalized_widths, 69 | unnormalized_heights, 70 | unnormalized_derivatives, 71 | inverse=False, 72 | tails='linear', 73 | tail_bound=1., 74 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 75 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 76 | min_derivative=DEFAULT_MIN_DERIVATIVE): 77 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 78 | outside_interval_mask = ~inside_interval_mask 79 | 80 | outputs = torch.zeros_like(inputs) 81 | logabsdet = torch.zeros_like(inputs) 82 | 83 | if tails == 'linear': 84 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 85 | constant = np.log(np.exp(1 - min_derivative) - 1) 86 | unnormalized_derivatives[..., 0] = constant 87 | unnormalized_derivatives[..., -1] = constant 88 | 89 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 90 | logabsdet[outside_interval_mask] = 0 91 | else: 92 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 93 | 94 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 95 | inputs=inputs[inside_interval_mask], 96 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 97 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 98 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 99 | inverse=inverse, 100 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 101 | min_bin_width=min_bin_width, 102 | min_bin_height=min_bin_height, 103 | min_derivative=min_derivative 104 | ) 105 | 106 | return outputs, logabsdet 107 | 108 | def rational_quadratic_spline(inputs, 109 | unnormalized_widths, 110 | unnormalized_heights, 111 | unnormalized_derivatives, 112 | inverse=False, 113 | left=0., right=1., bottom=0., top=1., 114 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 115 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 116 | min_derivative=DEFAULT_MIN_DERIVATIVE): 117 | if torch.min(inputs) < left or torch.max(inputs) > right: 118 | raise ValueError('Input to a transform is not within its domain') 119 | 120 | num_bins = unnormalized_widths.shape[-1] 121 | 122 | if min_bin_width * num_bins > 1.0: 123 | raise ValueError('Minimal bin width too large for the number of bins') 124 | if min_bin_height * num_bins > 1.0: 125 | raise ValueError('Minimal bin height too large for the number of bins') 126 | 127 | widths = F.softmax(unnormalized_widths, dim=-1) 128 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 129 | cumwidths = torch.cumsum(widths, dim=-1) 130 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 131 | cumwidths = (right - left) * cumwidths + left 132 | cumwidths[..., 0] = left 133 | cumwidths[..., -1] = right 134 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 135 | 136 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 137 | 138 | heights = F.softmax(unnormalized_heights, dim=-1) 139 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 140 | cumheights = torch.cumsum(heights, dim=-1) 141 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 142 | cumheights = (top - bottom) * cumheights + bottom 143 | cumheights[..., 0] = bottom 144 | cumheights[..., -1] = top 145 | heights = cumheights[..., 1:] - cumheights[..., :-1] 146 | 147 | if inverse: 148 | bin_idx = searchsorted(cumheights, inputs)[..., None] 149 | else: 150 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 151 | 152 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 153 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 154 | 155 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 156 | delta = heights / widths 157 | input_delta = delta.gather(-1, bin_idx)[..., 0] 158 | 159 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 160 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 161 | 162 | input_heights = heights.gather(-1, bin_idx)[..., 0] 163 | 164 | if inverse: 165 | a = (((inputs - input_cumheights) * (input_derivatives 166 | + input_derivatives_plus_one 167 | - 2 * input_delta) 168 | + input_heights * (input_delta - input_derivatives))) 169 | b = (input_heights * input_derivatives 170 | - (inputs - input_cumheights) * (input_derivatives 171 | + input_derivatives_plus_one 172 | - 2 * input_delta)) 173 | c = - input_delta * (inputs - input_cumheights) 174 | 175 | discriminant = b.pow(2) - 4 * a * c 176 | assert (discriminant >= 0).all() 177 | 178 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 179 | outputs = root * input_bin_widths + input_cumwidths 180 | 181 | theta_one_minus_theta = root * (1 - root) 182 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 183 | * theta_one_minus_theta) 184 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 185 | + 2 * input_delta * theta_one_minus_theta 186 | + input_derivatives * (1 - root).pow(2)) 187 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 188 | 189 | return outputs, -logabsdet 190 | else: 191 | theta = (inputs - input_cumwidths) / input_bin_widths 192 | theta_one_minus_theta = theta * (1 - theta) 193 | 194 | numerator = input_heights * (input_delta * theta.pow(2) 195 | + input_derivatives * theta_one_minus_theta) 196 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 197 | * theta_one_minus_theta) 198 | outputs = input_cumheights + numerator / denominator 199 | 200 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 201 | + 2 * input_delta * theta_one_minus_theta 202 | + input_derivatives * (1 - theta).pow(2)) 203 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 204 | 205 | return outputs, logabsdet 206 | 207 | class LayerNorm(nn.Module): 208 | def __init__(self, channels, eps=1e-5): 209 | super().__init__() 210 | self.channels = channels 211 | self.eps = eps 212 | 213 | self.gamma = nn.Parameter(torch.ones(channels)) 214 | self.beta = nn.Parameter(torch.zeros(channels)) 215 | 216 | def forward(self, x): 217 | x = x.transpose(1, -1) 218 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 219 | return x.transpose(1, -1) 220 | 221 | 222 | class ConvReluNorm(nn.Module): 223 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 224 | super().__init__() 225 | self.in_channels = in_channels 226 | self.hidden_channels = hidden_channels 227 | self.out_channels = out_channels 228 | self.kernel_size = kernel_size 229 | self.n_layers = n_layers 230 | self.p_dropout = p_dropout 231 | assert n_layers > 1, "Number of layers should be larger than 0." 232 | 233 | self.conv_layers = nn.ModuleList() 234 | self.norm_layers = nn.ModuleList() 235 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 236 | self.norm_layers.append(LayerNorm(hidden_channels)) 237 | self.relu_drop = nn.Sequential( 238 | nn.ReLU(), 239 | nn.Dropout(p_dropout)) 240 | for _ in range(n_layers - 1): 241 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 242 | self.norm_layers.append(LayerNorm(hidden_channels)) 243 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 244 | self.proj.weight.data.zero_() 245 | self.proj.bias.data.zero_() 246 | 247 | def forward(self, x, x_mask): 248 | x_org = x 249 | for i in range(self.n_layers): 250 | x = self.conv_layers[i](x * x_mask) 251 | x = self.norm_layers[i](x) 252 | x = self.relu_drop(x) 253 | x = x_org + self.proj(x) 254 | return x * x_mask 255 | 256 | 257 | class DDSConv(nn.Module): 258 | """ 259 | Dialted and Depth-Separable Convolution 260 | """ 261 | 262 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 263 | super().__init__() 264 | self.channels = channels 265 | self.kernel_size = kernel_size 266 | self.n_layers = n_layers 267 | self.p_dropout = p_dropout 268 | 269 | self.drop = nn.Dropout(p_dropout) 270 | self.convs_sep = nn.ModuleList() 271 | self.convs_1x1 = nn.ModuleList() 272 | self.norms_1 = nn.ModuleList() 273 | self.norms_2 = nn.ModuleList() 274 | for i in range(n_layers): 275 | dilation = kernel_size ** i 276 | padding = (kernel_size * dilation - dilation) // 2 277 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 278 | groups=channels, dilation=dilation, padding=padding 279 | )) 280 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 281 | self.norms_1.append(LayerNorm(channels)) 282 | self.norms_2.append(LayerNorm(channels)) 283 | 284 | def forward(self, x, x_mask, g=None): 285 | if g is not None: 286 | x = x + g 287 | for i in range(self.n_layers): 288 | y = self.convs_sep[i](x * x_mask) 289 | y = self.norms_1[i](y) 290 | y = F.gelu(y) 291 | y = self.convs_1x1[i](y) 292 | y = self.norms_2[i](y) 293 | y = F.gelu(y) 294 | y = self.drop(y) 295 | x = x + y 296 | return x * x_mask 297 | 298 | 299 | class WN(torch.nn.Module): 300 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 301 | super(WN, self).__init__() 302 | assert (kernel_size % 2 == 1) 303 | self.hidden_channels = hidden_channels 304 | self.kernel_size = kernel_size, 305 | self.dilation_rate = dilation_rate 306 | self.n_layers = n_layers 307 | self.gin_channels = gin_channels 308 | self.p_dropout = p_dropout 309 | 310 | self.in_layers = torch.nn.ModuleList() 311 | self.res_skip_layers = torch.nn.ModuleList() 312 | self.drop = nn.Dropout(p_dropout) 313 | 314 | if gin_channels != 0: 315 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 316 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 317 | 318 | for i in range(n_layers): 319 | dilation = dilation_rate ** i 320 | padding = int((kernel_size * dilation - dilation) / 2) 321 | in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, 322 | dilation=dilation, padding=padding) 323 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 324 | self.in_layers.append(in_layer) 325 | 326 | # last one is not necessary 327 | if i < n_layers - 1: 328 | res_skip_channels = 2 * hidden_channels 329 | else: 330 | res_skip_channels = hidden_channels 331 | 332 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 333 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 334 | self.res_skip_layers.append(res_skip_layer) 335 | 336 | def forward(self, x, x_mask, g=None, **kwargs): 337 | output = torch.zeros_like(x) 338 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 339 | 340 | if g is not None: 341 | g = self.cond_layer(g) 342 | 343 | for i in range(self.n_layers): 344 | x_in = self.in_layers[i](x) 345 | if g is not None: 346 | cond_offset = i * 2 * self.hidden_channels 347 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] 348 | else: 349 | g_l = torch.zeros_like(x_in) 350 | 351 | acts = commons.fused_add_tanh_sigmoid_multiply( 352 | x_in, 353 | g_l, 354 | n_channels_tensor) 355 | acts = self.drop(acts) 356 | 357 | res_skip_acts = self.res_skip_layers[i](acts) 358 | if i < self.n_layers - 1: 359 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 360 | x = (x + res_acts) * x_mask 361 | output = output + res_skip_acts[:, self.hidden_channels:, :] 362 | else: 363 | output = output + res_skip_acts 364 | return output * x_mask 365 | 366 | def remove_weight_norm(self): 367 | if self.gin_channels != 0: 368 | torch.nn.utils.remove_weight_norm(self.cond_layer) 369 | for l in self.in_layers: 370 | torch.nn.utils.remove_weight_norm(l) 371 | for l in self.res_skip_layers: 372 | torch.nn.utils.remove_weight_norm(l) 373 | 374 | 375 | class AMPBlock(torch.nn.Module): 376 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), rank=0): 377 | super(AMPBlock, self).__init__() 378 | 379 | self.convs1 = nn.ModuleList([ 380 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 381 | padding=get_padding(kernel_size, dilation[0]))), 382 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 383 | padding=get_padding(kernel_size, dilation[1]))), 384 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 385 | padding=get_padding(kernel_size, dilation[2]))) 386 | ]) 387 | self.convs1.apply(init_weights) 388 | self.alpha1 = nn.Parameter(torch.ones(1, channels, 1).to(rank)) 389 | self.alpha2 = nn.Parameter(torch.ones(1, channels, 1).to(rank)) 390 | 391 | self.convs2 = nn.ModuleList([ 392 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 393 | padding=get_padding(kernel_size, 1))), 394 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 395 | padding=get_padding(kernel_size, 1))), 396 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 397 | padding=get_padding(kernel_size, 1))) 398 | ]) 399 | self.convs2.apply(init_weights) 400 | 401 | def forward(self, x, x_mask=None): 402 | for c1, c2 in zip(self.convs1, self.convs2): 403 | 404 | xt = x + (1 / self.alpha1) * (torch.sin(self.alpha1 * x) ** 2) # Snake1D 405 | 406 | if x_mask is not None: 407 | xt = xt * x_mask 408 | xt = c1(xt) 409 | 410 | xt = x + (1 / self.alpha2) * (torch.sin(self.alpha2 * xt) ** 2) # Snake1D 411 | 412 | if x_mask is not None: 413 | xt = xt * x_mask 414 | xt = c2(xt) 415 | x = xt + x 416 | if x_mask is not None: 417 | x = x * x_mask 418 | return x 419 | 420 | def remove_weight_norm(self): 421 | for l in self.convs1: 422 | remove_weight_norm(l) 423 | for l in self.convs2: 424 | remove_weight_norm(l) 425 | 426 | 427 | class AMPBlock_filter(torch.nn.Module): 428 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), rank=0, orig_freq=None, rolloff=0.25): 429 | super(AMPBlock_filter, self).__init__() 430 | 431 | self.upsampling_with_lfilter = T.Resample(orig_freq=orig_freq, new_freq=orig_freq * 2, 432 | resampling_method='kaiser_window', 433 | lowpass_filter_width=12, 434 | rolloff=rolloff, 435 | beta=4.663800127934911 436 | ) 437 | self.downsampling_with_lfilter = T.Resample(orig_freq=orig_freq * 2, new_freq=orig_freq, 438 | resampling_method='kaiser_window', 439 | lowpass_filter_width=12, 440 | rolloff=rolloff, 441 | beta=4.663800127934911 442 | ) 443 | self.convs1 = nn.ModuleList([ 444 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 445 | padding=get_padding(kernel_size, dilation[0]))), 446 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 447 | padding=get_padding(kernel_size, dilation[1]))), 448 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 449 | padding=get_padding(kernel_size, dilation[2]))) 450 | ]) 451 | self.convs1.apply(init_weights) 452 | 453 | self.convs2 = nn.ModuleList([ 454 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 455 | padding=get_padding(kernel_size, 1))), 456 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 457 | padding=get_padding(kernel_size, 1))), 458 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 459 | padding=get_padding(kernel_size, 1))) 460 | ]) 461 | self.convs2.apply(init_weights) 462 | 463 | self.alpha1 = [nn.Parameter(torch.ones(1, channels, 1).to(rank)) for i in range(len(self.convs1))] 464 | self.alpha2 = [nn.Parameter(torch.ones(1, channels, 1).to(rank)) for i in range(len(self.convs2))] 465 | 466 | def forward(self, x, x_mask=None): 467 | for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.alpha1, self.alpha2): 468 | 469 | with autocast(enabled=False): 470 | xt = self.upsampling_with_lfilter(x.float()) 471 | xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D 472 | xt = self.downsampling_with_lfilter(xt) 473 | 474 | if x_mask is not None: 475 | xt = xt * x_mask 476 | xt = c1(xt) 477 | 478 | with autocast(enabled=False): 479 | xt = self.upsampling_with_lfilter(xt.float()) 480 | xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D 481 | xt = self.downsampling_with_lfilter(xt) 482 | 483 | if x_mask is not None: 484 | xt = xt * x_mask 485 | xt = c2(xt) 486 | x = xt + x 487 | if x_mask is not None: 488 | x = x * x_mask 489 | return x 490 | 491 | def remove_weight_norm(self): 492 | for l in self.convs1: 493 | remove_weight_norm(l) 494 | for l in self.convs2: 495 | remove_weight_norm(l) 496 | 497 | 498 | class ResBlock1(torch.nn.Module): 499 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 500 | super(ResBlock1, self).__init__() 501 | self.convs1 = nn.ModuleList([ 502 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 503 | padding=get_padding(kernel_size, dilation[0]))), 504 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 505 | padding=get_padding(kernel_size, dilation[1]))), 506 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 507 | padding=get_padding(kernel_size, dilation[2]))) 508 | ]) 509 | self.convs1.apply(init_weights) 510 | 511 | self.convs2 = nn.ModuleList([ 512 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 513 | padding=get_padding(kernel_size, 1))), 514 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 515 | padding=get_padding(kernel_size, 1))), 516 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 517 | padding=get_padding(kernel_size, 1))) 518 | ]) 519 | self.convs2.apply(init_weights) 520 | 521 | def forward(self, x, x_mask=None): 522 | for c1, c2 in zip(self.convs1, self.convs2): 523 | xt = F.leaky_relu(x, LRELU_SLOPE) 524 | if x_mask is not None: 525 | xt = xt * x_mask 526 | xt = c1(xt) 527 | xt = F.leaky_relu(xt, LRELU_SLOPE) 528 | if x_mask is not None: 529 | xt = xt * x_mask 530 | xt = c2(xt) 531 | x = xt + x 532 | if x_mask is not None: 533 | x = x * x_mask 534 | return x 535 | 536 | def remove_weight_norm(self): 537 | for l in self.convs1: 538 | remove_weight_norm(l) 539 | for l in self.convs2: 540 | remove_weight_norm(l) 541 | 542 | 543 | class ResBlock2(torch.nn.Module): 544 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 545 | super(ResBlock2, self).__init__() 546 | self.convs = nn.ModuleList([ 547 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 548 | padding=get_padding(kernel_size, dilation[0]))), 549 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 550 | padding=get_padding(kernel_size, dilation[1]))) 551 | ]) 552 | self.convs.apply(init_weights) 553 | 554 | def forward(self, x, x_mask=None): 555 | for c in self.convs: 556 | xt = F.leaky_relu(x, LRELU_SLOPE) 557 | if x_mask is not None: 558 | xt = xt * x_mask 559 | xt = c(xt) 560 | x = xt + x 561 | if x_mask is not None: 562 | x = x * x_mask 563 | return x 564 | 565 | def remove_weight_norm(self): 566 | for l in self.convs: 567 | remove_weight_norm(l) 568 | 569 | 570 | class Log(nn.Module): 571 | def forward(self, x, x_mask, reverse=False, **kwargs): 572 | if not reverse: 573 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 574 | logdet = torch.sum(-y, [1, 2]) 575 | return y, logdet 576 | else: 577 | x = torch.exp(x) * x_mask 578 | return x 579 | 580 | 581 | class Flip(nn.Module): 582 | def forward(self, x, *args, reverse=False, **kwargs): 583 | x = torch.flip(x, [1]) 584 | if not reverse: 585 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 586 | return x, logdet 587 | else: 588 | return x 589 | 590 | 591 | class ElementwiseAffine(nn.Module): 592 | def __init__(self, channels): 593 | super().__init__() 594 | self.channels = channels 595 | self.m = nn.Parameter(torch.zeros(channels, 1)) 596 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 597 | 598 | def forward(self, x, x_mask, reverse=False, **kwargs): 599 | if not reverse: 600 | y = self.m + torch.exp(self.logs) * x 601 | y = y * x_mask 602 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 603 | return y, logdet 604 | else: 605 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 606 | return x 607 | 608 | 609 | class ResidualCouplingLayer(nn.Module): 610 | def __init__(self, 611 | channels, 612 | hidden_channels, 613 | kernel_size, 614 | dilation_rate, 615 | n_layers, 616 | p_dropout=0, 617 | gin_channels=0, 618 | mean_only=False): 619 | assert channels % 2 == 0, "channels should be divisible by 2" 620 | super().__init__() 621 | self.channels = channels 622 | self.hidden_channels = hidden_channels 623 | self.kernel_size = kernel_size 624 | self.dilation_rate = dilation_rate 625 | self.n_layers = n_layers 626 | self.half_channels = channels // 2 627 | self.mean_only = mean_only 628 | 629 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 630 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, 631 | gin_channels=gin_channels) 632 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 633 | self.post.weight.data.zero_() 634 | self.post.bias.data.zero_() 635 | 636 | def forward(self, x, x_mask, g=None, reverse=False): 637 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 638 | h = self.pre(x0) * x_mask 639 | h = self.enc(h, x_mask, g=g) 640 | stats = self.post(h) * x_mask 641 | if not self.mean_only: 642 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 643 | else: 644 | m = stats 645 | logs = torch.zeros_like(m) 646 | 647 | if not reverse: 648 | x1 = m + x1 * torch.exp(logs) * x_mask 649 | x = torch.cat([x0, x1], 1) 650 | logdet = torch.sum(logs, [1, 2]) 651 | return x, logdet 652 | else: 653 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 654 | x = torch.cat([x0, x1], 1) 655 | return x 656 | 657 | 658 | class ConvFlow(nn.Module): 659 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 660 | super().__init__() 661 | self.in_channels = in_channels 662 | self.filter_channels = filter_channels 663 | self.kernel_size = kernel_size 664 | self.n_layers = n_layers 665 | self.num_bins = num_bins 666 | self.tail_bound = tail_bound 667 | self.half_channels = in_channels // 2 668 | 669 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 670 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 671 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 672 | self.proj.weight.data.zero_() 673 | self.proj.bias.data.zero_() 674 | 675 | def forward(self, x, x_mask, g=None, reverse=False): 676 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 677 | h = self.pre(x0) 678 | h = self.convs(h, x_mask, g=g) 679 | h = self.proj(h) * x_mask 680 | 681 | b, c, t = x0.shape 682 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 683 | 684 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 685 | unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) 686 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 687 | 688 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 689 | unnormalized_widths, 690 | unnormalized_heights, 691 | unnormalized_derivatives, 692 | inverse=reverse, 693 | tails='linear', 694 | tail_bound=self.tail_bound 695 | ) 696 | 697 | x = torch.cat([x0, x1], 1) * x_mask 698 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 699 | if not reverse: 700 | return x, logdet 701 | else: 702 | return x 703 | --------------------------------------------------------------------------------