├── usfgan ├── bin │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── compute_statistics.yaml │ │ ├── train.yaml │ │ ├── decode.yaml │ │ ├── discriminator │ │ │ ├── pwg.yaml │ │ │ ├── univnet.yaml │ │ │ └── hifigan.yaml │ │ ├── generator │ │ │ ├── usfgan.yaml │ │ │ ├── cascade_hn_usfgan.yaml │ │ │ └── parallel_hn_usfgan.yaml │ │ ├── anasyn.yaml │ │ ├── extract_features.yaml │ │ ├── data │ │ │ └── vctk_24kHz.yaml │ │ └── train │ │ │ ├── usfgan.yaml │ │ │ └── hn_usfgan.yaml │ ├── compute_statistics.py │ ├── decode.py │ ├── anasyn.py │ ├── extract_features.py │ └── train.py ├── __init__.py ├── datasets │ ├── __init__.py │ └── audio_feat_dataset.py ├── optimizers │ ├── __init__.py │ └── radam.py ├── models │ ├── __init__.py │ └── generator.py ├── layers │ ├── __init__.py │ ├── upsample.py │ ├── cheaptrick.py │ └── residual_block.py ├── utils │ ├── __init__.py │ ├── filters.py │ ├── index.py │ ├── utils.py │ └── features.py └── losses │ ├── __init__.py │ ├── feat_match.py │ ├── adversarial.py │ ├── source.py │ └── stft.py ├── .gitignore ├── egs └── vctk │ └── data │ ├── scp │ ├── vctk_train_24kHz.scp │ └── vctk_train_24kHz.list │ └── spk_info.yaml ├── LICENSE ├── setup.py └── README.md /usfgan/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /usfgan/bin/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /usfgan/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __version__ = "0.0.1.post1" 4 | -------------------------------------------------------------------------------- /usfgan/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from usfgan.datasets.audio_feat_dataset import * # NOQA 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .eggs 3 | *.egg-info 4 | *.log 5 | *.out 6 | .venv 7 | 8 | egs/vctk/exp/ -------------------------------------------------------------------------------- /usfgan/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.optim import * # NOQA 2 | from usfgan.optimizers.radam import * # NOQA 3 | -------------------------------------------------------------------------------- /usfgan/models/__init__.py: -------------------------------------------------------------------------------- 1 | from usfgan.models.discriminator import * # NOQA 2 | from usfgan.models.generator import * # NOQA 3 | -------------------------------------------------------------------------------- /usfgan/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from usfgan.layers.cheaptrick import * # NOQA 2 | from usfgan.layers.residual_block import * # NOQA 3 | from usfgan.layers.upsample import * # NOQA 4 | -------------------------------------------------------------------------------- /usfgan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from usfgan.utils.features import * # NOQA 2 | from usfgan.utils.filters import * # NOQA 3 | from usfgan.utils.index import * # NOQA 4 | from usfgan.utils.utils import * # NOQA 5 | -------------------------------------------------------------------------------- /usfgan/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from usfgan.losses.adversarial import * # NOQA 2 | from usfgan.losses.feat_match import * # NOQA 3 | from usfgan.losses.source import * # NOQA 4 | from usfgan.losses.stft import * # NOQA 5 | -------------------------------------------------------------------------------- /usfgan/bin/config/compute_statistics.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | job: 5 | chdir: false 6 | output_subdir: null 7 | job_logging: 8 | formatters: 9 | simple: 10 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s' 11 | 12 | feats: data/scp/vctk_train_24kHz.list # List file of input features. 13 | stats: data/stats/vctk_train_24kHz.joblib # Path to file to output statistics. 14 | feat_types: ['f0', 'contf0', 'mcep', 'mcap'] # Feature types. 15 | -------------------------------------------------------------------------------- /usfgan/bin/config/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - generator: parallel_hn_usfgan 6 | - discriminator: hifigan 7 | - train: hn_usfgan 8 | - data: vctk_24kHz 9 | 10 | hydra: 11 | job: 12 | chdir: false 13 | output_subdir: null 14 | job_logging: 15 | formatters: 16 | simple: 17 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s' 18 | 19 | out_dir: # Directory to output training results. 20 | seed: 12345 # Seed number for random numbers. 21 | -------------------------------------------------------------------------------- /egs/vctk/data/scp/vctk_train_24kHz.scp: -------------------------------------------------------------------------------- 1 | path_to_your_own_dataset_dir/wav/p225/p225_001.wav 2 | path_to_your_own_dataset_dir/wav/p225/p225_002.wav 3 | path_to_your_own_dataset_dir/wav/p225/p225_003.wav 4 | path_to_your_own_dataset_dir/wav/p225/p225_004.wav 5 | path_to_your_own_dataset_dir/wav/p225/p225_005.wav 6 | path_to_your_own_dataset_dir/wav/p225/p225_006.wav 7 | path_to_your_own_dataset_dir/wav/p225/p225_007.wav 8 | path_to_your_own_dataset_dir/wav/p225/p225_008.wav 9 | path_to_your_own_dataset_dir/wav/p225/p225_009.wav 10 | path_to_your_own_dataset_dir/wav/p225/p225_010.wav -------------------------------------------------------------------------------- /egs/vctk/data/scp/vctk_train_24kHz.list: -------------------------------------------------------------------------------- 1 | path_to_your_own_dataset_dir/hdf5/p225/p225_001.h5 2 | path_to_your_own_dataset_dir/hdf5/p225/p225_002.h5 3 | path_to_your_own_dataset_dir/hdf5/p225/p225_003.h5 4 | path_to_your_own_dataset_dir/hdf5/p225/p225_004.h5 5 | path_to_your_own_dataset_dir/hdf5/p225/p225_005.h5 6 | path_to_your_own_dataset_dir/hdf5/p225/p225_006.h5 7 | path_to_your_own_dataset_dir/hdf5/p225/p225_007.h5 8 | path_to_your_own_dataset_dir/hdf5/p225/p225_008.h5 9 | path_to_your_own_dataset_dir/hdf5/p225/p225_009.h5 10 | path_to_your_own_dataset_dir/hdf5/p225/p225_010.h5 -------------------------------------------------------------------------------- /usfgan/bin/config/decode.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - generator: parallel_hn_usfgan 6 | - data: vctk_24kHz 7 | 8 | hydra: 9 | job: 10 | chdir: false 11 | output_subdir: null 12 | job_logging: 13 | formatters: 14 | simple: 15 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s' 16 | 17 | out_dir: # Directory to output decoding results. 18 | checkpoint_path: # Path to the checkpoint of pre-trained model. 19 | f0_factor: 1.00 # F0 factor. 20 | seed: 100 # Seed number for random numbers. 21 | save_source: false # Whether to save source excitation signals. -------------------------------------------------------------------------------- /usfgan/bin/config/discriminator/pwg.yaml: -------------------------------------------------------------------------------- 1 | _target_: usfgan.models.PWGDiscriminator 2 | in_channels: 1 # Number of input channels. 3 | out_channels: 1 # Number of output channels. 4 | kernel_size: 3 # Number of output channels. 5 | layers: 10 # Number of conv layers. 6 | conv_channels: 64 # Number of cnn layers. 7 | dilation_factor: 2 # Dilation factor. 8 | bias: true # Whether to use bias parameter in conv. 9 | use_weight_norm: true # Whether to use weight norm. 10 | # If set to true, it will be applied to all of the conv layers. 11 | nonlinear_activation: LeakyReLU # Nonlinear function after each conv. 12 | nonlinear_activation_params: # Nonlinear function parameters 13 | negative_slope: 0.2 # Alpha in LeakyReLU. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Reo YONEYAMA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /usfgan/losses/feat_match.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Feature matching loss module. 7 | 8 | References: 9 | - https://github.com/kan-bayashi/ParallelWaveGAN 10 | 11 | """ 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class FeatureMatchLoss(nn.Module): 18 | # Feature matching loss module. 19 | 20 | def __init__( 21 | self, 22 | average_by_layers=False, 23 | ): 24 | """Initialize FeatureMatchLoss module.""" 25 | super(FeatureMatchLoss, self).__init__() 26 | self.average_by_layers = average_by_layers 27 | 28 | def forward(self, fmaps_fake, fmaps_real): 29 | """Calculate forward propagation. 30 | 31 | Args: 32 | fmaps_fake (list): List of discriminator outputs 33 | calcuated from generater outputs. 34 | fmaps_real (list): List of discriminator outputs 35 | calcuated from groundtruth. 36 | 37 | Returns: 38 | Tensor: Feature matching loss value. 39 | 40 | """ 41 | feat_match_loss = 0.0 42 | for feat_fake, feat_real in zip(fmaps_fake, fmaps_real): 43 | feat_match_loss += F.l1_loss(feat_fake, feat_real.detach()) 44 | 45 | if self.average_by_layers: 46 | feat_match_loss /= len(fmaps_fake) 47 | 48 | return feat_match_loss 49 | -------------------------------------------------------------------------------- /usfgan/bin/config/discriminator/univnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: usfgan.models.UnivNetMultiResolutionMultiPeriodDiscriminator 2 | fft_sizes: [1024, 2048, 512] 3 | hop_sizes: [120, 240, 50] 4 | win_lengths: [600, 1200, 240] 5 | window: "hann_window" 6 | spectral_discriminator_params: 7 | channels: 32 8 | kernel_sizes: [[3, 9], [3, 9], [3, 9], [3, 9], [3, 3], [3, 3]] 9 | strides: [[1, 1], [1, 2], [1, 2], [1, 2], [1, 1], [1, 1]] 10 | bias: true 11 | nonlinear_activation: "LeakyReLU" 12 | nonlinear_activation_params: 13 | negative_slope: 0.2 14 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 15 | period_discriminator_params: 16 | in_channels: 1 # Number of input channels. 17 | out_channels: 1 # Number of output channels. 18 | kernel_sizes: [5, 3] # List of kernel sizes. 19 | channels: 32 # Initial number of channels. 20 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 21 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 22 | bias: true # Whether to use bias parameter in conv layer." 23 | nonlinear_activation: "LeakyReLU" # Nonlinear activation. 24 | nonlinear_activation_params: # Nonlinear activation paramters. 25 | negative_slope: 0.1 26 | use_weight_norm: true # Whether to apply weight normalization. 27 | use_spectral_norm: false # Whether to apply spectral normalization. 28 | -------------------------------------------------------------------------------- /usfgan/bin/config/generator/usfgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: usfgan.models.USFGANGenerator 2 | source_network_params: 3 | blockA: 30 # Number of adaptive residual blocks. 4 | cycleA: 6 # Number of adaptive dilation cycles. 5 | blockF: 0 # Number of fixed residual blocks. 6 | cycleF: 0 # Number of fixed dilation cycles. 7 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 8 | filter_network_params: 9 | blockA: 0 # Number of adaptive residual blocks. 10 | cycleA: 0 # Number of adaptive dilation cycles. 11 | blockF: 30 # Number of fixed residual blocks. 12 | cycleF: 3 # Number of fixed dilation cycles. 13 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 14 | in_channels: 1 # Number of input channels. 15 | out_channels: 1 # Number of output channels. 16 | residual_channels: 64 # Number of channels in residual conv. 17 | gate_channels: 128 # Number of channels in gated conv. 18 | skip_channels: 64 # Number of channels in skip conv. 19 | aux_channels: 62 # Number of channels for auxiliary feature conv. 20 | aux_context_window: 2 # Context window size for auxiliary feature. 21 | # If set to 2, previous 2 and future 2 frames will be considered. 22 | use_weight_norm: true # Whether to use weight norm. 23 | upsample_params: # Upsampling network parameters. 24 | upsample_scales: [5, 4, 3, 2] # Upsampling scales. Product of these must be the same as hop size. 25 | -------------------------------------------------------------------------------- /usfgan/utils/filters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Yi-Chiao Wu (Nagoya University) 4 | # based on a WaveNet script by Tomoki Hayashi (Nagoya University) 5 | # (https://github.com/kan-bayashi/PytorchWaveNetVocoder) 6 | # based on sprocket-vc script by Kazuhiro Kobayashi (Nagoya University) 7 | # (https://github.com/k2kobayashi/sprocket) 8 | # MIT License (https://opensource.org/licenses/MIT) 9 | 10 | """Filters.""" 11 | 12 | import numpy as np 13 | from scipy.signal import firwin, lfilter 14 | 15 | NUMTAPS = 255 16 | 17 | 18 | def low_cut_filter(x, fs, cutoff=70): 19 | """Low-cut filter 20 | 21 | Args: 22 | x (ndarray): Waveform sequence 23 | fs (int): Sampling frequency 24 | cutoff (float): Cutoff frequency of low cut filter 25 | 26 | Return: 27 | (ndarray): Low cut filtered waveform sequence 28 | 29 | """ 30 | nyquist = fs // 2 31 | norm_cutoff = cutoff / nyquist 32 | numtaps = NUMTAPS 33 | fil = firwin(numtaps, norm_cutoff, pass_zero=False) 34 | lcf_x = lfilter(fil, 1, x) 35 | 36 | return lcf_x 37 | 38 | 39 | def low_pass_filter(x, fs, cutoff=70): 40 | """Low-pass filter 41 | 42 | Args: 43 | x (ndarray): Waveform sequence 44 | fs (int): Sampling frequency 45 | cutoff (float): Cutoff frequency of low pass filter 46 | 47 | Return: 48 | (ndarray): Low pass filtered waveform sequence 49 | 50 | """ 51 | nyquist = fs // 2 52 | norm_cutoff = cutoff / nyquist 53 | numtaps = NUMTAPS 54 | fil = firwin(numtaps, norm_cutoff, pass_zero=True) 55 | x_pad = np.pad(x, (numtaps, numtaps), "edge") 56 | lpf_x = lfilter(fil, 1, x_pad) 57 | lpf_x = lpf_x[numtaps + numtaps // 2 : -numtaps // 2] 58 | 59 | return lpf_x 60 | -------------------------------------------------------------------------------- /usfgan/bin/config/anasyn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - generator: parallel_hn_usfgan 6 | 7 | hydra: 8 | run: 9 | dir: ./ 10 | output_subdir: null 11 | job_logging: 12 | formatters: 13 | simple: 14 | format: '[%(asctime)s][%(levelname)s][%(module)s | %(lineno)s] %(message)s' 15 | disable_existing_loggers: false 16 | 17 | in_dir: # Path to directory which include wav files you want to process. 18 | out_dir: # Path to directory to save the synthesized wavs. 19 | f0_factor: 1.00 # F0 scaling factor. 20 | seed: 100 # Seed number for random numbers. 21 | stats: pretrained_model/stats/vctk_train_24kHz.joblib # Path to statistics file. 22 | checkpoint_path: pretrained_model/checkpoint-600000steps.pkl # Path to pre-trained model. 23 | 24 | # The same parametes should be set as in the training. 25 | sample_rate: 24000 # Sampling rate. 26 | frame_period: 5 # Frameshift in ms. 27 | f0_floor: 70 # Minimum F0 for WORLD F0 analysis. 28 | f0_ceil: 500 # Maximum F0 for WORLD F0 analysis. 29 | mcep_dim: 40 # Number of dimension of MGC. 30 | mcap_dim: 20 # Number of dimension of mel-cepstral AP. 31 | aux_feats: ["mcep", "mcap"] # Input acoustic features. 32 | dense_factor: 4 # Dense factor in PDCNNs. 33 | df_f0_type: "contf0" # F0 type for dilation factor ("f0" or "cf0"). 34 | sine_amp: 0.1 # Sine amplitude. 35 | noise_amp: 0.003 # Noise amplitude. 36 | sine_f0_type: "contf0" # F0 type for sine signal ("f0" or "cf0"). 37 | signal_types: ["sine", "noise"] # List of input signal types. -------------------------------------------------------------------------------- /usfgan/bin/config/extract_features.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | job: 5 | chdir: false 6 | output_subdir: null 7 | job_logging: 8 | formatters: 9 | simple: 10 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s' 11 | 12 | audio: data/scp/vctk_train_24kHz.scp # List filr of input wav files. 13 | in_dir: wav # Directory of input feature files. 14 | out_dir: hdf5 # Directory to save generated samples. 15 | feature_format: h5 # Feature format. 16 | sampling_rate: 24000 # Sampling rate. 17 | spkinfo: data/spk_info.yaml # YAML format speaker information. 18 | spkidx: -2 # Speaker index of the split path. 19 | inv: true # If false, wav is restored from acoustic features. 20 | 21 | # Audio preprocess setting. 22 | highpass_cutoff: 70 # Cut-off-frequency for low-cut-filter. 23 | pow_th: # Threshold of power. 24 | 25 | # Mel-spectrogram extraction setting. 26 | fft_size: 1024 # FFT size. 27 | hop_size: 120 # Hop size. 28 | win_length: 1024 # Window length. 29 | # If set to null, it will be same as fft_size. 30 | window: hann # Window function. 31 | num_mels: 80 # Number of mel basis. 32 | fmin: 0 # Minimum frequency in mel basis calculation. 33 | fmax: null # Maximum frequency in mel basis calculation. 34 | 35 | # WORLD feature extraction setting. 36 | minf0: 70 # F0 setting: minimum f0. 37 | maxf0: 340 # F0 setting: maximum f0. 38 | shiftms: 5 # F0 setting: frame shift (ms). 39 | mcep_dim: 40 # Mel-cepstrum setting: number of dimension. 40 | mcap_dim: 20 # Mel-cepstrum setting: number of dimension. 41 | alpha: 0.466 # Mel-cepstrum setting: all-pass constant. 42 | -------------------------------------------------------------------------------- /usfgan/bin/config/data/vctk_24kHz.yaml: -------------------------------------------------------------------------------- 1 | # Dataset settings 2 | train_audio: data/scp/vctk_train_24kHz.scp # List file of training audio files. 3 | train_feat: data/scp/vctk_train_24kHz.list # List file of training feature files. 4 | valid_audio: data/scp/vctk_valid_24kHz.scp # List file of validation audio files. 5 | valid_feat: data/scp/vctk_valid_24kHz.list # List file of validation feature files. 6 | eval_feat: data/scp/vctk_eval_24kHz.list # List file of evaluation feature files for decoding. 7 | stats: data/stats/vctk_train_24kHz.joblib # Path to the file of statistics. 8 | allow_cache: false # Whether to allow cache in dataset. If true, it requires cpu memory 9 | 10 | # Feature settings 11 | sample_rate: 24000 # Sampling rate. 12 | hop_size: 120 # Hop size. 13 | dense_factor: 4 # Dense factor in PDCNNs. 14 | sine_amp: 0.1 # Sine amplitude. 15 | noise_amp: 0.003 # Noise amplitude. 16 | signal_types: ["sine", "noise"] # List of input signal types for generator. 17 | sine_f0_type: "contf0" # F0 type for sine signal ("f0" or "contf0"). 18 | df_f0_type: "contf0" # F0 type for dilation factor ("f0" or "contf0"). 19 | aux_feats: ["mcep", "mcap"] # Auxiliary features. 20 | # "uv": V/UV binary. 21 | # "f0": descrete f0. 22 | # "mcep": mel-cepstral envelope. 23 | # "contf0": continuous f0. 24 | # "mcap": mel-cepstral aperiodicity. 25 | # "codeap": coded aperiodicity. 26 | # "logmsp": log mel-spectrogram. 27 | 28 | # Collater setting 29 | batch_max_length: 18000 # Length of each audio in batch. Make sure dividable by hop_size. 30 | 31 | # Data loader setting 32 | batch_size: 5 # Batch size 33 | num_workers: 1 # Number of workers in Pytorch DataLoader 34 | pin_memory: true # Whether to pin memory in Pytorch DataLoader 35 | 36 | # Other setting 37 | remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_length 38 | -------------------------------------------------------------------------------- /usfgan/bin/config/discriminator/hifigan.yaml: -------------------------------------------------------------------------------- 1 | _target_: usfgan.models.HiFiGANMultiScaleMultiPeriodDiscriminator 2 | scales: 3 # Number of multi-scale discriminator. 3 | scale_downsample_pooling: "AvgPool1d" # Pooling operation for scale discriminator. 4 | scale_downsample_pooling_params: 5 | kernel_size: 4 # Pooling kernel size. 6 | stride: 2 # Pooling stride. 7 | padding: 2 # Padding size. 8 | scale_discriminator_params: 9 | in_channels: 1 # Number of input channels. 10 | out_channels: 1 # Number of output channels. 11 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. 12 | channels: 128 # Initial number of channels. 13 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 14 | max_groups: 16 # Maximum number of groups in downsampling conv layers. 15 | bias: true 16 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. 17 | nonlinear_activation: "LeakyReLU" # Nonlinear activation. 18 | nonlinear_activation_params: 19 | negative_slope: 0.1 20 | follow_official_norm: true # Whether to follow the official norm setting. 21 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. 22 | period_discriminator_params: 23 | in_channels: 1 # Number of input channels. 24 | out_channels: 1 # Number of output channels. 25 | kernel_sizes: [5, 3] # List of kernel sizes. 26 | channels: 32 # Initial number of channels. 27 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. 28 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. 29 | bias: true # Whether to use bias parameter in conv layer." 30 | nonlinear_activation: "LeakyReLU" # Nonlinear activation. 31 | nonlinear_activation_params: # Nonlinear activation paramters. 32 | negative_slope: 0.1 33 | use_weight_norm: true # Whether to apply weight normalization. 34 | use_spectral_norm: false # Whether to apply spectral normalization. 35 | -------------------------------------------------------------------------------- /usfgan/bin/config/generator/cascade_hn_usfgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: usfgan.models.CascadeHnUSFGANGenerator 2 | harmonic_network_params: 3 | blockA: 20 # Number of adaptive residual blocks. 4 | cycleA: 4 # Number of adaptive dilation cycles. 5 | blockF: 0 # Number of fixed residual blocks. 6 | cycleF: 0 # Number of fixed dilation cycles. 7 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 8 | noise_network_params: 9 | blockA: 0 # Number of adaptive residual blocks. 10 | cycleA: 0 # Number of adaptive dilation cycles. 11 | blockF: 5 # Number of fixed residual blocks. 12 | cycleF: 5 # Number of fixed dilation cycles. 13 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 14 | filter_network_params: 15 | blockA: 0 # Number of adaptive residual blocks. 16 | cycleA: 0 # Number of adaptive dilation cycles. 17 | blockF: 30 # Number of fixed residual blocks. 18 | cycleF: 3 # Number of fixed dilation cycles. 19 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 20 | periodicity_estimator_params: 21 | conv_layers: 3 # Number of convolution layers. 22 | kernel_size: 5 # Kernel size. 23 | dilation: 1 # Dilation size. 24 | padding_mode: "replicate" # Padding mode. 25 | in_channels: 1 # Number of input channels. 26 | out_channels: 1 # Number of output channels. 27 | residual_channels: 64 # Number of channels in residual conv. 28 | gate_channels: 128 # Number of channels in gated conv. 29 | skip_channels: 64 # Number of channels in skip conv. 30 | aux_channels: 62 # Number of channels for auxiliary feature conv. 31 | aux_context_window: 2 # Context window size for auxiliary feature. 32 | # If set to 2, previous 2 and future 2 frames will be considered. 33 | use_weight_norm: true # Whether to use weight norm. 34 | upsample_params: # Upsampling network parameters. 35 | upsample_scales: [5, 4, 3, 2] # Upsampling scales. Product of these must be the same as hop size. 36 | -------------------------------------------------------------------------------- /usfgan/bin/config/generator/parallel_hn_usfgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: usfgan.models.ParallelHnUSFGANGenerator 2 | harmonic_network_params: 3 | blockA: 20 # Number of adaptive residual blocks. 4 | cycleA: 4 # Number of adaptive dilation cycles. 5 | blockF: 0 # Number of fixed residual blocks. 6 | cycleF: 0 # Number of fixed dilation cycles. 7 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 8 | noise_network_params: 9 | blockA: 0 # Number of adaptive residual blocks. 10 | cycleA: 0 # Number of adaptive dilation cycles. 11 | blockF: 5 # Number of fixed residual blocks. 12 | cycleF: 5 # Number of fixed dilation cycles. 13 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 14 | filter_network_params: 15 | blockA: 0 # Number of adaptive residual blocks. 16 | cycleA: 0 # Number of adaptive dilation cycles. 17 | blockF: 30 # Number of fixed residual blocks. 18 | cycleF: 3 # Number of fixed dilation cycles. 19 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive). 20 | periodicity_estimator_params: 21 | conv_layers: 3 # Number of convolution layers. 22 | kernel_size: 5 # Kernel size. 23 | dilation: 1 # Dilation size. 24 | padding_mode: "replicate" # Padding mode. 25 | in_channels: 1 # Number of input channels. 26 | out_channels: 1 # Number of output channels. 27 | residual_channels: 64 # Number of channels in residual conv. 28 | gate_channels: 128 # Number of channels in gated conv. 29 | skip_channels: 64 # Number of channels in skip conv. 30 | aux_channels: 62 # Number of channels for auxiliary feature conv. 31 | aux_context_window: 2 # Context window size for auxiliary feature. 32 | # If set to 2, previous 2 and future 2 frames will be considered. 33 | use_weight_norm: true # Whether to use weight norm. 34 | upsample_params: # Upsampling network parameters. 35 | upsample_scales: [5, 4, 3, 2] # Upsampling scales. Product of these must be the same as hop size. 36 | -------------------------------------------------------------------------------- /usfgan/utils/index.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Yi-Chiao Wu (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Indexing-related functions.""" 7 | 8 | import torch 9 | from torch.nn import ConstantPad1d as pad1d 10 | 11 | 12 | def pd_indexing(x, d, dilation, batch_index, ch_index): 13 | """Pitch-dependent indexing of past and future samples. 14 | 15 | Args: 16 | x (Tensor): Input feature map (B, C, T). 17 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T). 18 | dilation (Int): Dilation size. 19 | batch_index (Tensor): Batch index 20 | ch_index (Tensor): Channel index 21 | 22 | Returns: 23 | Tensor: Past output tensor (B, out_channels, T) 24 | Tensor: Future output tensor (B, out_channels, T) 25 | 26 | """ 27 | (_, _, batch_length) = d.size() 28 | dilations = d * dilation 29 | 30 | # get past index 31 | idxP = torch.arange(-batch_length, 0).float() 32 | idxP = idxP.to(x.device) 33 | idxP = torch.add(-dilations, idxP) 34 | idxP = idxP.round().long() 35 | maxP = -((torch.min(idxP) + batch_length)) 36 | assert maxP >= 0 37 | idxP = (batch_index, ch_index, idxP) 38 | # padding past tensor 39 | xP = pad1d((maxP, 0), 0)(x) 40 | 41 | # get future index 42 | idxF = torch.arange(0, batch_length).float() 43 | idxF = idxF.to(x.device) 44 | idxF = torch.add(dilations, idxF) 45 | idxF = idxF.round().long() 46 | maxF = torch.max(idxF) - (batch_length - 1) 47 | assert maxF >= 0 48 | idxF = (batch_index, ch_index, idxF) 49 | # padding future tensor 50 | xF = pad1d((0, maxF), 0)(x) 51 | 52 | return xP[idxP], xF[idxF] 53 | 54 | 55 | def index_initial(n_batch, n_ch, tensor=True): 56 | """Tensor batch and channel index initialization. 57 | 58 | Args: 59 | n_batch (Int): Number of batch. 60 | n_ch (Int): Number of channel. 61 | tensor (bool): Return tensor or numpy array 62 | 63 | Returns: 64 | Tensor: Batch index 65 | Tensor: Channel index 66 | 67 | """ 68 | batch_index = [] 69 | for i in range(n_batch): 70 | batch_index.append([[i]] * n_ch) 71 | ch_index = [] 72 | for i in range(n_ch): 73 | ch_index += [[i]] 74 | ch_index = [ch_index] * n_batch 75 | 76 | if tensor: 77 | batch_index = torch.tensor(batch_index) 78 | ch_index = torch.tensor(ch_index) 79 | if torch.cuda.is_available(): 80 | batch_index = batch_index.cuda() 81 | ch_index = ch_index.cuda() 82 | return batch_index, ch_index 83 | -------------------------------------------------------------------------------- /usfgan/bin/compute_statistics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Feature statistics computing script. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | - https://github.com/kan-bayashi/ParallelWaveGAN 11 | 12 | """ 13 | 14 | import os 15 | from logging import getLogger 16 | 17 | import hydra 18 | import numpy as np 19 | from hydra.utils import to_absolute_path 20 | from joblib import dump, load 21 | from omegaconf import DictConfig, OmegaConf 22 | from sklearn.preprocessing import StandardScaler 23 | from usfgan.utils import read_hdf5, read_txt 24 | 25 | # A logger for this file 26 | logger = getLogger(__name__) 27 | 28 | 29 | def calc_stats(file_list, config): 30 | """Calcute statistics 31 | Args: 32 | file_list (list): File list. 33 | config (dict): Dictionary of config. 34 | """ 35 | 36 | # define scalers 37 | scaler = load(config.stats) if os.path.isfile(config.stats) else {} 38 | for feat_type in config.feat_types: 39 | if feat_type == "uv": 40 | continue 41 | scaler[feat_type] = StandardScaler() 42 | 43 | # process over all of data 44 | for i, filename in enumerate(file_list): 45 | logger.info(f"now processing {filename} ({i + 1}/{len(file_list)})") 46 | for feat_type in config.feat_types: 47 | if feat_type == "uv": 48 | continue 49 | if feat_type == "f0": 50 | f0 = read_hdf5(to_absolute_path(filename), "/f0") 51 | feat = np.expand_dims(f0[f0 > 0], axis=-1) 52 | else: 53 | feat = read_hdf5(to_absolute_path(filename), f"/{feat_type}") 54 | if feat.shape[0] == 0: 55 | logger.warning(f"feat length is 0 {filename}/{feat_type}") 56 | continue 57 | scaler[feat_type].partial_fit(feat) 58 | 59 | if not os.path.exists(os.path.dirname(config.stats)): 60 | os.makedirs(os.path.dirname(config.stats)) 61 | dump(scaler, to_absolute_path(config.stats)) 62 | logger.info(f"Successfully saved statistics to {config.stats}.") 63 | 64 | 65 | @hydra.main(version_base=None, config_path="config", config_name="compute_statistics") 66 | def main(config: DictConfig): 67 | # show argument 68 | logger.info(OmegaConf.to_yaml(config)) 69 | 70 | # read file list 71 | file_list = read_txt(to_absolute_path(config.feats)) 72 | logger.info(f"number of utterances = {len(file_list)}") 73 | 74 | # calculate statistics 75 | calc_stats(file_list, config) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Setup Unified Source-Filter GAN Library.""" 4 | 5 | import os 6 | import sys 7 | from distutils.version import LooseVersion 8 | 9 | import pip 10 | from setuptools import find_packages, setup 11 | 12 | if LooseVersion(sys.version) < LooseVersion("3.8"): 13 | raise RuntimeError( 14 | "usfgan requires Python>=3.8, " "but your Python is {}".format(sys.version) 15 | ) 16 | if LooseVersion(pip.__version__) < LooseVersion("21.0.0"): 17 | raise RuntimeError( 18 | "pip>=21.0.0 is required, but your pip is {}. " 19 | 'Try again after "pip install -U pip"'.format(pip.__version__) 20 | ) 21 | 22 | requirements = { 23 | "install": [ 24 | "wheel", 25 | "torch>=1.9.0", 26 | "torchaudio>=0.8.1", 27 | "setuptools>=38.5.1", 28 | "librosa>=0.8.0", 29 | "soundfile>=0.10.2", 30 | "tensorboardX>=2.2", 31 | "matplotlib>=3.1.0", 32 | "PyYAML>=3.12", 33 | "tqdm>=4.26.1", 34 | "h5py>=2.10.0", 35 | "pyworld>=0.2.12", 36 | "sprocket-vc", 37 | "protobuf<=3.19.0", 38 | "hydra-core>=1.2", 39 | ], 40 | "setup": [ 41 | "numpy", 42 | "pytest-runner", 43 | ], 44 | "test": [ 45 | "pytest>=3.3.0", 46 | "hacking>=1.1.0", 47 | "flake8>=3.7.8", 48 | "flake8-docstrings>=1.3.1", 49 | ], 50 | } 51 | entry_points = { 52 | "console_scripts": [ 53 | "usfgan-extract-features=usfgan.bin.extract_features:main", 54 | "usfgan-compute-statistics=usfgan.bin.compute_statistics:main", 55 | "usfgan-train=usfgan.bin.train:main", 56 | "usfgan-decode=usfgan.bin.decode:main", 57 | ] 58 | } 59 | 60 | install_requires = requirements["install"] 61 | setup_requires = requirements["setup"] 62 | tests_require = requirements["test"] 63 | extras_require = { 64 | k: v for k, v in requirements.items() if k not in ["install", "setup"] 65 | } 66 | 67 | dirname = os.path.dirname(__file__) 68 | setup( 69 | name="usfgan", 70 | version="0.1", 71 | url="http://github.com/chomeyama/HN-UnifiedSourceFilterGAN", 72 | author="Reo Yoneyama", 73 | author_email="yoneyama.reo@g.sp.m.is.nagoya-u.ac.jp", 74 | description="Harmonic-plus-Noise Unified Source-Filter GAN implementation", 75 | long_description_content_type="text/markdown", 76 | long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(), 77 | license="MIT License", 78 | packages=find_packages(include=["usfgan*"]), 79 | install_requires=install_requires, 80 | setup_requires=setup_requires, 81 | tests_require=tests_require, 82 | extras_require=extras_require, 83 | entry_points=entry_points, 84 | classifiers=[ 85 | "Programming Language :: Python :: 3.9.5", 86 | "Intended Audience :: Science/Research", 87 | "Operating System :: POSIX :: Linux", 88 | "License :: OSI Approved :: MIT License", 89 | "Topic :: Software Development :: Libraries :: Python Modules", 90 | ], 91 | ) 92 | -------------------------------------------------------------------------------- /usfgan/bin/config/train/usfgan.yaml: -------------------------------------------------------------------------------- 1 | # Interval setting 2 | discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator. 3 | train_max_steps: 600000 # Number of pre-training steps. 4 | save_interval_steps: 100000 # Interval steps to save checkpoint. 5 | eval_interval_steps: 2000 # Interval steps to evaluate the network. 6 | log_interval_steps: 2000 # Interval steps to record the training log. 7 | resume: # Epoch to resume training. 8 | 9 | # Loss balancing coefficients. 10 | lambda_stft: 1.0 11 | lambda_source: 1.0 12 | lambda_adv: 4.0 13 | lambda_feat_match: 0.0 14 | 15 | # Multi-resolution STFT loss setting 16 | stft_loss: 17 | _target_: usfgan.losses.MultiResolutionSTFTLoss 18 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 19 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 20 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 21 | window: hann_window # Window function for STFT-based loss 22 | 23 | # Source regularization loss setting 24 | source_loss: 25 | _target_: usfgan.losses.FlattenLoss 26 | sampling_rate: 24000 # Sampling rate. 27 | fft_size: 2048 # FFT size. 28 | hop_size: 120 # Hop size. 29 | f0_floor: 70 # Minimum F0. 30 | f0_ceil: 340 # Maximum F0. 31 | power: false # Whether to use power or magnitude spectrogram. 32 | elim_0th: false # Whether to exclude 0th components of cepstrums in 33 | # CheapTrick estimation. If set to true, source-network 34 | # is forced to estimate the power of the output signal. 35 | l2_norm: false # Whether to use L1 or L2 norm. 36 | 37 | # Adversarial loss setting 38 | adversarial_loss: 39 | _target_: usfgan.losses.AdversarialLoss 40 | 41 | # Feature matching loss setting 42 | feat_match_loss: 43 | _target_: usfgan.losses.FeatureMatchLoss 44 | 45 | # Optimizer setting 46 | generator_optimizer: 47 | _target_: usfgan.optimizers.RAdam 48 | lr: 0.0001 # Generator's learning rate. 49 | eps: 1.0e-6 # Generator's epsilon. 50 | weight_decay: 0.0 # Generator's weight decay coefficient. 51 | generator_scheduler: 52 | _target_: torch.optim.lr_scheduler.StepLR 53 | step_size: 200000 # Generator's scheduler step size. 54 | gamma: 0.5 # Generator's scheduler gamma. 55 | # At each step size, lr will be multiplied by this parameter. 56 | generator_grad_norm: 10 # Generator's gradient norm. 57 | discriminator_optimizer: 58 | _target_: usfgan.optimizers.RAdam 59 | lr: 0.00005 # Discriminator's learning rate. 60 | eps: 1.0e-6 # Discriminator's epsilon. 61 | weight_decay: 0.0 # Discriminator's weight decay coefficient. 62 | discriminator_scheduler: 63 | _target_: torch.optim.lr_scheduler.StepLR 64 | step_size: 200000 # Discriminator's scheduler step size. 65 | gamma: 0.5 # Discriminator's scheduler gamma. 66 | # At each step size, lr will be multiplied by this parameter. 67 | discriminator_grad_norm: 10 # Discriminator's gradient norm. 68 | -------------------------------------------------------------------------------- /usfgan/losses/adversarial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Adversarial loss modules. 7 | 8 | References: 9 | - https://github.com/kan-bayashi/ParallelWaveGAN 10 | 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | class AdversarialLoss(nn.Module): 19 | """Adversarial loss module.""" 20 | 21 | def __init__( 22 | self, 23 | average_by_discriminators=False, 24 | loss_type="mse", 25 | ): 26 | """Initialize AversarialLoss module.""" 27 | super(AdversarialLoss, self).__init__() 28 | self.average_by_discriminators = average_by_discriminators 29 | 30 | assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." 31 | if loss_type == "mse": 32 | self.adv_criterion = self._mse_adv_loss 33 | self.fake_criterion = self._mse_fake_loss 34 | self.real_criterion = self._mse_real_loss 35 | else: 36 | self.adv_criterion = self._hinge_adv_loss 37 | self.fake_criterion = self._hinge_fake_loss 38 | self.real_criterion = self._hinge_real_loss 39 | 40 | def forward(self, p_fakes, p_reals=None): 41 | """Calcualate generator/discriminator adversarial loss. 42 | 43 | Args: 44 | p_fakes (list): List of 45 | discriminator outputs calculated from generator outputs. 46 | p_reals (list): List of 47 | discriminator outputs calculated from groundtruth. 48 | 49 | Returns: 50 | Tensor: Generator adversarial loss value. 51 | Tensor: Discriminator real loss value. 52 | Tensor: Discriminator fake loss value. 53 | 54 | """ 55 | # generator adversarial loss 56 | if p_reals is None: 57 | adv_loss = 0.0 58 | for p_fake in p_fakes: 59 | adv_loss += self.adv_criterion(p_fake) 60 | 61 | if self.average_by_discriminators: 62 | adv_loss /= len(p_fakes) 63 | 64 | return adv_loss 65 | 66 | # discriminator adversarial loss 67 | else: 68 | fake_loss = 0.0 69 | real_loss = 0.0 70 | for p_fake, p_real in zip(p_fakes, p_reals): 71 | fake_loss += self.fake_criterion(p_fake) 72 | real_loss += self.real_criterion(p_real) 73 | 74 | if self.average_by_discriminators: 75 | fake_loss /= len(p_fakes) 76 | real_loss /= len(p_reals) 77 | 78 | return fake_loss, real_loss 79 | 80 | def _mse_adv_loss(self, x): 81 | return F.mse_loss(x, x.new_ones(x.size())) 82 | 83 | def _mse_real_loss(self, x): 84 | return F.mse_loss(x, x.new_ones(x.size())) 85 | 86 | def _mse_fake_loss(self, x): 87 | return F.mse_loss(x, x.new_zeros(x.size())) 88 | 89 | def _hinge_adv_loss(self, x): 90 | return -x.mean() 91 | 92 | def _hinge_real_loss(self, x): 93 | return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) 94 | 95 | def _hinge_fake_loss(self, x): 96 | return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) 97 | -------------------------------------------------------------------------------- /usfgan/bin/config/train/hn_usfgan.yaml: -------------------------------------------------------------------------------- 1 | # Interval setting 2 | discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. 3 | train_max_steps: 600000 # Number of pre-training steps. 4 | save_interval_steps: 100000 # Interval steps to save checkpoint. 5 | eval_interval_steps: 2000 # Interval steps to evaluate the network. 6 | log_interval_steps: 2000 # Interval steps to record the training log. 7 | resume: # Epoch to resume training. 8 | 9 | # Loss balancing coefficients. 10 | lambda_stft: 45.0 11 | lambda_source: 1.0 12 | lambda_adv: 1.0 13 | lambda_feat_match: 0.0 14 | 15 | # Mel-spectrogram reconstraction loss setting 16 | stft_loss: 17 | _target_: usfgan.losses.MelSpectralLoss 18 | fft_size: 1024 # FFT size for STFT-based loss. 19 | hop_size: 256 # Hop size for STFT-based loss 20 | win_length: 1024 # Window length for STFT-based loss. 21 | window: hann_window # Window function for STFT-based loss. 22 | sampling_rate: 24000 # Samplring rate. 23 | n_mels: 80 # Number of bins of mel-filter-bank. 24 | fmin: 0 # Minimum frequency of mel-filter-bank. 25 | fmax: null # Maximum frequency of mel-filter-bank. 26 | # If null, it will be fs / 2. 27 | 28 | # Source regularization loss setting 29 | source_loss: 30 | _target_: usfgan.losses.ResidualLoss 31 | sampling_rate: 24000 # Sampling rate. 32 | fft_size: 2048 # FFT size. 33 | hop_size: 120 # Hop size. 34 | f0_floor: 70 # Minimum F0. 35 | f0_ceil: 340 # Maximum F0. 36 | n_mels: 80 # Number of mel-filter-bank bins. 37 | fmin: 0 # Minimum frequency of mel-filter-bank. 38 | fmax: null # Maximum frequency of mel-filter-bank. 39 | power: false # Whether to use power or magnitude spectrogram. 40 | elim_0th: true # Whether to exclude 0th components of cepstrums in 41 | # CheapTrick estimation. If set to true, source-network 42 | # is forced to estimate the power of the output signal. 43 | 44 | # Adversarial loss setting 45 | adversarial_loss: 46 | _target_: usfgan.losses.AdversarialLoss 47 | average_by_discriminators: false # Whether to average loss by #discriminators. 48 | loss_type: mse 49 | 50 | # Feature matching loss setting 51 | feat_match_loss: 52 | _target_: usfgan.losses.FeatureMatchLoss 53 | average_by_layers: false # Whether to average loss by #layers in each discriminator. 54 | 55 | # Optimizer and scheduler setting 56 | generator_optimizer: 57 | _target_: torch.optim.Adam 58 | lr: 2.0e-4 59 | betas: [0.5, 0.9] 60 | weight_decay: 0.0 61 | generator_scheduler: 62 | _target_: torch.optim.lr_scheduler.MultiStepLR 63 | gamma: 0.5 64 | milestones: 65 | - 200000 66 | - 400000 67 | - 600000 68 | - 800000 69 | generator_grad_norm: 10 70 | discriminator_optimizer: 71 | _target_: torch.optim.Adam 72 | lr: 2.0e-4 73 | betas: [0.5, 0.9] 74 | weight_decay: 0.0 75 | discriminator_scheduler: 76 | _target_: torch.optim.lr_scheduler.MultiStepLR 77 | gamma: 0.5 78 | milestones: 79 | - 200000 80 | - 400000 81 | - 600000 82 | - 800000 83 | discriminator_grad_norm: 10 84 | -------------------------------------------------------------------------------- /usfgan/optimizers/radam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """RAdam optimizer. 4 | 5 | This code is drived from https://github.com/LiyuanLucasLiu/RAdam. 6 | 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | from torch.optim.optimizer import Optimizer 13 | 14 | 15 | class RAdam(Optimizer): 16 | """Rectified Adam optimizer.""" 17 | 18 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 19 | """Initilize RAdam optimizer.""" 20 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 21 | self.buffer = [[None, None, None] for ind in range(10)] 22 | super(RAdam, self).__init__(params, defaults) 23 | 24 | def __setstate__(self, state): 25 | """Set state.""" 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | """Run one step.""" 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group["params"]: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError("RAdam does not support sparse gradients") 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state["step"] = 0 49 | state["exp_avg"] = torch.zeros_like(p_data_fp32) 50 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) 53 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 56 | beta1, beta2 = group["betas"] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state["step"] += 1 62 | buffered = self.buffer[int(state["step"] % 10)] 63 | if state["step"] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state["step"] 67 | beta2_t = beta2 ** state["step"] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt( 75 | (1 - beta2_t) 76 | * (N_sma - 4) 77 | / (N_sma_max - 4) 78 | * (N_sma - 2) 79 | / N_sma 80 | * N_sma_max 81 | / (N_sma_max - 2) 82 | ) / ( 83 | 1 - beta1 ** state["step"] 84 | ) # NOQA 85 | else: 86 | step_size = 1.0 / (1 - beta1 ** state["step"]) 87 | buffered[2] = step_size 88 | 89 | if group["weight_decay"] != 0: 90 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) 91 | 92 | # more conservative since it's an approximated value 93 | if N_sma >= 5: 94 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 95 | p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) 96 | else: 97 | p_data_fp32.add_(-step_size * group["lr"], exp_avg) 98 | 99 | p.data.copy_(p_data_fp32) 100 | 101 | return loss 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Harmonic-plus-Noise Unified Source-Filter GAN implementation with Pytorch 2 | 3 | 4 | This repo provides official PyTorch implementation of [HN-uSFGAN](https://arxiv.org/abs/2205.06053), a high-fidelity and pitch controllable neural vocoder based on unifid source-filter networks.
5 | HN-uSFGAN is an extended model of [uSFGAN](https://arxiv.org/abs/2104.04668), and this repo includes the original [uSFGAN implementation](https://github.com/chomeyama/UnifiedSourceFilterGAN) with some modifications.
6 | 7 | For more information, please see the [demo](https://chomeyama.github.io/HN-UnifiedSourceFilterGAN-Demo/). 8 | 9 | This repository is tested on the following condition. 10 | 11 | - Ubuntu 20.04.3 LTS 12 | - Titan RTX 3090 GPU 13 | - Python 3.9.5 14 | - Cuda 11.5 15 | - CuDNN 8.1.1.33-1+cuda11.2 16 | 17 | ## Environment setup 18 | 19 | ```bash 20 | $ cd HN-UnifiedSourceFilterGAN 21 | $ pip install -e . 22 | ``` 23 | 24 | Please refer to the [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) repo for more details. 25 | 26 | ## Folder architecture 27 | - **egs**: 28 | The folder for projects. 29 | - **egs/vctk**: 30 | The folder of the VCTK project example. 31 | - **usfgan**: 32 | The folder of the source codes. 33 | 34 | ## Run 35 | 36 | In this repo, hyperparameters are managed using [Hydra](https://hydra.cc/docs/intro/).
37 | Hydra provides an easy way to dynamically create a hierarchical configuration by composition and override it through config files and the command line. 38 | 39 | ### Dataset preparation 40 | 41 | Make dataset and scp files denoting paths to each audio files according to your own dataset (E.g., `egs/vctk/data/scp/vctk_train_24kHz.scp`).
42 | Also, list files denoting paths to the features extracted in the next step are required (E.g., `egs/vctk/data/scp/vctk_train_24kHz.list`).
43 | Note that scp/list files for each training/validation/evaluation are needed. 44 | 45 | ### Preprocessing 46 | 47 | ```bash 48 | # Move to the project directory 49 | $ cd egs/vctk 50 | 51 | # Extract acoustic features (F0, mel-cepstrum, and etc.) 52 | # You can customize parameters according to usfgan/bin/config/extract_features.yaml 53 | $ usfgan-extract-features audio=data/scp/vctk_all_24kHz.scp 54 | 55 | # Compute statistics of training and testing data 56 | $ usfgan-compute-statistics feats=data/scp/vctk_train_24kHz.list stats=data/stats/vctk_train_24kHz.joblib 57 | ``` 58 | 59 | ### Training 60 | 61 | ```bash 62 | # Train a model customizing the hyperparameters as you like 63 | # The following setting of Parallel-HN-uSFGAN generator with HiFiGAN discriminator would show best performance 64 | $ usfgan-train generator=parallel_hn_usfgan discriminator=hifigan train=hn_usfgan data=vctk_24kHz out_dir=exp/parallel_hn_usfgan 65 | ``` 66 | 67 | ### Inference 68 | 69 | ```bash 70 | # Decode with natural acoustic features 71 | $ usfgan-decode out_dir=exp/parallel-hn-usfgan/wav/600000steps checkpoint_path=exp/parallel-hn-usfgan/checkpoints/checkpoint-600000steps.pkl 72 | # Decode with halved f0 73 | $ usfgan-decode out_dir=exp/parallel-hn-usfgan/wav/600000steps checkpoint_path=exp/parallel-hn-usfgan/checkpoints/checkpoint-600000steps.pkl f0_factor=0.50 74 | ``` 75 | 76 | ### Monitor training progress 77 | 78 | ```bash 79 | $ tensorboard --logdir exp 80 | ``` 81 | 82 | ## Citation 83 | If you find the code is helpful, please cite the following article. 84 | 85 | ``` 86 | @inproceedings{yoneyama22_interspeech, 87 | author={Reo Yoneyama and Yi-Chiao Wu and Tomoki Toda}, 88 | title={{Unified Source-Filter GAN with Harmonic-plus-Noise Source Excitation Generation}}, 89 | year=2022, 90 | booktitle={Proc. Interspeech 2022}, 91 | pages={848--852}, 92 | doi={10.21437/Interspeech.2022-11130} 93 | } 94 | ``` 95 | 96 | ## Authors 97 | 98 | Development: 99 | Reo Yoneyama @ Nagoya University ([@chomeyama](https://github.com/chomeyama))
100 | E-mail: `yoneyama.reo@g.sp.m.is.nagoya-u.ac.jp` 101 | 102 | Advisors:
103 | Yi-Chiao Wu @ Nagoya University ([@bigpon](https://github.com/bigpon))
104 | E-mail: `yichiao.wu@g.sp.m.is.nagoya-u.ac.jp`
105 | Tomoki Toda @ Nagoya University
106 | E-mail: `tomoki@icts.nagoya-u.ac.jp` 107 | -------------------------------------------------------------------------------- /usfgan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Yi-Chiao Wu (Nagoya University) 4 | # based on a Parallel WaveGAN script by Tomoki Hayashi (Nagoya University) 5 | # (https://github.com/kan-bayashi/ParallelWaveGAN) 6 | # MIT License (https://opensource.org/licenses/MIT) 7 | 8 | """Utility functions.""" 9 | 10 | import fnmatch 11 | import os 12 | import sys 13 | from logging import getLogger 14 | 15 | import h5py 16 | import numpy as np 17 | 18 | # A logger for this file 19 | logger = getLogger(__name__) 20 | 21 | 22 | def find_files(root_dir, query="*.wav", include_root_dir=True): 23 | """Find files recursively. 24 | 25 | Args: 26 | root_dir (str): Root root_dir to find. 27 | query (str): Query to find. 28 | include_root_dir (bool): If False, root_dir name is not included. 29 | 30 | Returns: 31 | list: List of found filenames. 32 | 33 | """ 34 | files = [] 35 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 36 | for filename in fnmatch.filter(filenames, query): 37 | files.append(os.path.join(root, filename)) 38 | if not include_root_dir: 39 | files = [file_.replace(root_dir + "/", "") for file_ in files] 40 | 41 | return files 42 | 43 | 44 | def read_hdf5(hdf5_name, hdf5_path): 45 | """Read hdf5 dataset. 46 | 47 | Args: 48 | hdf5_name (str): Filename of hdf5 file. 49 | hdf5_path (str): Dataset name in hdf5 file. 50 | 51 | Return: 52 | any: Dataset values. 53 | 54 | """ 55 | if not os.path.exists(hdf5_name): 56 | logger.error(f"There is no such a hdf5 file ({hdf5_name}).") 57 | sys.exit(1) 58 | 59 | hdf5_file = h5py.File(hdf5_name, "r") 60 | 61 | if hdf5_path not in hdf5_file: 62 | logger.error(f"There is no such a data in hdf5 file. ({hdf5_path})") 63 | sys.exit(1) 64 | 65 | hdf5_data = hdf5_file[hdf5_path][()] 66 | hdf5_file.close() 67 | 68 | return hdf5_data 69 | 70 | 71 | def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): 72 | """Write dataset to hdf5. 73 | 74 | Args: 75 | hdf5_name (str): Hdf5 dataset filename. 76 | hdf5_path (str): Dataset path in hdf5. 77 | write_data (ndarray): Data to write. 78 | is_overwrite (bool): Whether to overwrite dataset. 79 | 80 | """ 81 | # convert to numpy array 82 | write_data = np.array(write_data) 83 | 84 | # check folder existence 85 | folder_name, _ = os.path.split(hdf5_name) 86 | if not os.path.exists(folder_name) and len(folder_name) != 0: 87 | os.makedirs(folder_name) 88 | 89 | # check hdf5 existence 90 | if os.path.exists(hdf5_name): 91 | # if already exists, open with r+ mode 92 | hdf5_file = h5py.File(hdf5_name, "r+") 93 | # check dataset existence 94 | if hdf5_path in hdf5_file: 95 | if is_overwrite: 96 | hdf5_file.__delitem__(hdf5_path) 97 | else: 98 | logger.error( 99 | "Dataset in hdf5 file already exists. " 100 | "if you want to overwrite, please set is_overwrite = True." 101 | ) 102 | hdf5_file.close() 103 | sys.exit(1) 104 | else: 105 | # if not exists, open with w mode 106 | hdf5_file = h5py.File(hdf5_name, "w") 107 | 108 | # write data to hdf5 109 | hdf5_file.create_dataset(hdf5_path, data=write_data) 110 | hdf5_file.flush() 111 | hdf5_file.close() 112 | 113 | 114 | def check_hdf5(hdf5_name, hdf5_path): 115 | """Check hdf5 file existence 116 | 117 | Args: 118 | hdf5_name (str): filename of hdf5 file 119 | hdf5_path (str): dataset name in hdf5 file 120 | 121 | Return: 122 | (bool): dataset exists then return true 123 | 124 | """ 125 | if not os.path.exists(hdf5_name): 126 | return False 127 | else: 128 | with h5py.File(hdf5_name, "r") as f: 129 | if hdf5_path in f: 130 | return True 131 | else: 132 | return False 133 | 134 | 135 | def read_txt(file_list): 136 | """Read .txt file list 137 | 138 | Arg: 139 | file_list (str): txt file filename 140 | 141 | Return: 142 | (list): list of read lines 143 | 144 | """ 145 | with open(file_list, "r") as f: 146 | filenames = f.readlines() 147 | return [filename.replace("\n", "") for filename in filenames] 148 | 149 | 150 | def check_filename(list1, list2): 151 | """Check the filenames of two list are matched 152 | 153 | Arg: 154 | list1 (list): file list 1 155 | list2 (list): file list 2 156 | 157 | Return: 158 | (bool): matched (True) or not (False) 159 | 160 | """ 161 | 162 | def _filename(x): 163 | return os.path.basename(x).split(".")[0] 164 | 165 | list1 = list(map(_filename, list1)) 166 | list2 = list(map(_filename, list2)) 167 | 168 | return list1 == list2 169 | -------------------------------------------------------------------------------- /usfgan/bin/decode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Decoding Script for Unified Source-Filter GAN. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | - https://github.com/kan-bayashi/ParallelWaveGAN 11 | 12 | """ 13 | 14 | import logging 15 | import os 16 | from logging import getLogger 17 | from time import time 18 | 19 | import hydra 20 | import numpy as np 21 | import soundfile as sf 22 | import torch 23 | import usfgan 24 | from hydra.utils import to_absolute_path 25 | from omegaconf import DictConfig 26 | from tqdm import tqdm 27 | from usfgan.datasets import FeatDataset 28 | from usfgan.utils.features import SignalGenerator 29 | 30 | # A logger for this file 31 | logger = getLogger(__name__) 32 | 33 | 34 | @hydra.main(version_base=None, config_path="config", config_name="decode") 35 | def main(config: DictConfig) -> None: 36 | """Run decoding process.""" 37 | 38 | # fix seed 39 | np.random.seed(config.seed) 40 | torch.manual_seed(config.seed) 41 | torch.cuda.manual_seed(config.seed) 42 | os.environ["PYTHONHASHSEED"] = str(config.seed) 43 | 44 | # set device 45 | if torch.cuda.is_available(): 46 | device = torch.device("cuda") 47 | else: 48 | device = torch.device("cpu") 49 | 50 | # load pre-trained model from checkpoint file 51 | model = hydra.utils.instantiate(config.generator) 52 | model.load_state_dict(torch.load(to_absolute_path(config.checkpoint_path))["model"]["generator"]) 53 | logger.info(f"Loaded model parameters from {config.checkpoint_path}.") 54 | model.remove_weight_norm() 55 | model.eval().to(device) 56 | 57 | # check directory existence 58 | out_dir = to_absolute_path(config.out_dir) 59 | if not os.path.isdir(out_dir): 60 | os.makedirs(out_dir) 61 | 62 | # get dataset 63 | dataset = FeatDataset( 64 | stats=to_absolute_path(config.data.stats), 65 | feat_list=config.data.eval_feat, 66 | return_filename=True, 67 | sample_rate=config.data.sample_rate, 68 | hop_size=config.data.hop_size, 69 | dense_factor=config.data.dense_factor, 70 | df_f0_type=config.data.df_f0_type, 71 | aux_feats=config.data.aux_feats, 72 | f0_factor=config.f0_factor, 73 | ) 74 | logger.info(f"The number of features to be decoded = {len(dataset)}.") 75 | 76 | # get data processor 77 | signal_generator = SignalGenerator( 78 | sample_rate=config.data.sample_rate, 79 | hop_size=config.data.hop_size, 80 | sine_amp=config.data.sine_amp, 81 | noise_amp=config.data.noise_amp, 82 | signal_types=config.data.signal_types, 83 | ) 84 | pad_fn = torch.nn.ReplicationPad1d(config.generator.aux_context_window) 85 | 86 | # start generation 87 | total_rtf = 0.0 88 | with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar: 89 | for idx, (feat_path, c, df, f0, contf0) in enumerate(pbar, 1): 90 | # setup input features 91 | c = pad_fn(torch.FloatTensor(c).unsqueeze(0).transpose(2, 1)).to(device) 92 | df = torch.FloatTensor(df).view(1, 1, -1).to(device) 93 | f0 = torch.FloatTensor(f0).unsqueeze(0).transpose(2, 1).to(device) 94 | contf0 = torch.FloatTensor(contf0).unsqueeze(0).transpose(2, 1).to(device) 95 | # create input signal tensor 96 | if config.data.sine_f0_type == "contf0": 97 | in_signal = signal_generator(contf0) 98 | else: 99 | in_signal = signal_generator(f0) 100 | 101 | # generate 102 | start = time() 103 | y, s = model(in_signal, c, df)[:2] 104 | rtf = (time() - start) / (y.size(-1) / config.data.sample_rate) 105 | pbar.set_postfix({"RTF": rtf}) 106 | total_rtf += rtf 107 | 108 | # save output signal as PCM 16 bit wav file 109 | utt_id = os.path.splitext(os.path.basename(feat_path))[0] 110 | wav_filename = f"{utt_id}_f{config.f0_factor:.2f}.wav" 111 | sf.write( 112 | os.path.join(out_dir, wav_filename), 113 | y.view(-1).cpu().numpy(), 114 | config.data.sample_rate, 115 | "PCM_16", 116 | ) 117 | 118 | # save source signal as PCM 16 bit wav file 119 | if config.save_source: 120 | wav_filename = wav_filename.replace(".wav", "_s.wav") 121 | s = s.view(-1).cpu().numpy() 122 | s = s / np.max(np.abs(s)) # normalize 123 | sf.write( 124 | os.path.join(out_dir, wav_filename), 125 | s, 126 | config.data.sample_rate, 127 | "PCM_16", 128 | ) 129 | 130 | # report average RTF 131 | logger.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).") 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /usfgan/utils/features.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Feature-related functions. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | 11 | """ 12 | 13 | import sys 14 | from logging import getLogger 15 | 16 | import numpy as np 17 | import torch 18 | from torch.nn.functional import interpolate 19 | 20 | # A logger for this file 21 | logger = getLogger(__name__) 22 | 23 | 24 | def validate_length(x, y, hop_size=None): 25 | """Validate length 26 | 27 | Args: 28 | x (ndarray): numpy array with x.shape[0] = len_x 29 | y (ndarray): numpy array with y.shape[0] = len_y 30 | hop_size (int): upsampling factor 31 | 32 | Returns: 33 | (ndarray): length adjusted x with same length y 34 | (ndarray): length adjusted y with same length x 35 | 36 | """ 37 | if hop_size is None: 38 | if x.shape[0] < y.shape[0]: 39 | y = y[: x.shape[0]] 40 | if x.shape[0] > y.shape[0]: 41 | x = x[: y.shape[0]] 42 | assert len(x) == len(y) 43 | else: 44 | if x.shape[0] > y.shape[0] * hop_size: 45 | x = x[: y.shape[0] * hop_size] 46 | if x.shape[0] < y.shape[0] * hop_size: 47 | mod_y = y.shape[0] * hop_size - x.shape[0] 48 | mod_y_frame = mod_y // hop_size + 1 49 | y = y[:-mod_y_frame] 50 | x = x[: y.shape[0] * hop_size] 51 | assert len(x) == len(y) * hop_size 52 | 53 | return x, y 54 | 55 | 56 | def dilated_factor(batch_f0, fs, dense_factor): 57 | """Pitch-dependent dilated factor 58 | 59 | Args: 60 | batch_f0 (ndarray): the f0 sequence (T) 61 | fs (int): sampling rate 62 | dense_factor (int): the number of taps in one cycle 63 | 64 | Return: 65 | dilated_factors(np array): 66 | float array of the pitch-dependent dilated factors (T) 67 | 68 | """ 69 | batch_f0[batch_f0 == 0] = fs / dense_factor 70 | dilated_factors = np.ones(batch_f0.shape) * fs 71 | dilated_factors /= batch_f0 72 | dilated_factors /= dense_factor 73 | assert np.all(dilated_factors > 0) 74 | 75 | return dilated_factors 76 | 77 | 78 | class SignalGenerator: 79 | """Input signal generator module.""" 80 | 81 | def __init__( 82 | self, 83 | sample_rate=24000, 84 | hop_size=120, 85 | sine_amp=0.1, 86 | noise_amp=0.003, 87 | signal_types=["sine", "noise"], 88 | ): 89 | """Initialize WaveNetResidualBlock module. 90 | 91 | Args: 92 | sample_rate (int): Sampling rate. 93 | hop_size (int): Hop size of input F0. 94 | sine_amp (float): Sine amplitude for NSF-based sine generation. 95 | noise_amp (float): Noise amplitude for NSF-based sine generation. 96 | signal_types (list): List of input signal types for generator. 97 | 98 | """ 99 | self.sample_rate = sample_rate 100 | self.hop_size = hop_size 101 | self.signal_types = signal_types 102 | self.sine_amp = sine_amp 103 | self.noise_amp = noise_amp 104 | 105 | for signal_type in signal_types: 106 | if not signal_type in ["noise", "sine", "uv"]: 107 | logger.info(f"{signal_type} is not supported type for generator input.") 108 | sys.exit(0) 109 | logger.info(f"Use {signal_types} for generator input signals.") 110 | 111 | @torch.no_grad() 112 | def __call__(self, f0): 113 | signals = [] 114 | for typ in self.signal_types: 115 | if "noise" == typ: 116 | signals.append(self.random_noise(f0)) 117 | if "sine" == typ: 118 | signals.append(self.sinusoid(f0)) 119 | if "uv" == typ: 120 | signals.append(self.vuv_binary(f0)) 121 | 122 | input_batch = signals[0] 123 | for signal in signals[1:]: 124 | input_batch = torch.cat([input_batch, signal], axis=1) 125 | 126 | return input_batch 127 | 128 | @torch.no_grad() 129 | def random_noise(self, f0): 130 | """Calculate noise signals. 131 | 132 | Args: 133 | f0 (Tensor): F0 tensor (B, 1, T // hop_size). 134 | 135 | Returns: 136 | Tensor: Gaussian noise signals (B, 1, T). 137 | 138 | """ 139 | B, _, T = f0.size() 140 | noise = torch.randn((B, 1, T * self.hop_size), device=f0.device) 141 | 142 | return noise 143 | 144 | @torch.no_grad() 145 | def sinusoid(self, f0): 146 | """Calculate sine signals. 147 | 148 | Args: 149 | f0 (Tensor): F0 tensor (B, 1, T // hop_size). 150 | 151 | Returns: 152 | Tensor: Sines generated following NSF (B, 1, T). 153 | 154 | """ 155 | B, _, T = f0.size() 156 | vuv = interpolate((f0 > 0) * torch.ones_like(f0), T * self.hop_size) 157 | radious = (interpolate(f0, T * self.hop_size) / self.sample_rate) % 1 158 | sine = vuv * torch.sin(torch.cumsum(radious, dim=2) * 2 * np.pi) * self.sine_amp 159 | if self.noise_amp > 0: 160 | noise_amp = vuv * self.noise_amp + (1.0 - vuv) * self.noise_amp / 3.0 161 | noise = torch.randn((B, 1, T * self.hop_size), device=f0.device) * noise_amp 162 | sine = sine + noise 163 | 164 | return sine 165 | 166 | @torch.no_grad() 167 | def vuv_binary(self, f0): 168 | """Calculate V/UV binary sequences. 169 | 170 | Args: 171 | f0 (Tensor): F0 tensor (B, 1, T // hop_size). 172 | 173 | Returns: 174 | Tensor: V/UV binary sequences (B, 1, T). 175 | 176 | """ 177 | _, _, T = f0.size() 178 | uv = interpolate((f0 > 0) * torch.ones_like(f0), T * self.hop_size) 179 | 180 | return uv 181 | -------------------------------------------------------------------------------- /usfgan/losses/source.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Source regularization loss modules.""" 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import usfgan.losses 12 | from librosa.filters import mel as librosa_mel 13 | from usfgan.layers import CheapTrick 14 | 15 | 16 | class FlattenLoss(nn.Module): 17 | # Spectral envelope flattening loss module. 18 | 19 | def __init__( 20 | self, 21 | sampling_rate=24000, 22 | hop_size=120, 23 | fft_size=2048, 24 | f0_floor=70, 25 | f0_ceil=340, 26 | power=False, 27 | elim_0th=False, 28 | l2_norm=False, 29 | ): 30 | """Initialize spectral envelope regularlization loss module. 31 | 32 | Args: 33 | sampling_rate (int): Sampling rate. 34 | hop_size (int): Hop size. 35 | fft_size (int): FFT size. 36 | f0_floor (int): Minimum F0 value. 37 | f0_ceil (int): Maximum F0 value. 38 | uv_threshold (float): V/UV determining threshold. 39 | q1 (float): Parameter to remove effect of adjacent harmonics. 40 | 41 | """ 42 | super(FlattenLoss, self).__init__() 43 | self.hop_size = hop_size 44 | self.power = power 45 | self.elim_0th = elim_0th 46 | self.cheaptrick = CheapTrick( 47 | sampling_rate=sampling_rate, 48 | hop_size=hop_size, 49 | fft_size=fft_size, 50 | f0_floor=f0_floor, 51 | f0_ceil=f0_ceil, 52 | ) 53 | if l2_norm: 54 | self.loss = nn.MSELoss() 55 | else: 56 | self.loss = nn.L1Loss() 57 | 58 | def forward(self, s, f): 59 | """Calculate forward propagation. 60 | 61 | Args: 62 | s (Tensor): Predicted source signal (B, 1, T). 63 | f (Tensor): Extracted F0 sequence (B, 1, T'). 64 | 65 | Returns: 66 | loss (Tensor): Spectral envelope regularization loss value. 67 | 68 | """ 69 | s, f = s.squeeze(1), f.squeeze(1) 70 | e = self.cheaptrick.forward(s, f, self.power, self.elim_0th) 71 | loss = self.loss(e, e.new_zeros(e.size())) 72 | return loss 73 | 74 | 75 | class ResidualLoss(nn.Module): 76 | # hn-uSFGAN source regularization loss module. 77 | 78 | def __init__( 79 | self, 80 | sampling_rate=24000, 81 | fft_size=2048, 82 | hop_size=120, 83 | f0_floor=70, 84 | f0_ceil=340, 85 | n_mels=80, 86 | fmin=0, 87 | fmax=None, 88 | power=False, 89 | elim_0th=True, 90 | ): 91 | """Initialize Mel-spectrogram loss. 92 | 93 | Args: 94 | sampling_rate (int): Sampling rate. 95 | fft_size (int): FFT size. 96 | hop_size (int): Hop size. 97 | f0_floor (int): Minimum F0 value. 98 | f0_ceil (int): Maximum F0 value. 99 | n_mels (int): Number of Mel basis. 100 | fmin (int): Minimum frequency for Mel. 101 | fmax (int): Maximum frequency for Mel. 102 | power (bool): Whether to use power or magnitude spectrogram. 103 | elim_0th (bool): Whether to exclude 0th cepstrum in CheapTrick. 104 | If set to true, power is estimated by source-network. 105 | 106 | """ 107 | super(ResidualLoss, self).__init__() 108 | self.sampling_rate = sampling_rate 109 | self.fft_size = fft_size 110 | self.hop_size = hop_size 111 | self.cheaptrick = CheapTrick( 112 | sampling_rate=sampling_rate, 113 | hop_size=hop_size, 114 | fft_size=fft_size, 115 | f0_floor=f0_floor, 116 | f0_ceil=f0_ceil, 117 | ) 118 | self.win_length = fft_size 119 | self.register_buffer("window", torch.hann_window(self.win_length)) 120 | 121 | # define mel-filter-bank 122 | self.n_mels = n_mels 123 | self.fmin = fmin 124 | self.fmax = fmax if fmax is not None else sampling_rate / 2 125 | melmat = librosa_mel( 126 | sr=sampling_rate, n_fft=fft_size, n_mels=n_mels, fmin=fmin, fmax=self.fmax 127 | ).T 128 | self.register_buffer("melmat", torch.from_numpy(melmat).float()) 129 | 130 | self.power = power 131 | self.elim_0th = elim_0th 132 | 133 | def forward(self, s, y, f): 134 | """Calculate forward propagation. 135 | 136 | Args: 137 | s (Tensor): Predicted source signal (B, 1, T). 138 | y (Tensor): Target signal (B, 1, T). 139 | f (Tensor): Extracted F0 sequence (B, 1, T'). 140 | 141 | Returns: 142 | Tensor: Difference loss value. 143 | 144 | """ 145 | s, y, f = s.squeeze(1), y.squeeze(1), f.squeeze(1) 146 | 147 | with torch.no_grad(): 148 | # calculate log power (or magnitude) spectrograms 149 | e = self.cheaptrick.forward(y, f, self.power, self.elim_0th) 150 | y = usfgan.losses.stft( 151 | y, 152 | self.fft_size, 153 | self.hop_size, 154 | self.win_length, 155 | self.window, 156 | power=self.power, 157 | ) 158 | # adjust length, (B, T', C) 159 | minlen = min(e.size(1), y.size(1)) 160 | e, y = e[:, :minlen, :], y[:, :minlen, :] 161 | 162 | # calculate mean power (or magnitude) of y 163 | if self.elim_0th: 164 | y_mean = y.mean(dim=-1, keepdim=True) 165 | 166 | # calculate target of output source signal 167 | y = torch.log(torch.clamp(y, min=1e-7)) 168 | t = (y - e).exp() 169 | if self.elim_0th: 170 | t_mean = t.mean(dim=-1, keepdim=True) 171 | t = y_mean / t_mean * t 172 | 173 | # apply mel-filter-bank and log 174 | t = torch.matmul(t, self.melmat) 175 | t = torch.log(torch.clamp(t, min=1e-7)) 176 | 177 | # calculate power (or magnitude) spectrogram 178 | s = usfgan.losses.stft( 179 | s, 180 | self.fft_size, 181 | self.hop_size, 182 | self.win_length, 183 | self.window, 184 | power=self.power, 185 | ) 186 | # adjust length, (B, T', C) 187 | minlen = min(minlen, s.size(1)) 188 | s, t = s[:, :minlen, :], t[:, :minlen, :] 189 | 190 | # apply mel-filter-bank and log 191 | s = torch.matmul(s, self.melmat) 192 | s = torch.log(torch.clamp(s, min=1e-7)) 193 | 194 | loss = F.l1_loss(s, t.detach()) 195 | 196 | return loss 197 | -------------------------------------------------------------------------------- /usfgan/bin/anasyn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Analysis-synthesis script. 7 | 8 | Analysis: WORLD vocoder 9 | Synthesis: Pre-trained neural vocoder 10 | 11 | """ 12 | 13 | # A logger for this file 14 | import copy 15 | import os 16 | from logging import getLogger 17 | 18 | import hydra 19 | import librosa 20 | import numpy as np 21 | import pysptk 22 | import pyworld as pw 23 | import soundfile as sf 24 | import torch 25 | import torch.nn as nn 26 | from hydra.utils import instantiate, to_absolute_path 27 | from joblib import load 28 | from omegaconf import DictConfig 29 | from scipy.interpolate import interp1d 30 | from usfgan.utils.features import SignalGenerator, dilated_factor 31 | 32 | logger = getLogger(__name__) 33 | 34 | # All-pass-filter coefficients {key -> sampling rate : value -> coefficient} 35 | ALPHA = { 36 | 8000: 0.312, 37 | 12000: 0.369, 38 | 16000: 0.410, 39 | 22050: 0.455, 40 | 24000: 0.466, 41 | 32000: 0.504, 42 | 44100: 0.544, 43 | 48000: 0.554, 44 | } 45 | 46 | 47 | def convert_continuos_f0(f0): 48 | # get uv information as binary 49 | uv = np.float32(f0 != 0) 50 | # get start and end of f0 51 | if (f0 == 0).all(): 52 | logger.warn("all of the f0 values are 0.") 53 | return uv, f0, False 54 | start_f0 = f0[f0 != 0][0] 55 | end_f0 = f0[f0 != 0][-1] 56 | # padding start and end of f0 sequence 57 | cont_f0 = copy.deepcopy(f0) 58 | start_idx = np.where(cont_f0 == start_f0)[0][0] 59 | end_idx = np.where(cont_f0 == end_f0)[0][-1] 60 | cont_f0[:start_idx] = start_f0 61 | cont_f0[end_idx:] = end_f0 62 | # get non-zero frame index 63 | nz_frames = np.where(cont_f0 != 0)[0] 64 | # perform linear interpolation 65 | f = interp1d(nz_frames, cont_f0[nz_frames]) 66 | cont_f0 = f(np.arange(0, cont_f0.shape[0])) 67 | 68 | return uv, cont_f0 69 | 70 | 71 | @torch.no_grad() 72 | @hydra.main(version_base=None, config_path="config", config_name="anasyn") 73 | def main(config: DictConfig) -> None: 74 | """Run analysis-synthesis process.""" 75 | 76 | np.random.seed(config.seed) 77 | torch.manual_seed(config.seed) 78 | torch.cuda.manual_seed(config.seed) 79 | os.environ["PYTHONHASHSEED"] = str(config.seed) 80 | 81 | # set device 82 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 83 | logger.info(f"Decode on {device}.") 84 | 85 | # load pre-trained model from checkpoint file 86 | model = instantiate(config.generator) 87 | state_dict = torch.load( 88 | to_absolute_path(config.checkpoint_path), map_location="cpu" 89 | ) 90 | model.load_state_dict(state_dict["model"]["generator"]) 91 | logger.info(f"Loaded model parameters from {config.checkpoint_path}.") 92 | model.remove_weight_norm() 93 | model.eval().to(device) 94 | 95 | # get scaler 96 | scaler = load(to_absolute_path(config.stats)) 97 | 98 | # get data processor 99 | hop_size = int(config.sample_rate * config.frame_period * 0.001) 100 | signal_generator = SignalGenerator( 101 | sample_rate=config.sample_rate, 102 | hop_size=hop_size, 103 | sine_amp=config.sine_amp, 104 | noise_amp=config.noise_amp, 105 | signal_types=config.signal_types, 106 | ) 107 | pad_fn = nn.ReplicationPad1d(config.generator.aux_context_window) 108 | 109 | # create output directory 110 | os.makedirs(config.out_dir, exist_ok=True) 111 | 112 | # loop all wav files in in_dir 113 | for wav_file in os.listdir(config.in_dir): 114 | logger.info(os.path.splitext(wav_file)[1]) 115 | if os.path.splitext(wav_file)[1] != ".wav": 116 | continue 117 | wav_path = os.path.join(config.in_dir, wav_file) 118 | 119 | # WORLD analysis 120 | x, sr = sf.read(to_absolute_path(wav_path)) 121 | if sr != config.sample_rate: 122 | x = librosa.resample(x, orig_sr=sr, target_sr=config.sample_rate) 123 | f0, t = pw.harvest( 124 | x, 125 | config.sample_rate, 126 | f0_floor=config.f0_floor, 127 | f0_ceil=config.f0_ceil, 128 | frame_period=config.frame_period, 129 | ) 130 | sp = pw.cheaptrick(x, f0, t, config.sample_rate) 131 | ap = pw.d4c(x, f0, t, config.sample_rate) 132 | mcep = pysptk.sp2mc(sp, order=config.mcep_dim, alpha=ALPHA[config.sample_rate]) 133 | mcap = pysptk.sp2mc(ap, order=config.mcap_dim, alpha=ALPHA[config.sample_rate]) 134 | bap = pw.code_aperiodicity(ap, config.sample_rate) 135 | 136 | # prepare f0 related features 137 | uv, cf0 = convert_continuos_f0(f0) 138 | uv = uv[:, np.newaxis] # (T, 1) 139 | f0 = f0[:, np.newaxis] # (T, 1) 140 | cf0 = cf0[:, np.newaxis] # (T, 1) 141 | f0 *= config.f0_factor 142 | cf0 *= config.f0_factor 143 | 144 | # prepare input acoustic features 145 | c = [] 146 | for feat_type in config.aux_feats: 147 | if feat_type == "f0": 148 | c += [scaler[feat_type].transform(f0)] 149 | elif feat_type in ["cf0", "contf0"]: 150 | c += [scaler[feat_type].transform(cf0)] 151 | elif feat_type == "uv": 152 | c += [scaler[feat_type].transform(uv)] 153 | elif feat_type == "mcep": 154 | c += [scaler[feat_type].transform(mcep)] 155 | elif feat_type == "mcap": 156 | c += [scaler[feat_type].transform(mcap)] 157 | elif feat_type == "bap": 158 | c += [scaler[feat_type].transform(bap)] 159 | c = np.concatenate(c, axis=1) 160 | logger.info(c.shape) 161 | 162 | # prepare dense factors 163 | df = dilated_factor( 164 | cf0 if config.df_f0_type in ["cf0", "contf0"] else f0, 165 | config.sample_rate, 166 | config.dense_factor, 167 | ).repeat(hop_size, axis=0) 168 | 169 | # convert to torch tensors 170 | f0 = torch.FloatTensor(f0).view(1, 1, -1).to(device) 171 | cf0 = torch.FloatTensor(cf0).view(1, 1, -1).to(device) 172 | c = pad_fn(torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)) 173 | df = torch.FloatTensor(np.array(df)).view(1, 1, -1).to(device) 174 | 175 | # generate input signals 176 | if config.sine_f0_type in ["cf0", "contf0"]: 177 | in_signal = signal_generator(cf0) 178 | elif config.sine_f0_type == "f0": 179 | in_signal = signal_generator(f0) 180 | 181 | # synthesize with the neural vocoder 182 | y = model(in_signal, c, df)[0] 183 | 184 | # save output signal as PCM 16 bit wav file 185 | out_path = os.path.join(config.out_dir, wav_file).replace( 186 | ".wav", f"_f{config.f0_factor:.2f}.wav" 187 | ) 188 | sf.write( 189 | to_absolute_path(out_path), 190 | y.view(-1).cpu().numpy(), 191 | config.sample_rate, 192 | "PCM_16", 193 | ) 194 | 195 | 196 | if __name__ == "__main__": 197 | main() 198 | -------------------------------------------------------------------------------- /usfgan/layers/upsample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Upsampling module. 4 | 5 | This code is modified from https://github.com/r9y9/wavenet_vocoder. 6 | 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from usfgan.layers import Conv1d 13 | 14 | 15 | class Stretch2d(torch.nn.Module): 16 | """Stretch2d module.""" 17 | 18 | def __init__(self, x_scale, y_scale, mode="nearest"): 19 | """Initialize Stretch2d module. 20 | 21 | Args: 22 | x_scale (int): X scaling factor (Time axis in spectrogram). 23 | y_scale (int): Y scaling factor (Frequency axis in spectrogram). 24 | mode (str): Interpolation mode. 25 | 26 | """ 27 | super(Stretch2d, self).__init__() 28 | self.x_scale = x_scale 29 | self.y_scale = y_scale 30 | self.mode = mode 31 | 32 | def forward(self, x): 33 | """Calculate forward propagation. 34 | 35 | Args: 36 | x (Tensor): Input tensor (B, C, F, T). 37 | 38 | Returns: 39 | Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale) 40 | 41 | """ 42 | return F.interpolate( 43 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode 44 | ) 45 | 46 | 47 | class Conv2d(torch.nn.Conv2d): 48 | """Conv2d module with customized initialization.""" 49 | 50 | def __init__(self, *args, **kwargs): 51 | """Initialize Conv2d module.""" 52 | super(Conv2d, self).__init__(*args, **kwargs) 53 | 54 | def reset_parameters(self): 55 | """Reset parameters.""" 56 | self.weight.data.fill_(1.0 / np.prod(self.kernel_size)) 57 | if self.bias is not None: 58 | torch.nn.init.constant_(self.bias, 0.0) 59 | 60 | 61 | class UpsampleNetwork(torch.nn.Module): 62 | """Upsampling network module.""" 63 | 64 | def __init__( 65 | self, 66 | upsample_scales, 67 | nonlinear_activation=None, 68 | nonlinear_activation_params={}, 69 | interpolate_mode="nearest", 70 | freq_axis_kernel_size=1, 71 | use_causal_conv=False, 72 | ): 73 | """Initialize upsampling network module. 74 | 75 | Args: 76 | upsample_scales (list): List of upsampling scales. 77 | nonlinear_activation (str): Activation function name. 78 | nonlinear_activation_params (dict): Arguments for specified activation function. 79 | interpolate_mode (str): Interpolation mode. 80 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 81 | 82 | """ 83 | super(UpsampleNetwork, self).__init__() 84 | self.use_causal_conv = use_causal_conv 85 | self.up_layers = torch.nn.ModuleList() 86 | for scale in upsample_scales: 87 | # interpolation layer 88 | stretch = Stretch2d(scale, 1, interpolate_mode) 89 | self.up_layers += [stretch] 90 | 91 | # conv layer 92 | assert ( 93 | freq_axis_kernel_size - 1 94 | ) % 2 == 0, "Not support even number freq axis kernel size." 95 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 96 | kernel_size = (freq_axis_kernel_size, scale * 2 + 1) 97 | if use_causal_conv: 98 | padding = (freq_axis_padding, scale * 2) 99 | else: 100 | padding = (freq_axis_padding, scale) 101 | conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 102 | self.up_layers += [conv] 103 | 104 | # nonlinear 105 | if nonlinear_activation is not None: 106 | nonlinear = getattr(torch.nn, nonlinear_activation)( 107 | **nonlinear_activation_params 108 | ) 109 | self.up_layers += [nonlinear] 110 | 111 | def forward(self, c): 112 | """Calculate forward propagation. 113 | 114 | Args: 115 | c : Input tensor (B, C, T). 116 | 117 | Returns: 118 | Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). 119 | 120 | """ 121 | c = c.unsqueeze(1) # (B, 1, C, T) 122 | for f in self.up_layers: 123 | if self.use_causal_conv and isinstance(f, Conv2d): 124 | c = f(c)[..., : c.size(-1)] 125 | else: 126 | c = f(c) 127 | 128 | return c.squeeze(1) # (B, C, T') 129 | 130 | 131 | class ConvInUpsampleNetwork(torch.nn.Module): 132 | """Convolution + upsampling network module.""" 133 | 134 | def __init__( 135 | self, 136 | upsample_scales, 137 | nonlinear_activation=None, 138 | nonlinear_activation_params={}, 139 | interpolate_mode="nearest", 140 | freq_axis_kernel_size=1, 141 | aux_channels=80, 142 | aux_context_window=0, 143 | use_causal_conv=False, 144 | ): 145 | """Initialize convolution + upsampling network module. 146 | 147 | Args: 148 | upsample_scales (list): List of upsampling scales. 149 | nonlinear_activation (str): Activation function name. 150 | nonlinear_activation_params (dict): Arguments for specified activation function. 151 | mode (str): Interpolation mode. 152 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 153 | aux_channels (int): Number of channels of pre-convolutional layer. 154 | aux_context_window (int): Context window size of the pre-convolutional layer. 155 | use_causal_conv (bool): Whether to use causal structure. 156 | 157 | """ 158 | super(ConvInUpsampleNetwork, self).__init__() 159 | self.aux_context_window = aux_context_window 160 | self.use_causal_conv = use_causal_conv and aux_context_window > 0 161 | # To capture wide-context information in conditional features 162 | kernel_size = ( 163 | aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 164 | ) 165 | # NOTE(kan-bayashi): Here do not use padding because the input is already padded 166 | self.conv_in = Conv1d( 167 | aux_channels, aux_channels, kernel_size=kernel_size, bias=False 168 | ) 169 | self.upsample = UpsampleNetwork( 170 | upsample_scales=upsample_scales, 171 | nonlinear_activation=nonlinear_activation, 172 | nonlinear_activation_params=nonlinear_activation_params, 173 | interpolate_mode=interpolate_mode, 174 | freq_axis_kernel_size=freq_axis_kernel_size, 175 | use_causal_conv=use_causal_conv, 176 | ) 177 | 178 | def forward(self, c): 179 | """Calculate forward propagation. 180 | 181 | Args: 182 | c : Input tensor (B, C, T'). 183 | 184 | Returns: 185 | Tensor: Upsampled tensor (B, C, T), 186 | where T = (T' - aux_context_window * 2) * prod(upsample_scales). 187 | 188 | Note: 189 | The length of inputs considers the context window size. 190 | 191 | """ 192 | c_ = self.conv_in(c) 193 | c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ 194 | return self.upsample(c) 195 | -------------------------------------------------------------------------------- /egs/vctk/data/spk_info.yaml: -------------------------------------------------------------------------------- 1 | p225: 2 | f0_min: 130 3 | f0_max: 260 4 | pow_th: -20 5 | p226: 6 | f0_min: 80 7 | f0_max: 180 8 | pow_th: -20 9 | p227: 10 | f0_min: 80 11 | f0_max: 190 12 | pow_th: -20 13 | p228: 14 | f0_min: 120 15 | f0_max: 300 16 | pow_th: -20 17 | p229: 18 | f0_min: 140 19 | f0_max: 230 20 | pow_th: -20 21 | p230: 22 | f0_min: 140 23 | f0_max: 260 24 | pow_th: -20 25 | p231: 26 | f0_min: 110 27 | f0_max: 280 28 | pow_th: -20 29 | p232: 30 | f0_min: 70 31 | f0_max: 180 32 | pow_th: -20 33 | p233: 34 | f0_min: 150 35 | f0_max: 290 36 | pow_th: -20 37 | p234: 38 | f0_min: 130 39 | f0_max: 270 40 | pow_th: -20 41 | p236: 42 | f0_min: 170 43 | f0_max: 330 44 | pow_th: -20 45 | p237: 46 | f0_min: 60 47 | f0_max: 120 48 | pow_th: -20 49 | p238: 50 | f0_min: 140 51 | f0_max: 320 52 | pow_th: -20 53 | p239: 54 | f0_min: 140 55 | f0_max: 270 56 | pow_th: -20 57 | p240: 58 | f0_min: 150 59 | f0_max: 330 60 | pow_th: -20 61 | p241: 62 | f0_min: 85 63 | f0_max: 170 64 | pow_th: -20 65 | p243: 66 | f0_min: 80 67 | f0_max: 180 68 | pow_th: -20 69 | p244: 70 | f0_min: 150 71 | f0_max: 320 72 | pow_th: -20 73 | p245: 74 | f0_min: 70 75 | f0_max: 130 76 | pow_th: -20 77 | p246: 78 | f0_min: 70 79 | f0_max: 140 80 | pow_th: -20 81 | p247: 82 | f0_min: 90 83 | f0_max: 200 84 | pow_th: -20 85 | p248: 86 | f0_min: 130 87 | f0_max: 390 88 | pow_th: -20 89 | p249: 90 | f0_min: 130 91 | f0_max: 240 92 | pow_th: -20 93 | p250: 94 | f0_min: 110 95 | f0_max: 380 96 | pow_th: -20 97 | p251: 98 | f0_min: 80 99 | f0_max: 170 100 | pow_th: -20 101 | p252: 102 | f0_min: 80 103 | f0_max: 160 104 | pow_th: -20 105 | p253: 106 | f0_min: 160 107 | f0_max: 310 108 | pow_th: -20 109 | p254: 110 | f0_min: 55 111 | f0_max: 120 112 | pow_th: -20 113 | p255: 114 | f0_min: 95 115 | f0_max: 210 116 | pow_th: -20 117 | p256: 118 | f0_min: 60 119 | f0_max: 130 120 | pow_th: -20 121 | p257: 122 | f0_min: 150 123 | f0_max: 300 124 | pow_th: -20 125 | p258: 126 | f0_min: 90 127 | f0_max: 170 128 | pow_th: -20 129 | p259: 130 | f0_min: 80 131 | f0_max: 200 132 | pow_th: -20 133 | p260: 134 | f0_min: 70 135 | f0_max: 170 136 | pow_th: -20 137 | p261: 138 | f0_min: 140 139 | f0_max: 360 140 | pow_th: -20 141 | p262: 142 | f0_min: 120 143 | f0_max: 250 144 | pow_th: -20 145 | p263: 146 | f0_min: 70 147 | f0_max: 160 148 | pow_th: -20 149 | p264: 150 | f0_min: 150 151 | f0_max: 250 152 | pow_th: -20 153 | p265: 154 | f0_min: 130 155 | f0_max: 290 156 | pow_th: -20 157 | p266: 158 | f0_min: 100 159 | f0_max: 300 160 | pow_th: -20 161 | p267: 162 | f0_min: 130 163 | f0_max: 270 164 | pow_th: -20 165 | p268: 166 | f0_min: 130 167 | f0_max: 300 168 | pow_th: -20 169 | p269: 170 | f0_min: 150 171 | f0_max: 260 172 | pow_th: -20 173 | p270: 174 | f0_min: 75 175 | f0_max: 150 176 | pow_th: -20 177 | p271: 178 | f0_min: 90 179 | f0_max: 160 180 | pow_th: -20 181 | p272: 182 | f0_min: 80 183 | f0_max: 180 184 | pow_th: -20 185 | p273: 186 | f0_min: 100 187 | f0_max: 220 188 | pow_th: -20 189 | p274: 190 | f0_min: 75 191 | f0_max: 150 192 | pow_th: -20 193 | p275: 194 | f0_min: 80 195 | f0_max: 160 196 | pow_th: -20 197 | p276: 198 | f0_min: 170 199 | f0_max: 340 200 | pow_th: -20 201 | p277: 202 | f0_min: 160 203 | f0_max: 260 204 | pow_th: -20 205 | p278: 206 | f0_min: 80 207 | f0_max: 170 208 | pow_th: -20 209 | p279: 210 | f0_min: 80 211 | f0_max: 200 212 | pow_th: -20 213 | p280: 214 | f0_min: 140 215 | f0_max: 280 216 | pow_th: -20 217 | p281: 218 | f0_min: 70 219 | f0_max: 140 220 | pow_th: -20 221 | p282: 222 | f0_min: 150 223 | f0_max: 250 224 | pow_th: -20 225 | p283: 226 | f0_min: 160 227 | f0_max: 280 228 | pow_th: -20 229 | p284: 230 | f0_min: 60 231 | f0_max: 170 232 | pow_th: -20 233 | p285: 234 | f0_min: 80 235 | f0_max: 180 236 | pow_th: -20 237 | p286: 238 | f0_min: 90 239 | f0_max: 210 240 | pow_th: -20 241 | p287: 242 | f0_min: 65 243 | f0_max: 140 244 | pow_th: -20 245 | p288: 246 | f0_min: 130 247 | f0_max: 300 248 | pow_th: -20 249 | p292: 250 | f0_min: 70 251 | f0_max: 170 252 | pow_th: -20 253 | p293: 254 | f0_min: 130 255 | f0_max: 260 256 | pow_th: -20 257 | p294: 258 | f0_min: 120 259 | f0_max: 260 260 | pow_th: -20 261 | p295: 262 | f0_min: 140 263 | f0_max: 260 264 | pow_th: -20 265 | p297: 266 | f0_min: 120 267 | f0_max: 300 268 | pow_th: -20 269 | p298: 270 | f0_min: 80 271 | f0_max: 190 272 | pow_th: -20 273 | p299: 274 | f0_min: 110 275 | f0_max: 270 276 | pow_th: -20 277 | p300: 278 | f0_min: 130 279 | f0_max: 270 280 | pow_th: -20 281 | p301: 282 | f0_min: 120 283 | f0_max: 250 284 | pow_th: -20 285 | p302: 286 | f0_min: 90 287 | f0_max: 180 288 | pow_th: -20 289 | p303: 290 | f0_min: 160 291 | f0_max: 300 292 | pow_th: -20 293 | p304: 294 | f0_min: 65 295 | f0_max: 160 296 | pow_th: -20 297 | p305: 298 | f0_min: 160 299 | f0_max: 330 300 | pow_th: -20 301 | p306: 302 | f0_min: 130 303 | f0_max: 260 304 | pow_th: -20 305 | p307: 306 | f0_min: 160 307 | f0_max: 400 308 | pow_th: -20 309 | p308: 310 | f0_min: 140 311 | f0_max: 280 312 | pow_th: -20 313 | p310: 314 | f0_min: 170 315 | f0_max: 350 316 | pow_th: -20 317 | p311: 318 | f0_min: 75 319 | f0_max: 160 320 | pow_th: -20 321 | p312: 322 | f0_min: 130 323 | f0_max: 280 324 | pow_th: -20 325 | p313: 326 | f0_min: 120 327 | f0_max: 260 328 | pow_th: -20 329 | p314: 330 | f0_min: 120 331 | f0_max: 270 332 | pow_th: -20 333 | p315: 334 | f0_min: 90 335 | f0_max: 190 336 | pow_th: -20 337 | p316: 338 | f0_min: 65 339 | f0_max: 140 340 | pow_th: -20 341 | p317: 342 | f0_min: 175 343 | f0_max: 360 344 | pow_th: -20 345 | p318: 346 | f0_min: 110 347 | f0_max: 290 348 | pow_th: -20 349 | p323: 350 | f0_min: 130 351 | f0_max: 390 352 | pow_th: -20 353 | p326: 354 | f0_min: 50 355 | f0_max: 110 356 | pow_th: -20 357 | p329: 358 | f0_min: 150 359 | f0_max: 280 360 | pow_th: -20 361 | p330: 362 | f0_min: 140 363 | f0_max: 270 364 | pow_th: -20 365 | p333: 366 | f0_min: 120 367 | f0_max: 270 368 | pow_th: -20 369 | p334: 370 | f0_min: 65 371 | f0_max: 140 372 | pow_th: -20 373 | p335: 374 | f0_min: 130 375 | f0_max: 310 376 | pow_th: -20 377 | p336: 378 | f0_min: 130 379 | f0_max: 310 380 | pow_th: -20 381 | p339: 382 | f0_min: 120 383 | f0_max: 310 384 | pow_th: -20 385 | p340: 386 | f0_min: 170 387 | f0_max: 270 388 | pow_th: -20 389 | p341: 390 | f0_min: 130 391 | f0_max: 280 392 | pow_th: -20 393 | p343: 394 | f0_min: 110 395 | f0_max: 240 396 | pow_th: -20 397 | p345: 398 | f0_min: 70 399 | f0_max: 140 400 | pow_th: -20 401 | p347: 402 | f0_min: 60 403 | f0_max: 130 404 | pow_th: -20 405 | p351: 406 | f0_min: 170 407 | f0_max: 280 408 | pow_th: -20 409 | p360: 410 | f0_min: 70 411 | f0_max: 160 412 | pow_th: -20 413 | p361: 414 | f0_min: 120 415 | f0_max: 290 416 | pow_th: -20 417 | p362: 418 | f0_min: 130 419 | f0_max: 300 420 | pow_th: -20 421 | p363: 422 | f0_min: 70 423 | f0_max: 180 424 | pow_th: -20 425 | p364: 426 | f0_min: 65 427 | f0_max: 145 428 | pow_th: -20 429 | p374: 430 | f0_min: 70 431 | f0_max: 190 432 | pow_th: -20 433 | p376: 434 | f0_min: 70 435 | f0_max: 170 436 | pow_th: -20 437 | s5: 438 | f0_min: 140 439 | f0_max: 340 440 | pow_th: -20 441 | -------------------------------------------------------------------------------- /usfgan/layers/cheaptrick.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Spectral envelopes estimation module based on CheapTrick. 7 | 8 | References: 9 | - https://www.sciencedirect.com/science/article/pii/S0167639314000697 10 | - https://github.com/mmorise/World 11 | 12 | """ 13 | 14 | import math 15 | 16 | import torch 17 | import torch.fft 18 | import torch.nn as nn 19 | 20 | 21 | class AdaptiveWindowing(nn.Module): 22 | """CheapTrick F0 adptive windowing module.""" 23 | 24 | def __init__( 25 | self, 26 | sampling_rate, 27 | hop_size, 28 | fft_size, 29 | f0_floor, 30 | f0_ceil, 31 | ): 32 | """Initilize AdaptiveWindowing module. 33 | 34 | Args: 35 | sampling_rate (int): Sampling rate. 36 | hop_size (int): Hop size. 37 | fft_size (int): FFT size. 38 | f0_floor (int): Minimum value of F0. 39 | f0_ceil (int): Maximum value of F0. 40 | 41 | """ 42 | super(AdaptiveWindowing, self).__init__() 43 | 44 | self.sampling_rate = sampling_rate 45 | self.hop_size = hop_size 46 | self.fft_size = fft_size 47 | self.register_buffer("window", torch.zeros((f0_ceil + 1, fft_size))) 48 | self.zero_padding = nn.ConstantPad2d((fft_size // 2, fft_size // 2, 0, 0), 0) 49 | 50 | # Pre-calculation of the window functions 51 | for f0 in range(f0_floor, f0_ceil + 1): 52 | half_win_len = round(1.5 * self.sampling_rate / f0) 53 | base_index = torch.arange( 54 | -half_win_len, half_win_len + 1, dtype=torch.int64 55 | ) 56 | position = base_index / 1.5 / self.sampling_rate 57 | left = fft_size // 2 - half_win_len 58 | right = fft_size // 2 + half_win_len + 1 59 | window = torch.zeros(fft_size) 60 | window[left:right] = 0.5 * torch.cos(math.pi * position * f0) + 0.5 61 | average = torch.sum(window * window).pow(0.5) 62 | self.window[f0] = window / average 63 | 64 | def forward(self, x, f, power=False): 65 | """Calculate forward propagation. 66 | 67 | Args: 68 | x (Tensor): Waveform (B, fft_size // 2 + 1, T). 69 | f (Tensor): F0 sequence (B, T'). 70 | power (boot): Whether to use power or magnitude. 71 | 72 | Returns: 73 | Tensor: Power spectrogram (B, bin_size, T'). 74 | 75 | """ 76 | # Get the matrix of window functions corresponding to F0 77 | x = self.zero_padding(x).unfold(1, self.fft_size, self.hop_size) 78 | windows = self.window[f] 79 | 80 | # Adaptive windowing and calculate power spectrogram. 81 | # In test, change x[:, : -1, :] to x. 82 | x = torch.abs(torch.fft.rfft(x[:, :-1, :] * windows)) 83 | x = x.pow(2) if power else x 84 | 85 | return x 86 | 87 | 88 | class AdaptiveLiftering(nn.Module): 89 | """CheapTrick F0 adptive windowing module.""" 90 | 91 | def __init__( 92 | self, 93 | sampling_rate, 94 | fft_size, 95 | f0_floor, 96 | f0_ceil, 97 | q1=-0.15, 98 | ): 99 | """Initilize AdaptiveLiftering module. 100 | 101 | Args: 102 | sampling_rate (int): Sampling rate. 103 | fft_size (int): FFT size. 104 | f0_floor (int): Minimum value of F0. 105 | f0_ceil (int): Maximum value of F0. 106 | q1 (float): Parameter to remove effect of adjacent harmonics. 107 | 108 | """ 109 | super(AdaptiveLiftering, self).__init__() 110 | 111 | self.sampling_rate = sampling_rate 112 | self.bin_size = fft_size // 2 + 1 113 | self.q1 = q1 114 | self.q0 = 1.0 - 2.0 * q1 115 | self.register_buffer( 116 | "smoothing_lifter", torch.zeros((f0_ceil + 1, self.bin_size)) 117 | ) 118 | self.register_buffer( 119 | "compensation_lifter", torch.zeros((f0_ceil + 1, self.bin_size)) 120 | ) 121 | 122 | # Pre-calculation of the smoothing lifters and compensation lifters 123 | for f0 in range(f0_floor, f0_ceil + 1): 124 | smoothing_lifter = torch.zeros(self.bin_size) 125 | compensation_lifter = torch.zeros(self.bin_size) 126 | quefrency = torch.arange(1, self.bin_size) / sampling_rate 127 | smoothing_lifter[0] = 1.0 128 | smoothing_lifter[1:] = torch.sin(math.pi * f0 * quefrency) / ( 129 | math.pi * f0 * quefrency 130 | ) 131 | compensation_lifter[0] = self.q0 + 2.0 * self.q1 132 | compensation_lifter[1:] = self.q0 + 2.0 * self.q1 * torch.cos( 133 | 2.0 * math.pi * f0 * quefrency 134 | ) 135 | self.smoothing_lifter[f0] = smoothing_lifter 136 | self.compensation_lifter[f0] = compensation_lifter 137 | 138 | def forward(self, x, f, elim_0th=False): 139 | """Calculate forward propagation. 140 | 141 | Args: 142 | x (Tensor): Power spectrogram (B, bin_size, T'). 143 | f (Tensor): F0 sequence (B, T'). 144 | elim_0th (bool): Whether to eliminate cepstram 0th component. 145 | 146 | Returns: 147 | Tensor: Estimated spectral envelope (B, bin_size, T'). 148 | 149 | """ 150 | # Setting the smoothing lifter and compensation lifter 151 | smoothing_lifter = self.smoothing_lifter[f] 152 | compensation_lifter = self.compensation_lifter[f] 153 | 154 | # Calculating cepstrum 155 | tmp = torch.cat((x, torch.flip(x[:, :, 1:-1], [2])), dim=2) 156 | cepstrum = torch.fft.rfft(torch.log(torch.clamp(tmp, min=1e-7))).real 157 | 158 | # Set the 0th cepstrum to 0 159 | if elim_0th: 160 | cepstrum[..., 0] = 0 161 | 162 | # Liftering cepstrum with the lifters 163 | liftered_cepstrum = cepstrum * smoothing_lifter * compensation_lifter 164 | 165 | # Return the result to the spectral domain 166 | x = torch.fft.irfft(liftered_cepstrum)[:, :, : self.bin_size] 167 | 168 | return x 169 | 170 | 171 | class CheapTrick(nn.Module): 172 | """CheapTrick based spectral envelope estimation module.""" 173 | 174 | def __init__( 175 | self, 176 | sampling_rate, 177 | hop_size, 178 | fft_size, 179 | f0_floor=70, 180 | f0_ceil=340, 181 | uv_threshold=0, 182 | q1=-0.15, 183 | ): 184 | """Initilize AdaptiveLiftering module. 185 | 186 | Args: 187 | sampling_rate (int): Sampling rate. 188 | hop_size (int): Hop size. 189 | fft_size (int): FFT size. 190 | f0_floor (int): Minimum value of F0. 191 | f0_ceil (int): Maximum value of F0. 192 | uv_threshold (float): V/UV determining threshold. 193 | q1 (float): Parameter to remove effect of adjacent harmonics. 194 | 195 | """ 196 | super(CheapTrick, self).__init__() 197 | 198 | # fft_size must be larger than 3.0 * sampling_rate / f0_floor 199 | assert fft_size > 3.0 * sampling_rate / f0_floor 200 | self.f0_floor = f0_floor 201 | self.f0_ceil = f0_ceil 202 | self.uv_threshold = uv_threshold 203 | 204 | self.ada_wind = AdaptiveWindowing( 205 | sampling_rate, 206 | hop_size, 207 | fft_size, 208 | f0_floor, 209 | f0_ceil, 210 | ) 211 | self.ada_lift = AdaptiveLiftering( 212 | sampling_rate, 213 | fft_size, 214 | f0_floor, 215 | f0_ceil, 216 | q1, 217 | ) 218 | 219 | def forward(self, x, f, power=False, elim_0th=False): 220 | """Calculate forward propagation. 221 | 222 | Args: 223 | x (Tensor): Power spectrogram (B, T). 224 | f (Tensor): F0 sequence (B, T'). 225 | power (boot): Whether to use power or magnitude spectrogram. 226 | elim_0th (bool): Whether to eliminate cepstram 0th component. 227 | 228 | Returns: 229 | Tensor: Estimated spectral envelope (B, bin_size, T'). 230 | 231 | """ 232 | # Step0: Round F0 values to integers. 233 | voiced = (f > self.uv_threshold) * torch.ones_like(f) 234 | f = voiced * f + (1.0 - voiced) * self.f0_ceil 235 | f = torch.round(torch.clamp(f, min=self.f0_floor, max=self.f0_ceil)).to( 236 | torch.int64 237 | ) 238 | 239 | # Step1: Adaptive windowing and calculate power or amplitude spectrogram. 240 | x = self.ada_wind(x, f, power) 241 | 242 | # Step3: Smoothing (log axis) and spectral recovery on the cepstrum domain. 243 | x = self.ada_lift(x, f, elim_0th) 244 | 245 | return x 246 | -------------------------------------------------------------------------------- /usfgan/losses/stft.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based loss modules. 7 | 8 | References: 9 | - https://github.com/kan-bayashi/ParallelWaveGAN 10 | 11 | """ 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from librosa.filters import mel as librosa_mel 17 | 18 | 19 | def stft( 20 | x, fft_size, hop_size, win_length, window, center=True, onesided=True, power=False 21 | ): 22 | """Perform STFT and convert to magnitude spectrogram. 23 | 24 | Args: 25 | x (Tensor): Input signal tensor (B, T). 26 | fft_size (int): FFT size. 27 | hop_size (int): Hop size. 28 | win_length (int): Window length. 29 | window (str): Window function type. 30 | 31 | Returns: 32 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 33 | 34 | """ 35 | x_stft = torch.stft( 36 | x, 37 | fft_size, 38 | hop_size, 39 | win_length, 40 | window, 41 | center=center, 42 | onesided=onesided, 43 | return_complex=False, 44 | ) 45 | real = x_stft[..., 0] 46 | imag = x_stft[..., 1] 47 | 48 | if power: 49 | return torch.clamp(real ** 2 + imag ** 2, min=1e-7).transpose(2, 1) 50 | else: 51 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 52 | 53 | 54 | class SpectralConvergenceLoss(nn.Module): 55 | """Spectral convergence loss module.""" 56 | 57 | def __init__(self): 58 | """Initilize spectral convergence loss module.""" 59 | super(SpectralConvergenceLoss, self).__init__() 60 | 61 | def forward(self, x_mag, y_mag): 62 | """Calculate forward propagation. 63 | 64 | Args: 65 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 66 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 67 | 68 | Returns: 69 | Tensor: Spectral convergence loss value. 70 | 71 | """ 72 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 73 | 74 | 75 | class LogSTFTMagnitudeLoss(nn.Module): 76 | """Log STFT magnitude loss module.""" 77 | 78 | def __init__(self): 79 | """Initilize log STFT magnitude loss module.""" 80 | super(LogSTFTMagnitudeLoss, self).__init__() 81 | 82 | def forward(self, x_mag, y_mag): 83 | """Calculate forward propagation. 84 | 85 | Args: 86 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 87 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 88 | 89 | Returns: 90 | Tensor: Log STFT magnitude loss value. 91 | 92 | """ 93 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 94 | 95 | 96 | class STFTLoss(nn.Module): 97 | """STFT loss module.""" 98 | 99 | def __init__( 100 | self, fft_size=1024, hop_size=120, win_length=600, window="hann_window" 101 | ): 102 | """Initialize STFT loss module.""" 103 | super(STFTLoss, self).__init__() 104 | self.fft_size = fft_size 105 | self.hop_size = hop_size 106 | self.win_length = win_length 107 | self.register_buffer("window", getattr(torch, window)(win_length)) 108 | self.spectral_convergence_loss = SpectralConvergenceLoss() 109 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 110 | 111 | def forward(self, x, y): 112 | """Calculate forward propagation. 113 | 114 | Args: 115 | x (Tensor): Predicted signal (B, T). 116 | y (Tensor): Groundtruth signal (B, T). 117 | 118 | Returns: 119 | Tensor: Spectral convergence loss value. 120 | Tensor: Log STFT magnitude loss value. 121 | 122 | """ 123 | x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window) 124 | y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window) 125 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag) 126 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 127 | 128 | return sc_loss, mag_loss 129 | 130 | 131 | class MultiResolutionSTFTLoss(nn.Module): 132 | """Multi resolution STFT loss module.""" 133 | 134 | def __init__( 135 | self, 136 | fft_sizes=[1024, 2048, 512], 137 | hop_sizes=[120, 240, 50], 138 | win_lengths=[600, 1200, 240], 139 | window="hann_window", 140 | ): 141 | """Initialize Multi resolution STFT loss module. 142 | 143 | Args: 144 | fft_sizes (list): List of FFT sizes. 145 | hop_sizes (list): List of hop sizes. 146 | win_lengths (list): List of window lengths. 147 | window (str): Window function type. 148 | 149 | """ 150 | super(MultiResolutionSTFTLoss, self).__init__() 151 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 152 | self.stft_losses = nn.ModuleList() 153 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 154 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 155 | 156 | def forward(self, x, y): 157 | """Calculate forward propagation. 158 | 159 | Args: 160 | x (Tensor): Predicted signal (B, 1, T). 161 | y (Tensor): Groundtruth signal (B, 1, T). 162 | 163 | Returns: 164 | Tensor: Multi resolution spectral convergence loss value. 165 | Tensor: Multi resolution log STFT magnitude loss value. 166 | 167 | """ 168 | x = x.squeeze(1) 169 | y = y.squeeze(1) 170 | 171 | sc_loss = 0.0 172 | mag_loss = 0.0 173 | for f in self.stft_losses: 174 | sc_l, mag_l = f(x, y) 175 | sc_loss += sc_l 176 | mag_loss += mag_l 177 | sc_loss /= len(self.stft_losses) 178 | mag_loss /= len(self.stft_losses) 179 | 180 | return sc_loss, mag_loss 181 | 182 | 183 | class LogSTFTPowerLoss(nn.Module): 184 | """Log STFT power loss module.""" 185 | 186 | def __init__( 187 | self, fft_size=1024, hop_size=120, win_length=600, window="hann_window" 188 | ): 189 | """Initialize STFT loss module.""" 190 | super(LogSTFTPowerLoss, self).__init__() 191 | self.fft_size = fft_size 192 | self.hop_size = hop_size 193 | self.win_length = win_length 194 | self.register_buffer("window", getattr(torch, window)(win_length)) 195 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 196 | self.mse = nn.MSELoss() 197 | 198 | def forward(self, x, y): 199 | """Calculate forward propagation. 200 | 201 | Args: 202 | x (Tensor): Predicted signal (B, T). 203 | y (Tensor): Groundtruth signal (B, T). 204 | 205 | Returns: 206 | Tensor: Spectral convergence loss value. 207 | Tensor: Log STFT magnitude loss value. 208 | 209 | """ 210 | x_pow = stft( 211 | x, 212 | self.fft_size, 213 | self.hop_size, 214 | self.win_length, 215 | self.window, 216 | power=True, 217 | ) 218 | y_pow = stft( 219 | y, 220 | self.fft_size, 221 | self.hop_size, 222 | self.win_length, 223 | self.window, 224 | power=True, 225 | ) 226 | stft_loss = ( 227 | self.mse( 228 | torch.log(torch.clamp(x_pow, min=1e-7)), 229 | torch.log(torch.clamp(y_pow, min=1e-7)), 230 | ) 231 | / 2.0 232 | ) 233 | 234 | return stft_loss 235 | 236 | 237 | class MultiResolutionLogSTFTPowerLoss(nn.Module): 238 | """Multi-resolution log STFT power loss module. 239 | 240 | This loss is same as the loss of Neural Source-Filter. 241 | https://arxiv.org/abs/1904.12088 242 | """ 243 | 244 | def __init__( 245 | self, 246 | fft_sizes=[1024, 2048, 512], 247 | hop_sizes=[120, 240, 50], 248 | win_lengths=[600, 1200, 240], 249 | window="hann_window", 250 | ): 251 | """Initialize Multi-resolution STFT loss module. 252 | 253 | Args: 254 | fft_sizes (list): List of FFT sizes. 255 | hop_sizes (list): List of hop sizes. 256 | win_lengths (list): List of window lengths. 257 | window (str): Window function type. 258 | 259 | """ 260 | super(MultiResolutionLogSTFTPowerLoss, self).__init__() 261 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 262 | self.stft_losses = nn.ModuleList() 263 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 264 | self.stft_losses += [LogSTFTPowerLoss(fs, ss, wl, window)] 265 | 266 | def forward(self, x, y): 267 | """Multi-resolution log STFT power loss value. 268 | 269 | Args: 270 | x (Tensor): Predicted signal (B, 1, T). 271 | y (Tensor): Groundtruth signal (B, 1, T). 272 | 273 | Returns: 274 | Tensor: Multi-resolution log STFT power loss value. 275 | 276 | """ 277 | x = x.squeeze(1) 278 | y = y.squeeze(1) 279 | 280 | stft_loss = 0.0 281 | for f in self.stft_losses: 282 | l = f(x, y) 283 | stft_loss += l 284 | stft_loss /= len(self.stft_losses) 285 | 286 | return stft_loss 287 | 288 | 289 | class MelSpectralLoss(nn.Module): 290 | """Mel-spectral L1 loss module.""" 291 | 292 | def __init__( 293 | self, 294 | fft_size=1024, 295 | hop_size=120, 296 | win_length=1024, 297 | window="hann_window", 298 | sampling_rate=24000, 299 | n_mels=80, 300 | fmin=0, 301 | fmax=None, 302 | ): 303 | """Initialize MelSpectralLoss loss. 304 | 305 | Args: 306 | fft_size (int): FFT points. 307 | hop_length (int): Hop length. 308 | win_length (Optional[int]): Window length. 309 | window (str): Window type. 310 | sampling_rate (int): Sampling rate. 311 | n_mels (int): Number of Mel basis. 312 | fmin (Optional[int]): Minimum frequency of mel-filter-bank. 313 | fmax (Optional[int]): Maximum frequency of mel-filter-bank. 314 | 315 | """ 316 | super().__init__() 317 | self.fft_size = fft_size 318 | self.hop_size = hop_size 319 | self.win_length = win_length if win_length is not None else fft_size 320 | self.register_buffer("window", getattr(torch, window)(self.win_length)) 321 | self.sampling_rate = sampling_rate 322 | self.n_mels = n_mels 323 | self.fmin = fmin 324 | self.fmax = fmax if fmax is not None else sampling_rate / 2 325 | melmat = librosa_mel( 326 | sr=sampling_rate, n_fft=fft_size, n_mels=n_mels, fmin=fmin, fmax=fmax 327 | ).T 328 | self.register_buffer("melmat", torch.from_numpy(melmat).float()) 329 | 330 | def forward(self, x, y): 331 | """Calculate Mel-spectral L1 loss. 332 | 333 | Args: 334 | x (Tensor): Generated waveform tensor (B, 1, T). 335 | y (Tensor): Groundtruth waveform tensor (B, 1, T). 336 | 337 | Returns: 338 | Tensor: Mel-spectral L1 loss value. 339 | 340 | """ 341 | x = x.squeeze(1) 342 | y = y.squeeze(1) 343 | x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window) 344 | y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window) 345 | x_log_mel = torch.log(torch.clamp(torch.matmul(x_mag, self.melmat), min=1e-7)) 346 | y_log_mel = torch.log(torch.clamp(torch.matmul(y_mag, self.melmat), min=1e-7)) 347 | mel_loss = F.l1_loss(x_log_mel, y_log_mel) 348 | 349 | return mel_loss 350 | -------------------------------------------------------------------------------- /usfgan/datasets/audio_feat_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Dataset modules. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | - https://github.com/kan-bayashi/ParallelWaveGAN 11 | 12 | """ 13 | 14 | from logging import getLogger 15 | from multiprocessing import Manager 16 | 17 | import numpy as np 18 | import soundfile as sf 19 | from hydra.utils import to_absolute_path 20 | from joblib import load 21 | from torch.utils.data import Dataset 22 | from usfgan.utils import ( 23 | check_filename, 24 | dilated_factor, 25 | read_hdf5, 26 | read_txt, 27 | validate_length, 28 | ) 29 | 30 | # A logger for this file 31 | logger = getLogger(__name__) 32 | 33 | 34 | class AudioFeatDataset(Dataset): 35 | """PyTorch compatible audio and acoustic feat. dataset.""" 36 | 37 | def __init__( 38 | self, 39 | stats, 40 | audio_list, 41 | feat_list, 42 | audio_length_threshold=None, 43 | feat_length_threshold=None, 44 | return_filename=False, 45 | allow_cache=False, 46 | sample_rate=24000, 47 | hop_size=120, 48 | dense_factor=4, 49 | df_f0_type="contf0", 50 | aux_feats=["mcep", "mcap"], 51 | ): 52 | """Initialize dataset. 53 | 54 | Args: 55 | stats (str): Filename of the statistic hdf5 file. 56 | audio_list (str): Filename of the list of audio files. 57 | feat_list (str): Filename of the list of feature files. 58 | audio_length_threshold (int): Threshold to remove short audio files. 59 | feat_length_threshold (int): Threshold to remove short feature files. 60 | return_filename (bool): Whether to return the filename with arrays. 61 | allow_cache (bool): Whether to allow cache of the loaded files. 62 | hop_size (int): Hope size of acoustic feature 63 | dense_factor (int): Number of taps in one cycle. 64 | aux_feats (str): Type of auxiliary features. 65 | 66 | """ 67 | # load audio and feature files & check filename 68 | audio_files = read_txt(to_absolute_path(audio_list)) 69 | feat_files = read_txt(to_absolute_path(feat_list)) 70 | assert check_filename(audio_files, feat_files) 71 | 72 | # filter by threshold 73 | if audio_length_threshold is not None: 74 | audio_lengths = [sf.read(to_absolute_path(f)).shape[0] for f in audio_files] 75 | idxs = [ 76 | idx 77 | for idx in range(len(audio_files)) 78 | if audio_lengths[idx] > audio_length_threshold 79 | ] 80 | if len(audio_files) != len(idxs): 81 | logger.warning( 82 | f"Some files are filtered by audio length threshold " 83 | f"({len(audio_files)} -> {len(idxs)})." 84 | ) 85 | audio_files = [audio_files[idx] for idx in idxs] 86 | feat_files = [feat_files[idx] for idx in idxs] 87 | if feat_length_threshold is not None: 88 | f0_lengths = [ 89 | read_hdf5(to_absolute_path(f), "/f0").shape[0] for f in feat_files 90 | ] 91 | idxs = [ 92 | idx 93 | for idx in range(len(feat_files)) 94 | if f0_lengths[idx] > feat_length_threshold 95 | ] 96 | if len(feat_files) != len(idxs): 97 | logger.warning( 98 | f"Some files are filtered by mel length threshold " 99 | f"({len(feat_files)} -> {len(idxs)})." 100 | ) 101 | audio_files = [audio_files[idx] for idx in idxs] 102 | feat_files = [feat_files[idx] for idx in idxs] 103 | 104 | # assert the number of files 105 | assert len(audio_files) != 0, f"${audio_list} is empty." 106 | assert len(audio_files) == len( 107 | feat_files 108 | ), f"Number of audio and features files are different ({len(audio_files)} vs {len(feat_files)})." 109 | 110 | self.audio_files = audio_files 111 | self.feat_files = feat_files 112 | self.return_filename = return_filename 113 | self.allow_cache = allow_cache 114 | self.sample_rate = sample_rate 115 | self.hop_size = hop_size 116 | self.dense_factor = dense_factor 117 | self.aux_feats = aux_feats 118 | self.df_f0_type = df_f0_type 119 | logger.info(f"Feature type : {self.aux_feats}") 120 | 121 | if allow_cache: 122 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 123 | self.manager = Manager() 124 | self.caches = self.manager.list() 125 | self.caches += [() for _ in range(len(audio_files))] 126 | 127 | # define feature pre-processing function 128 | self.scaler = load(stats) 129 | 130 | def __getitem__(self, idx): 131 | """Get specified idx items. 132 | 133 | Args: 134 | idx (int): Index of the item. 135 | 136 | Returns: 137 | str: Utterance id (only in return_filename = True). 138 | ndarray: Audio signal (T,). 139 | ndarray: UV feature (T', C). 140 | ndarray: F0 feature (T', C). 141 | ndarray: Mel-spectrogram feature (T', C). 142 | ndarray: Auxiliary feature (T', C). 143 | ndarray: Dilated factor (T, 1). 144 | 145 | """ 146 | if self.allow_cache and len(self.caches[idx]) != 0: 147 | return self.caches[idx] 148 | # load audio and features 149 | audio, sr = sf.read(to_absolute_path(self.audio_files[idx])) 150 | # audio & feature pre-processing 151 | audio = audio.astype(np.float32) 152 | 153 | # get auxiliary features 154 | aux_feats = [] 155 | for feat_type in self.aux_feats: 156 | aux_feat = read_hdf5( 157 | to_absolute_path(self.feat_files[idx]), f"/{feat_type}" 158 | ) 159 | if feat_type != "uv": 160 | aux_feat = self.scaler[f"{feat_type}"].transform(aux_feat) 161 | aux_feats += [aux_feat] 162 | aux_feats = np.concatenate(aux_feats, axis=1) 163 | 164 | # adjust length 165 | audio, aux_feats = validate_length(audio, aux_feats, self.hop_size) 166 | 167 | # get dilated factor sequence 168 | f0 = read_hdf5(to_absolute_path(self.feat_files[idx]), "/f0") # descrete F0 169 | contf0 = read_hdf5( 170 | to_absolute_path(self.feat_files[idx]), "/contf0" 171 | ) # continuous F0 172 | aux_feats, f0 = validate_length(aux_feats, f0) 173 | f0, contf0 = validate_length(f0, contf0) 174 | if self.df_f0_type == "contf0": 175 | df = dilated_factor( 176 | np.squeeze(contf0.copy()), self.sample_rate, self.dense_factor 177 | ) 178 | else: 179 | df = dilated_factor( 180 | np.squeeze(f0.copy()), self.sample_rate, self.dense_factor 181 | ) 182 | df = df.repeat(self.hop_size, axis=0) 183 | 184 | if self.return_filename: 185 | items = self.feat_files[idx], audio, aux_feats, df, f0, contf0 186 | else: 187 | items = audio, aux_feats, df, f0, contf0 188 | 189 | if self.allow_cache: 190 | self.caches[idx] = items 191 | 192 | return items 193 | 194 | def __len__(self): 195 | """Return dataset length. 196 | 197 | Returns: 198 | int: The length of dataset. 199 | 200 | """ 201 | return len(self.audio_files) 202 | 203 | 204 | class FeatDataset(Dataset): 205 | """PyTorch compatible mel dataset.""" 206 | 207 | def __init__( 208 | self, 209 | stats, 210 | feat_list, 211 | feat_length_threshold=None, 212 | return_filename=False, 213 | allow_cache=False, 214 | sample_rate=24000, 215 | hop_size=120, 216 | dense_factor=4, 217 | df_f0_type="contf0", 218 | aux_feats=["mcep", "mcap"], 219 | f0_factor=1.0, 220 | ): 221 | """Initialize dataset. 222 | 223 | Args: 224 | stats (str): Filename of the statistic hdf5 file. 225 | feat_list (str): Filename of the list of feature files. 226 | feat_length_threshold (int): Threshold to remove short feature files. 227 | return_filename (bool): Whether to return the utterance id with arrays. 228 | allow_cache (bool): Whether to allow cache of the loaded files. 229 | hop_size (int): Hope size of acoustic feature 230 | dense_factor (int): Number of taps in one cycle. 231 | aux_feats (str): Type of auxiliary features. 232 | f0_factor (float): Ratio of scaled f0. 233 | 234 | """ 235 | # load feat. files 236 | feat_files = read_txt(to_absolute_path(feat_list)) 237 | 238 | # filter by threshold 239 | if feat_length_threshold is not None: 240 | f0_lengths = [ 241 | read_hdf5(to_absolute_path(f), "/f0").shape[0] for f in feat_files 242 | ] 243 | idxs = [ 244 | idx 245 | for idx in range(len(feat_files)) 246 | if f0_lengths[idx] > feat_length_threshold 247 | ] 248 | if len(feat_files) != len(idxs): 249 | logger.warning( 250 | f"Some files are filtered by mel length threshold " 251 | f"({len(feat_files)} -> {len(idxs)})." 252 | ) 253 | feat_files = [feat_files[idx] for idx in idxs] 254 | 255 | # assert the number of files 256 | assert len(feat_files) != 0, f"${feat_list} is empty." 257 | 258 | self.feat_files = feat_files 259 | self.return_filename = return_filename 260 | self.allow_cache = allow_cache 261 | self.sample_rate = sample_rate 262 | self.hop_size = hop_size 263 | self.dense_factor = dense_factor 264 | self.df_f0_type = df_f0_type 265 | self.aux_feats = aux_feats 266 | self.f0_factor = f0_factor 267 | logger.info(f"Feature type : {self.aux_feats}") 268 | 269 | if allow_cache: 270 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 271 | self.manager = Manager() 272 | self.caches = self.manager.list() 273 | self.caches += [() for _ in range(len(feat_files))] 274 | 275 | # define feature pre-processing function 276 | self.scaler = load(stats) 277 | 278 | def __getitem__(self, idx): 279 | """Get specified idx items. 280 | 281 | Args: 282 | idx (int): Index of the item. 283 | 284 | Returns: 285 | str: Utterance id (only in return_filename = True). 286 | ndarray: UV feature (T', C). 287 | ndarray: F0 feature (T', C). 288 | ndarray: Mel-spectrogram feature (T', C). 289 | ndarray: Auxiliary feature (T', C). 290 | ndarray: Dilated factor (T, 1). 291 | 292 | """ 293 | if self.allow_cache and len(self.caches[idx]) != 0: 294 | return self.caches[idx] 295 | 296 | # get auxiliary features 297 | aux_feats = [] 298 | for feat_type in self.aux_feats: 299 | aux_feat = read_hdf5( 300 | to_absolute_path(self.feat_files[idx]), f"/{feat_type}" 301 | ) 302 | if "f0" in feat_type: 303 | aux_feat *= self.f0_factor 304 | if feat_type != "uv": 305 | aux_feat = self.scaler[f"{feat_type}"].transform(aux_feat) 306 | aux_feats += [aux_feat] 307 | aux_feats = np.concatenate(aux_feats, axis=1) 308 | 309 | # get dilated factor sequence 310 | f0 = read_hdf5(to_absolute_path(self.feat_files[idx]), "/f0") # descrete F0 311 | contf0 = read_hdf5( 312 | to_absolute_path(self.feat_files[idx]), "/contf0" 313 | ) # continuous F0 314 | f0 *= self.f0_factor 315 | contf0 *= self.f0_factor 316 | aux_feats, f0 = validate_length(aux_feats, f0) 317 | f0, contf0 = validate_length(f0, contf0) 318 | if self.df_f0_type == "contf0": 319 | df = dilated_factor( 320 | np.squeeze(contf0.copy()), self.sample_rate, self.dense_factor 321 | ) 322 | else: 323 | df = dilated_factor( 324 | np.squeeze(f0.copy()), self.sample_rate, self.dense_factor 325 | ) 326 | df = df.repeat(self.hop_size, axis=0) 327 | 328 | if self.return_filename: 329 | items = self.feat_files[idx], aux_feats, df, f0, contf0 330 | else: 331 | items = aux_feats, df, f0, contf0 332 | 333 | if self.allow_cache: 334 | self.caches[idx] = items 335 | 336 | return items 337 | 338 | def __len__(self): 339 | """Return dataset length. 340 | 341 | Returns: 342 | int: The length of dataset. 343 | 344 | """ 345 | return len(self.feat_files) 346 | -------------------------------------------------------------------------------- /usfgan/bin/extract_features.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Feature extraction script. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | - https://github.com/kan-bayashi/ParallelWaveGAN 11 | - https://github.com/k2kobayashi/sprocket 12 | 13 | """ 14 | 15 | import copy 16 | import multiprocessing as mp 17 | import os 18 | import sys 19 | from logging import getLogger 20 | 21 | import hydra 22 | import librosa 23 | import numpy as np 24 | import pysptk 25 | import pyworld 26 | import soundfile as sf 27 | import yaml 28 | from hydra.utils import to_absolute_path 29 | from omegaconf import DictConfig, OmegaConf 30 | from scipy.interpolate import interp1d 31 | from scipy.signal import firwin, lfilter 32 | from usfgan.utils import read_txt, write_hdf5 33 | 34 | # A logger for this file 35 | logger = getLogger(__name__) 36 | 37 | 38 | def path_create(wav_list, in_dir, out_dir, extname): 39 | for wav_name in wav_list: 40 | path_replace(wav_name, in_dir, out_dir, extname=extname) 41 | 42 | 43 | def path_replace(filepath, inputpath, outputpath, extname=None): 44 | if extname is not None: 45 | filepath = f"{filepath.split('.')[0]}.{extname}" 46 | filepath = filepath.replace(inputpath, outputpath) 47 | if not os.path.exists(os.path.dirname(filepath)): 48 | os.makedirs(os.path.dirname(filepath)) 49 | return filepath 50 | 51 | 52 | def spk_division(file_list, config, spkinfo, split="/"): 53 | """Divide list into speaker-dependent list 54 | 55 | Args: 56 | file_list (list): Waveform list 57 | config (dict): Config 58 | spkinfo (dict): Dictionary of 59 | speaker-dependent f0 range and power threshold 60 | split: Path split string 61 | 62 | Return: 63 | (list): List of divided file lists 64 | (list): List of speaker-dependent configs 65 | 66 | """ 67 | file_lists, configs, tempf = [], [], [] 68 | prespk = None 69 | for file in file_list: 70 | spk = file.split(split)[config.spkidx] 71 | if spk != prespk: 72 | if tempf: 73 | file_lists.append(tempf) 74 | tempf = [] 75 | prespk = spk 76 | tempc = copy.deepcopy(config) 77 | if spk in spkinfo: 78 | tempc["minf0"] = spkinfo[spk]["f0_min"] 79 | tempc["maxf0"] = spkinfo[spk]["f0_max"] 80 | tempc["pow_th"] = spkinfo[spk]["pow_th"] 81 | else: 82 | msg = f"Since {spk} is not in spkinfo dict, " 83 | msg += "default f0 range and power threshold are used." 84 | logger.info(msg) 85 | tempc["minf0"] = 70 86 | tempc["maxf0"] = 300 87 | tempc["pow_th"] = -20 88 | configs.append(tempc) 89 | tempf.append(file) 90 | file_lists.append(tempf) 91 | 92 | return file_lists, configs 93 | 94 | 95 | def aux_list_create(wav_list_file, config): 96 | """Create list of auxiliary acoustic features 97 | 98 | Args: 99 | wav_list_file (str): Filename of wav list 100 | config (dict): Config 101 | 102 | """ 103 | aux_list_file = wav_list_file.replace(".scp", ".list") 104 | wav_files = read_txt(wav_list_file) 105 | with open(aux_list_file, "w") as f: 106 | for wav_name in wav_files: 107 | feat_name = path_replace( 108 | wav_name, 109 | config.in_dir, 110 | config.out_dir, 111 | extname=config.feature_format, 112 | ) 113 | f.write(f"{feat_name}\n") 114 | 115 | 116 | def low_cut_filter(x, fs, cutoff=70): 117 | """Low cut filter 118 | 119 | Args: 120 | x (ndarray): Waveform sequence 121 | fs (int): Sampling frequency 122 | cutoff (float): Cutoff frequency of low cut filter 123 | 124 | Return: 125 | (ndarray): Low cut filtered waveform sequence 126 | 127 | """ 128 | nyquist = fs // 2 129 | norm_cutoff = cutoff / nyquist 130 | fil = firwin(255, norm_cutoff, pass_zero=False) 131 | lcf_x = lfilter(fil, 1, x) 132 | 133 | return lcf_x 134 | 135 | 136 | def low_pass_filter(x, fs, cutoff=70, padding=True): 137 | """Low pass filter 138 | 139 | Args: 140 | x (ndarray): Waveform sequence 141 | fs (int): Sampling frequency 142 | cutoff (float): Cutoff frequency of low pass filter 143 | 144 | Return: 145 | (ndarray): Low pass filtered waveform sequence 146 | 147 | """ 148 | nyquist = fs // 2 149 | norm_cutoff = cutoff / nyquist 150 | numtaps = 255 151 | fil = firwin(numtaps, norm_cutoff) 152 | x_pad = np.pad(x, (numtaps, numtaps), "edge") 153 | lpf_x = lfilter(fil, 1, x_pad) 154 | lpf_x = lpf_x[numtaps + numtaps // 2 : -numtaps // 2] 155 | 156 | return lpf_x 157 | 158 | 159 | def convert_continuos_f0(f0): 160 | """Convert F0 to continuous F0 161 | 162 | Args: 163 | f0 (ndarray): original f0 sequence with the shape (T) 164 | 165 | Return: 166 | (ndarray): continuous f0 with the shape (T) 167 | 168 | """ 169 | # get uv information as binary 170 | uv = np.float32(f0 != 0) 171 | # get start and end of f0 172 | if (f0 == 0).all(): 173 | logger.warn("all of the f0 values are 0.") 174 | return uv, f0, False 175 | start_f0 = f0[f0 != 0][0] 176 | end_f0 = f0[f0 != 0][-1] 177 | # padding start and end of f0 sequence 178 | cont_f0 = copy.deepcopy(f0) 179 | start_idx = np.where(cont_f0 == start_f0)[0][0] 180 | end_idx = np.where(cont_f0 == end_f0)[0][-1] 181 | cont_f0[:start_idx] = start_f0 182 | cont_f0[end_idx:] = end_f0 183 | # get non-zero frame index 184 | nz_frames = np.where(cont_f0 != 0)[0] 185 | # perform linear interpolation 186 | f = interp1d(nz_frames, cont_f0[nz_frames]) 187 | cont_f0 = f(np.arange(0, cont_f0.shape[0])) 188 | 189 | return uv, cont_f0, True 190 | 191 | 192 | # Mel-spectrogram and F0 features 193 | def melfilterbank( 194 | audio, 195 | sampling_rate, 196 | fft_size=1024, 197 | hop_size=256, 198 | win_length=None, 199 | window="hann", 200 | num_mels=80, 201 | fmin=None, 202 | fmax=None, 203 | ): 204 | """Extract linear mel filterbank feature. 205 | 206 | Args: 207 | audio (ndarray): Audio signal (T,). 208 | sampling_rate (int): Sampling rate. 209 | fft_size (int): FFT size. 210 | hop_size (int): Hop size. 211 | win_length (int): Window length. If set to None, it will be the same as fft_size. 212 | window (str): Window function type. 213 | num_mels (int): Number of mel basis. 214 | fmin (int): Minimum frequency in mel basis calculation. 215 | fmax (int): Maximum frequency in mel basis calculation. 216 | 217 | Returns: 218 | ndarray: Linear mel filterbank feature (#frames, num_mels). 219 | 220 | """ 221 | # get amplitude spectrogram 222 | x_stft = librosa.stft( 223 | audio, 224 | n_fft=fft_size, 225 | hop_length=hop_size, 226 | win_length=win_length, 227 | window=window, 228 | pad_mode="reflect", 229 | ) 230 | spc = np.abs(x_stft).T # (#frames, #bins) 231 | 232 | # get mel basis 233 | fmin = 0 if fmin is None else fmin 234 | fmax = sampling_rate / 2 if fmax is None else fmax 235 | mel_basis = librosa.filters.mel( 236 | sr=sampling_rate, 237 | n_fft=fft_size, 238 | n_mels=num_mels, 239 | fmin=fmin, 240 | fmax=fmax, 241 | ) 242 | 243 | return np.dot(spc, mel_basis.T) 244 | 245 | 246 | def melf0_feature_extract(queue, wav_list, config, eps=1e-7): 247 | """Mel-spc w/ discrete F0 and continuous F0 feature extraction 248 | 249 | Args: 250 | queue (multiprocessing.Queue): the queue to store the file name of utterance 251 | wav_list (list): list of the wav files 252 | config (dict): feature extraction config 253 | 254 | """ 255 | # extraction 256 | for i, wav_name in enumerate(wav_list): 257 | logger.info(f"now processing {wav_name} ({i + 1}/{len(wav_list)})") 258 | 259 | # load wavfile 260 | x, fs = sf.read(to_absolute_path(wav_name)) 261 | x = np.array(x, dtype=np.float) 262 | 263 | # check sampling frequency 264 | if not fs == config.sampling_rate: 265 | logger.error("sampling frequency is not matched.") 266 | sys.exit(1) 267 | 268 | # apply low-cut-filter 269 | if config.highpass_cutoff > 0: 270 | if (x == 0).all(): 271 | logger.info(f"xxxxx {wav_name}") 272 | continue 273 | x = low_cut_filter(x, fs, cutoff=config.highpass_cutoff) 274 | 275 | # extract WORLD features 276 | f0, t = pyworld.harvest( 277 | x, 278 | fs=config.sampling_rate, 279 | f0_floor=config.minf0, 280 | f0_ceil=config.maxf0, 281 | frame_period=config.shiftms, 282 | ) 283 | env = pyworld.cheaptrick( 284 | x, 285 | f0, 286 | t, 287 | fs=config.sampling_rate, 288 | fft_size=config.fft_size, 289 | ) 290 | ap = pyworld.d4c( 291 | x, 292 | f0, 293 | t, 294 | fs=config.sampling_rate, 295 | fft_size=config.fft_size, 296 | ) 297 | uv, cont_f0, is_all_uv = convert_continuos_f0(f0) 298 | if is_all_uv: 299 | lpf_fs = int(config.sampling_rate / config.hop_size) 300 | cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=20) 301 | next_cutoff = 70 302 | while not (cont_f0_lpf >= [0]).all(): 303 | cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=next_cutoff) 304 | next_cutoff *= 2 305 | else: 306 | cont_f0_lpf = cont_f0 307 | logger.warn(f"all of the f0 values are 0 {wav_name}.") 308 | mcep = pysptk.sp2mc(env, order=config.mcep_dim, alpha=config.alpha) 309 | mcap = pysptk.sp2mc(ap, order=config.mcap_dim, alpha=config.alpha) 310 | codeap = pyworld.code_aperiodicity(ap, config.sampling_rate) 311 | 312 | # extract mel-spectrogram 313 | msp = melfilterbank( 314 | x, 315 | config.sampling_rate, 316 | fft_size=config.fft_size, 317 | hop_size=config.hop_size, 318 | win_length=config.win_length, 319 | window=config.window, 320 | num_mels=config.num_mels, 321 | fmin=config.fmin, 322 | fmax=config.fmax, 323 | ) 324 | 325 | # adjust shapes 326 | minlen = min(uv.shape[0], msp.shape[0]) 327 | uv = uv[:minlen] 328 | uv = np.expand_dims(uv, axis=-1) 329 | f0 = f0[:minlen] 330 | f0 = np.expand_dims(f0, axis=-1) 331 | cont_f0_lpf = cont_f0_lpf[:minlen] 332 | cont_f0_lpf = np.expand_dims(cont_f0_lpf, axis=-1) 333 | mcep = mcep[:minlen] 334 | mcap = mcap[:minlen] 335 | codeap = codeap[:minlen] 336 | logmsp = np.log10(np.maximum(eps, msp)) 337 | 338 | # save features 339 | feat_name = path_replace( 340 | wav_name, 341 | config.in_dir, 342 | config.out_dir, 343 | extname=config.feature_format, 344 | ) 345 | write_hdf5(to_absolute_path(feat_name), "/uv", uv) 346 | write_hdf5(to_absolute_path(feat_name), "/f0", f0) 347 | write_hdf5(to_absolute_path(feat_name), "/contf0", cont_f0_lpf) 348 | write_hdf5(to_absolute_path(feat_name), "/mcep", mcep) 349 | write_hdf5(to_absolute_path(feat_name), "/mcap", mcap) 350 | write_hdf5(to_absolute_path(feat_name), "/codeap", codeap) 351 | write_hdf5(to_absolute_path(feat_name), "/logmsp", logmsp) 352 | 353 | queue.put("Finish") 354 | 355 | 356 | @hydra.main(version_base=None, config_path="config", config_name="extract_features") 357 | def main(config: DictConfig): 358 | # show argument 359 | logger.info(OmegaConf.to_yaml(config)) 360 | 361 | # read list 362 | file_list = read_txt(to_absolute_path(config.audio)) 363 | logger.info(f"number of utterances = {len(file_list)}") 364 | 365 | # list division 366 | if config.spkinfo and os.path.exists(to_absolute_path(config.spkinfo)): 367 | # load speaker info 368 | with open(to_absolute_path(config.spkinfo), "r") as f: 369 | spkinfo = yaml.safe_load(f) 370 | logger.info(f"Spkinfo {config.spkinfo} is used.") 371 | # divide into each spk list 372 | file_lists, configs = spk_division(file_list, config, spkinfo) 373 | else: 374 | logger.info( 375 | f"Since spkinfo {config.spkinfo} is not exist, default f0 range and power threshold are used." 376 | ) 377 | file_lists = np.array_split(file_list, 10) 378 | file_lists = [f_list.tolist() for f_list in file_lists] 379 | configs = [config] * len(file_lists) 380 | 381 | # set mode 382 | if config.inv: 383 | target_fn = melf0_feature_extract 384 | # create auxiliary feature list 385 | aux_list_create(to_absolute_path(config.audio), config) 386 | # create folder 387 | path_create(file_list, config.in_dir, config.out_dir, config.feature_format) 388 | 389 | # multi processing 390 | processes = [] 391 | queue = mp.Queue() 392 | for f, _config in zip(file_lists, configs): 393 | p = mp.Process( 394 | target=target_fn, 395 | args=(queue, f, _config), 396 | ) 397 | p.start() 398 | processes.append(p) 399 | 400 | # wait for all process 401 | for p in processes: 402 | p.join() 403 | 404 | 405 | if __name__ == "__main__": 406 | main() 407 | -------------------------------------------------------------------------------- /usfgan/layers/residual_block.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Residual block modules. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | - https://github.com/kan-bayashi/ParallelWaveGAN 11 | - https://github.com/r9y9/wavenet_vocoder 12 | 13 | """ 14 | 15 | import math 16 | import sys 17 | from logging import getLogger 18 | 19 | import torch 20 | import torch.nn as nn 21 | from usfgan.utils import pd_indexing 22 | 23 | # A logger for this file 24 | logger = getLogger(__name__) 25 | 26 | 27 | class Conv1d(nn.Conv1d): 28 | """Conv1d module with customized initialization.""" 29 | 30 | def __init__(self, *args, **kwargs): 31 | """Initialize Conv1d module.""" 32 | super(Conv1d, self).__init__(*args, **kwargs) 33 | 34 | def reset_parameters(self): 35 | """Reset parameters.""" 36 | nn.init.kaiming_normal_(self.weight, nonlinearity="relu") 37 | if self.bias is not None: 38 | nn.init.constant_(self.bias, 0.0) 39 | 40 | 41 | class Conv1d1x1(Conv1d): 42 | """1x1 Conv1d with customized initialization.""" 43 | 44 | def __init__(self, in_channels, out_channels, bias=True): 45 | """Initialize 1x1 Conv1d module.""" 46 | super(Conv1d1x1, self).__init__( 47 | in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias 48 | ) 49 | 50 | 51 | class Conv2d(nn.Conv2d): 52 | """Conv2d module with customized initialization.""" 53 | 54 | def __init__(self, *args, **kwargs): 55 | """Initialize Conv2d module.""" 56 | super(Conv2d, self).__init__(*args, **kwargs) 57 | 58 | def reset_parameters(self): 59 | """Reset parameters.""" 60 | nn.init.kaiming_normal_(self.weight, mode="fan_out", nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0.0) 63 | 64 | 65 | class Conv2d1x1(Conv2d): 66 | """1x1 Conv2d with customized initialization.""" 67 | 68 | def __init__(self, in_channels, out_channels, bias=True): 69 | """Initialize 1x1 Conv2d module.""" 70 | super(Conv2d1x1, self).__init__( 71 | in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias 72 | ) 73 | 74 | 75 | class FixedBlock(nn.Module): 76 | """Fixed block module in QPPWG.""" 77 | 78 | def __init__( 79 | self, 80 | residual_channels=64, 81 | gate_channels=128, 82 | skip_channels=64, 83 | aux_channels=80, 84 | kernel_size=3, 85 | dilation=1, 86 | bias=True, 87 | ): 88 | """Initialize Fixed ResidualBlock module. 89 | 90 | Args: 91 | residual_channels (int): Number of channels for residual connection. 92 | skip_channels (int): Number of channels for skip connection. 93 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. 94 | dilation (int): Dilation size. 95 | bias (bool): Whether to add bias parameter in convolution layers. 96 | 97 | """ 98 | super(FixedBlock, self).__init__() 99 | padding = (kernel_size - 1) // 2 * dilation 100 | 101 | # dilation conv 102 | self.conv = Conv1d( 103 | residual_channels, 104 | gate_channels, 105 | kernel_size, 106 | padding=padding, 107 | padding_mode="reflect", 108 | dilation=dilation, 109 | bias=bias, 110 | ) 111 | 112 | # local conditioning 113 | if aux_channels > 0: 114 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) 115 | else: 116 | self.conv1x1_aux = None 117 | 118 | # conv output is split into two groups 119 | gate_out_channels = gate_channels // 2 120 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 121 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) 122 | 123 | def forward(self, x, c): 124 | """Calculate forward propagation. 125 | 126 | Args: 127 | x (Tensor): Input tensor (B, residual_channels, T). 128 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). 129 | 130 | Returns: 131 | Tensor: Output tensor for residual connection (B, residual_channels, T). 132 | Tensor: Output tensor for skip connection (B, skip_channels, T). 133 | 134 | """ 135 | residual = x 136 | x = self.conv(x) 137 | 138 | # split into two part for gated activation 139 | splitdim = 1 140 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) 141 | 142 | # local conditioning 143 | if c is not None: 144 | assert self.conv1x1_aux is not None 145 | c = self.conv1x1_aux(c) 146 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 147 | xa, xb = xa + ca, xb + cb 148 | 149 | x = torch.tanh(xa) * torch.sigmoid(xb) 150 | 151 | # for skip connection 152 | s = self.conv1x1_skip(x) 153 | 154 | # for residual connection 155 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) 156 | 157 | return x, s 158 | 159 | 160 | class AdaptiveBlock(nn.Module): 161 | """Adaptive block module in QPPWG.""" 162 | 163 | def __init__( 164 | self, 165 | residual_channels=64, 166 | gate_channels=128, 167 | skip_channels=64, 168 | aux_channels=80, 169 | bias=True, 170 | ): 171 | """Initialize Adaptive ResidualBlock module. 172 | 173 | Args: 174 | residual_channels (int): Number of channels for residual connection. 175 | skip_channels (int): Number of channels for skip connection. 176 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. 177 | bias (bool): Whether to add bias parameter in convolution layers. 178 | 179 | """ 180 | super(AdaptiveBlock, self).__init__() 181 | 182 | # pitch-dependent dilation conv 183 | self.convP = Conv1d1x1(residual_channels, gate_channels, bias=bias) # past 184 | self.convC = Conv1d1x1(residual_channels, gate_channels, bias=bias) # current 185 | self.convF = Conv1d1x1(residual_channels, gate_channels, bias=bias) # future 186 | 187 | # local conditioning 188 | if aux_channels > 0: 189 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) 190 | else: 191 | self.conv1x1_aux = None 192 | 193 | # conv output is split into two groups 194 | gate_out_channels = gate_channels // 2 195 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 196 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) 197 | 198 | def forward(self, xC, xP, xF, c): 199 | """Calculate forward propagation. 200 | 201 | Args: 202 | xC (Tensor): Current input tensor (B, residual_channels, T). 203 | xP (Tensor): Past input tensor (B, residual_channels, T). 204 | xF (Tensor): Future input tensor (B, residual_channels, T). 205 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). 206 | 207 | Returns: 208 | Tensor: Output tensor for residual connection (B, residual_channels, T). 209 | Tensor: Output tensor for skip connection (B, skip_channels, T). 210 | 211 | """ 212 | residual = xC 213 | x = self.convC(xC) + self.convP(xP) + self.convF(xF) 214 | 215 | # split into two part for gated activation 216 | splitdim = 1 217 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) 218 | 219 | # local conditioning 220 | if c is not None: 221 | assert self.conv1x1_aux is not None 222 | c = self.conv1x1_aux(c) 223 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 224 | xa, xb = xa + ca, xb + cb 225 | 226 | x = torch.tanh(xa) * torch.sigmoid(xb) 227 | 228 | # for skip connection 229 | s = self.conv1x1_skip(x) 230 | 231 | # for residual connection 232 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) 233 | 234 | return x, s 235 | 236 | 237 | class ResidualBlocks(nn.Module): 238 | """Multiple residual blocks stacking module.""" 239 | 240 | def __init__( 241 | self, 242 | blockA, 243 | cycleA, 244 | blockF, 245 | cycleF, 246 | cascade_mode=0, 247 | residual_channels=64, 248 | gate_channels=128, 249 | skip_channels=64, 250 | aux_channels=80, 251 | ): 252 | """Initialize ResidualBlocks module. 253 | 254 | Args: 255 | blockA (int): Number of adaptive residual blocks. 256 | cycleA (int): Number of dilation cycles of adaptive residual blocks. 257 | blockF (int): Number of fixed residual blocks. 258 | cycleF (int): Number of dilation cycles of fixed residual blocks. 259 | cascade_mode (int): Cascaded mode (0: Adaptive->Fixed; 1: Fixed->Adaptive). 260 | residual_channels (int): Number of channels in residual conv. 261 | gate_channels (int): Number of channels in gated conv. 262 | skip_channels (int): Number of channels in skip conv. 263 | aux_channels (int): Number of channels for auxiliary feature conv. 264 | 265 | """ 266 | super(ResidualBlocks, self).__init__() 267 | 268 | # check the number of blocks and cycles 269 | cycleA = max(cycleA, 1) 270 | cycleF = max(cycleF, 1) 271 | assert blockA % cycleA == 0 272 | self.blockA_per_cycle = blockA // cycleA 273 | assert blockF % cycleF == 0 274 | blockF_per_cycle = blockF // cycleF 275 | 276 | # define adaptive residual blocks 277 | adaptive_blocks = nn.ModuleList() 278 | for block in range(blockA): 279 | conv = AdaptiveBlock( 280 | residual_channels=residual_channels, 281 | gate_channels=gate_channels, 282 | skip_channels=skip_channels, 283 | aux_channels=aux_channels, 284 | ) 285 | adaptive_blocks += [conv] 286 | 287 | # define fixed residual blocks 288 | fixed_blocks = nn.ModuleList() 289 | for block in range(blockF): 290 | dilation = 2 ** (block % blockF_per_cycle) 291 | conv = FixedBlock( 292 | residual_channels=residual_channels, 293 | gate_channels=gate_channels, 294 | skip_channels=skip_channels, 295 | aux_channels=aux_channels, 296 | dilation=dilation, 297 | ) 298 | fixed_blocks += [conv] 299 | 300 | # define cascaded structure 301 | if cascade_mode == 0: # adaptive->fixed 302 | self.conv_dilated = adaptive_blocks.extend(fixed_blocks) 303 | self.block_modes = [True] * blockA + [False] * blockF 304 | elif cascade_mode == 1: # fixed->adaptive 305 | self.conv_dilated = fixed_blocks.extend(adaptive_blocks) 306 | self.block_modes = [False] * blockF + [True] * blockA 307 | else: 308 | logger.error(f"Cascaded mode {cascade_mode} is not supported!") 309 | sys.exit(0) 310 | 311 | def forward(self, x, c, d, batch_index, ch_index): 312 | """Calculate forward propagation. 313 | 314 | Args: 315 | x (Tensor): Input noise signal (B, 1, T). 316 | c (Tensor): Local conditioning auxiliary features (B, C ,T). 317 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T). 318 | 319 | Returns: 320 | Tensor: Output tensor (B, residual_channels, T). 321 | 322 | """ 323 | skips = 0 324 | blockA_idx = 0 325 | for f, mode in zip(self.conv_dilated, self.block_modes): 326 | if mode: # adaptive block 327 | dilation = 2 ** (blockA_idx % self.blockA_per_cycle) 328 | xP, xF = pd_indexing(x, d, dilation, batch_index, ch_index) 329 | x, h = f(x, xP, xF, c) 330 | blockA_idx += 1 331 | else: # fixed block 332 | x, h = f(x, c) 333 | skips = h + skips 334 | skips *= math.sqrt(1.0 / len(self.conv_dilated)) 335 | 336 | return x 337 | 338 | 339 | class PeriodicityEstimator(nn.Module): 340 | """Periodicity estimator module.""" 341 | 342 | def __init__( 343 | self, 344 | in_channels, 345 | residual_channels=64, 346 | conv_layers=3, 347 | kernel_size=5, 348 | dilation=1, 349 | padding_mode="replicate", 350 | ): 351 | """Initialize USFGANGenerator module. 352 | 353 | Args: 354 | in_channels (int): Number of input channels. 355 | residual_channels (int): Number of channels in residual conv. 356 | conv_layers (int): # Number of convolution layers. 357 | kernel_size (int): Kernel size. 358 | dilation (int): Dilation size. 359 | padding_mode (str): Padding mode. 360 | 361 | """ 362 | super(PeriodicityEstimator, self).__init__() 363 | 364 | modules = [] 365 | for idx in range(conv_layers): 366 | conv1d = Conv1d( 367 | in_channels, 368 | residual_channels, 369 | kernel_size=kernel_size, 370 | dilation=dilation, 371 | padding=kernel_size // 2 * dilation, 372 | padding_mode=padding_mode, 373 | ) 374 | 375 | # initialize the initial outputs sigmoid(0)=0.5 to stabilize training 376 | if idx != conv_layers - 1: 377 | nonlinear = nn.ReLU(inplace=True) 378 | else: 379 | # NOTE: zero init induces nan or inf if weight normalization is used 380 | # nn.init.zeros_(conv1d.weight) 381 | nn.init.normal_(conv1d.weight, std=1e-4) 382 | nonlinear = nn.Sigmoid() 383 | 384 | modules += [conv1d, nonlinear] 385 | in_channels = residual_channels 386 | 387 | self.layers = nn.Sequential(*modules) 388 | 389 | def forward(self, x): 390 | """Calculate forward propagation. 391 | 392 | Args: 393 | x (Tensor): Input auxiliary features (B, C ,T). 394 | 395 | Returns: 396 | Tensor: Output tensor (B, residual_channels, T). 397 | 398 | """ 399 | return self.layers(x) 400 | -------------------------------------------------------------------------------- /usfgan/models/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Unified Source-Filter GAN Generator modules.""" 7 | 8 | from logging import getLogger 9 | 10 | import torch 11 | import torch.nn as nn 12 | from usfgan.layers import Conv1d1x1, ResidualBlocks, upsample 13 | from usfgan.layers.residual_block import PeriodicityEstimator 14 | from usfgan.utils import index_initial 15 | 16 | # A logger for this file 17 | logger = getLogger(__name__) 18 | 19 | 20 | class USFGANGenerator(nn.Module): 21 | """Unified Source-Filter GAN Generator module.""" 22 | 23 | def __init__( 24 | self, 25 | source_network_params={ 26 | "blockA": 30, 27 | "cycleA": 3, 28 | "blockF": 0, 29 | "cycleF": 0, 30 | "cascade_mode": 0, 31 | }, 32 | filter_network_params={ 33 | "blockA": 0, 34 | "cycleA": 0, 35 | "blockF": 30, 36 | "cycleF": 3, 37 | "cascade_mode": 0, 38 | }, 39 | in_channels=1, 40 | out_channels=1, 41 | residual_channels=64, 42 | gate_channels=128, 43 | skip_channels=64, 44 | aux_channels=80, 45 | aux_context_window=2, 46 | use_weight_norm=True, 47 | upsample_params={"upsample_scales": [5, 4, 3, 2]}, 48 | ): 49 | """Initialize USFGANGenerator module. 50 | 51 | Args: 52 | source_network_params (dict): Source-network parameters. 53 | filter_network_params (dict): Filter-network parameters. 54 | in_channels (int): Number of input channels. 55 | out_channels (int): Number of output channels. 56 | residual_channels (int): Number of channels in residual conv. 57 | gate_channels (int): Number of channels in gated conv. 58 | skip_channels (int): Number of channels in skip conv. 59 | aux_channels (int): Number of channels for auxiliary feature conv. 60 | aux_context_window (int): Context window size for auxiliary feature. 61 | use_weight_norm (bool): Whether to use weight norm. 62 | If set to true, it will be applied to all of the conv layers. 63 | upsample_params (dict): Upsampling network parameters. 64 | 65 | """ 66 | super(USFGANGenerator, self).__init__() 67 | self.in_channels = in_channels 68 | self.out_channels = out_channels 69 | self.aux_channels = aux_channels 70 | self.n_ch = residual_channels 71 | 72 | # define first convolution 73 | self.conv_first = Conv1d1x1(in_channels, residual_channels) 74 | 75 | # define upsampling network 76 | self.upsample_net = getattr(upsample, "ConvInUpsampleNetwork")( 77 | **upsample_params, 78 | aux_channels=aux_channels, 79 | aux_context_window=aux_context_window, 80 | ) 81 | 82 | # define source/filter networks 83 | for params in [ 84 | source_network_params, 85 | filter_network_params, 86 | ]: 87 | params.update( 88 | { 89 | "residual_channels": residual_channels, 90 | "gate_channels": gate_channels, 91 | "skip_channels": skip_channels, 92 | "aux_channels": aux_channels, 93 | } 94 | ) 95 | self.source_network = ResidualBlocks(**source_network_params) 96 | self.filter_network = ResidualBlocks(**filter_network_params) 97 | 98 | # convert source signal to hidden representation 99 | self.conv_mid = Conv1d1x1(out_channels, skip_channels) 100 | 101 | # convert hidden representation to output signal 102 | self.conv_last = nn.Sequential( 103 | nn.ReLU(), 104 | Conv1d1x1(skip_channels, skip_channels), 105 | nn.ReLU(), 106 | Conv1d1x1(skip_channels, out_channels), 107 | ) 108 | 109 | # apply weight norm 110 | if use_weight_norm: 111 | self.apply_weight_norm() 112 | 113 | def forward(self, x, c, d): 114 | """Calculate forward propagation. 115 | 116 | Args: 117 | x (Tensor): Input noise signal (B, 1, T). 118 | c (Tensor): Local conditioning auxiliary features (B, C ,T'). 119 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T). 120 | 121 | Returns: 122 | Tensor: Output tensor (B, 1, T) 123 | 124 | """ 125 | # index initialization 126 | batch_index, ch_index = index_initial(x.size(0), self.n_ch) 127 | 128 | # perform upsampling 129 | c = self.upsample_net(c) 130 | assert c.size(-1) == x.size(-1) 131 | 132 | # encode to hidden representation 133 | x = self.conv_first(x) 134 | 135 | # source excitation generation 136 | x = self.source_network(x, c, d, batch_index, ch_index) 137 | s = self.conv_last(x) 138 | x = self.conv_mid(s) 139 | 140 | # resonance filtering 141 | x = self.filter_network(x, c, d, batch_index, ch_index) 142 | x = self.conv_last(x) 143 | 144 | return x, s 145 | 146 | def remove_weight_norm(self): 147 | """Remove weight normalization module from all of the layers.""" 148 | 149 | def _remove_weight_norm(m): 150 | try: 151 | logger.debug(f"Weight norm is removed from {m}.") 152 | nn.utils.remove_weight_norm(m) 153 | except ValueError: # this module didn't have weight norm 154 | return 155 | 156 | self.apply(_remove_weight_norm) 157 | 158 | def apply_weight_norm(self): 159 | """Apply weight normalization module from all of the layers.""" 160 | 161 | def _apply_weight_norm(m): 162 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): 163 | nn.utils.weight_norm(m) 164 | logger.debug(f"Weight norm is applied to {m}.") 165 | 166 | self.apply(_apply_weight_norm) 167 | 168 | 169 | class CascadeHnUSFGANGenerator(nn.Module): 170 | """Cascade hn-uSFGAN Generator module.""" 171 | 172 | def __init__( 173 | self, 174 | harmonic_network_params={ 175 | "blockA": 20, 176 | "cycleA": 4, 177 | "blockF": 0, 178 | "cycleF": 0, 179 | "cascade_mode": 0, 180 | }, 181 | noise_network_params={ 182 | "blockA": 0, 183 | "cycleA": 0, 184 | "blockF": 5, 185 | "cycleF": 5, 186 | "cascade_mode": 0, 187 | }, 188 | filter_network_params={ 189 | "blockA": 0, 190 | "cycleA": 0, 191 | "blockF": 30, 192 | "cycleF": 3, 193 | "cascade_mode": 0, 194 | }, 195 | periodicity_estimator_params={ 196 | "conv_layers": 3, 197 | "kernel_size": 5, 198 | "dilation": 1, 199 | "padding_mode": "replicate", 200 | }, 201 | in_channels=1, 202 | out_channels=1, 203 | residual_channels=64, 204 | gate_channels=128, 205 | skip_channels=64, 206 | aux_channels=80, 207 | aux_context_window=2, 208 | use_weight_norm=True, 209 | upsample_params={"upsample_scales": [5, 4, 3, 2]}, 210 | ): 211 | """Initialize CascadeHnUSFGANGenerator module. 212 | 213 | Args: 214 | harmonic_network_params (dict): Periodic source generation network parameters. 215 | noise_network_params (dict): Aperiodic source generation network parameters. 216 | filter_network_params (dict): Filter network parameters. 217 | periodicity_estimator_params (dict): Periodicity estimation network parameters. 218 | in_channels (int): Number of input channels. 219 | out_channels (int): Number of output channels. 220 | residual_channels (int): Number of channels in residual conv. 221 | gate_channels (int): Number of channels in gated conv. 222 | skip_channels (int): Number of channels in skip conv. 223 | aux_channels (int): Number of channels for auxiliary feature conv. 224 | aux_context_window (int): Context window size for auxiliary feature. 225 | use_weight_norm (bool): Whether to use weight norm. 226 | If set to true, it will be applied to all of the conv layers. 227 | upsample_params (dict): Upsampling network parameters. 228 | 229 | """ 230 | super(CascadeHnUSFGANGenerator, self).__init__() 231 | self.in_channels = in_channels 232 | self.out_channels = out_channels 233 | self.aux_channels = aux_channels 234 | self.n_ch = residual_channels 235 | 236 | # define first convolution 237 | self.conv_first_sine = Conv1d1x1(in_channels, residual_channels) 238 | self.conv_first_noise = Conv1d1x1(in_channels, residual_channels) 239 | self.conv_merge = Conv1d1x1(residual_channels * 2, residual_channels) 240 | 241 | # define upsampling network 242 | self.upsample_net = getattr(upsample, "ConvInUpsampleNetwork")( 243 | **upsample_params, 244 | aux_channels=aux_channels, 245 | aux_context_window=aux_context_window, 246 | ) 247 | 248 | # define harmonic/noise/filter networks 249 | for params in [ 250 | harmonic_network_params, 251 | noise_network_params, 252 | filter_network_params, 253 | ]: 254 | params.update( 255 | { 256 | "residual_channels": residual_channels, 257 | "gate_channels": gate_channels, 258 | "skip_channels": skip_channels, 259 | "aux_channels": aux_channels, 260 | } 261 | ) 262 | self.harmonic_network = ResidualBlocks(**harmonic_network_params) 263 | self.noise_network = ResidualBlocks(**noise_network_params) 264 | self.filter_network = ResidualBlocks(**filter_network_params) 265 | 266 | # define periodicity estimator 267 | self.periodicity_estimator = PeriodicityEstimator( 268 | **periodicity_estimator_params, in_channels=aux_channels 269 | ) 270 | 271 | # convert hidden representation to output signal 272 | self.conv_last = nn.Sequential( 273 | nn.ReLU(), 274 | Conv1d1x1(skip_channels, skip_channels), 275 | nn.ReLU(), 276 | Conv1d1x1(skip_channels, out_channels), 277 | ) 278 | 279 | # apply weight norm 280 | if use_weight_norm: 281 | self.apply_weight_norm() 282 | 283 | def forward(self, x, c, d): 284 | """Calculate forward propagation. 285 | 286 | Args: 287 | x (Tensor): Input noise signal (B, 1, T). 288 | c (Tensor): Local conditioning auxiliary features (B, C ,T'). 289 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T). 290 | 291 | Returns: 292 | Tensor: Output tensor (B, 1, T) 293 | 294 | """ 295 | # index initialization 296 | batch_index, ch_index = index_initial(x.size(0), self.n_ch) 297 | 298 | # upsample auxiliary features 299 | c = self.upsample_net(c) 300 | assert c.size(-1) == x.size(-1) 301 | 302 | # estimate periodicity 303 | a = self.periodicity_estimator(c) 304 | 305 | # assume the first channel is sine and the other is noise 306 | sine, noise = torch.chunk(x, 2, 1) 307 | 308 | # encode to hidden representation 309 | h = self.conv_first_sine(sine) 310 | n = self.conv_first_noise(noise) 311 | 312 | # generate periodic and aperiodic source latent features 313 | h = self.harmonic_network(h, c, d, batch_index, ch_index) 314 | h = a * h 315 | n = self.conv_merge(torch.cat([h, n], dim=1)) 316 | n = self.noise_network(n, c, d, batch_index, ch_index) 317 | n = (1.0 - a) * n 318 | 319 | # merge periodic and aperiodic latent features 320 | s = h + n 321 | 322 | # resonance filtering 323 | x = self.filter_network(s, c, d, batch_index, ch_index) 324 | x = self.conv_last(x) 325 | 326 | # convert to 1d signal for regularization loss 327 | s = self.conv_last(s) 328 | 329 | # just for debug 330 | with torch.no_grad(): 331 | h = self.conv_last(h) 332 | n = self.conv_last(n) 333 | 334 | return x, s, h, n, a 335 | 336 | def remove_weight_norm(self): 337 | """Remove weight normalization module from all of the layers.""" 338 | 339 | def _remove_weight_norm(m): 340 | try: 341 | logger.debug(f"Weight norm is removed from {m}.") 342 | nn.utils.remove_weight_norm(m) 343 | except ValueError: # this module didn't have weight norm 344 | return 345 | 346 | self.apply(_remove_weight_norm) 347 | 348 | def apply_weight_norm(self): 349 | """Apply weight normalization module from all of the layers.""" 350 | 351 | def _apply_weight_norm(m): 352 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): 353 | nn.utils.weight_norm(m) 354 | logger.debug(f"Weight norm is applied to {m}.") 355 | 356 | self.apply(_apply_weight_norm) 357 | 358 | 359 | class ParallelHnUSFGANGenerator(nn.Module): 360 | """Parallel hn-uSFGAN Generator module.""" 361 | 362 | def __init__( 363 | self, 364 | harmonic_network_params={ 365 | "blockA": 20, 366 | "cycleA": 4, 367 | "blockF": 0, 368 | "cycleF": 0, 369 | "cascade_mode": 0, 370 | }, 371 | noise_network_params={ 372 | "blockA": 0, 373 | "cycleA": 0, 374 | "blockF": 5, 375 | "cycleF": 5, 376 | "cascade_mode": 0, 377 | }, 378 | filter_network_params={ 379 | "blockA": 0, 380 | "cycleA": 0, 381 | "blockF": 30, 382 | "cycleF": 3, 383 | "cascade_mode": 0, 384 | }, 385 | periodicity_estimator_params={ 386 | "conv_layers": 3, 387 | "kernel_size": 5, 388 | "dilation": 1, 389 | "padding_mode": "replicate", 390 | }, 391 | in_channels=1, 392 | out_channels=1, 393 | residual_channels=64, 394 | gate_channels=128, 395 | skip_channels=64, 396 | aux_channels=80, 397 | aux_context_window=2, 398 | use_weight_norm=True, 399 | upsample_params={"upsample_scales": [5, 4, 3, 2]}, 400 | ): 401 | """Initialize ParallelHnUSFGANGenerator module. 402 | 403 | Args: 404 | harmonic_network_params (dict): Periodic source generation network parameters. 405 | noise_network_params (dict): Aperiodic source generation network parameters. 406 | filter_network_params (dict): Filter network parameters. 407 | periodicity_estimator_params (dict): Periodicity estimation network parameters. 408 | in_channels (int): Number of input channels. 409 | out_channels (int): Number of output channels. 410 | residual_channels (int): Number of channels in residual conv. 411 | gate_channels (int): Number of channels in gated conv. 412 | skip_channels (int): Number of channels in skip conv. 413 | aux_channels (int): Number of channels for auxiliary feature conv. 414 | aux_context_window (int): Context window size for auxiliary feature. 415 | use_weight_norm (bool): Whether to use weight norm. 416 | If set to true, it will be applied to all of the conv layers. 417 | upsample_params (dict): Upsampling network parameters. 418 | 419 | """ 420 | super(ParallelHnUSFGANGenerator, self).__init__() 421 | self.in_channels = in_channels 422 | self.out_channels = out_channels 423 | self.aux_channels = aux_channels 424 | self.n_ch = residual_channels 425 | 426 | # define first convolution 427 | self.conv_first_sine = Conv1d1x1(in_channels, residual_channels) 428 | self.conv_first_noise = Conv1d1x1(in_channels, residual_channels) 429 | 430 | # define upsampling network 431 | self.upsample_net = getattr(upsample, "ConvInUpsampleNetwork")( 432 | **upsample_params, 433 | aux_channels=aux_channels, 434 | aux_context_window=aux_context_window, 435 | ) 436 | 437 | # define harmonic/noise/filter networks 438 | for params in [ 439 | harmonic_network_params, 440 | noise_network_params, 441 | filter_network_params, 442 | ]: 443 | params.update( 444 | { 445 | "residual_channels": residual_channels, 446 | "gate_channels": gate_channels, 447 | "skip_channels": skip_channels, 448 | "aux_channels": aux_channels, 449 | } 450 | ) 451 | self.harmonic_network = ResidualBlocks(**harmonic_network_params) 452 | self.noise_network = ResidualBlocks(**noise_network_params) 453 | self.filter_network = ResidualBlocks(**filter_network_params) 454 | 455 | # define periodicity estimator 456 | self.periodicity_estimator = PeriodicityEstimator( 457 | **periodicity_estimator_params, in_channels=aux_channels 458 | ) 459 | 460 | # convert hidden representation to output signal 461 | self.conv_last = nn.Sequential( 462 | nn.ReLU(), 463 | Conv1d1x1(skip_channels, skip_channels), 464 | nn.ReLU(), 465 | Conv1d1x1(skip_channels, out_channels), 466 | ) 467 | 468 | # apply weight norm 469 | if use_weight_norm: 470 | self.apply_weight_norm() 471 | 472 | def forward(self, x, c, d): 473 | """Calculate forward propagation. 474 | 475 | Args: 476 | x (Tensor): Input noise signal (B, 1, T). 477 | c (Tensor): Local conditioning auxiliary features (B, C ,T'). 478 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T). 479 | 480 | Returns: 481 | Tensor: Output tensor (B, 1, T) 482 | 483 | """ 484 | # index initialization 485 | batch_index, ch_index = index_initial(x.size(0), self.n_ch) 486 | 487 | # upsample auxiliary features 488 | c = self.upsample_net(c) 489 | assert c.size(-1) == x.size(-1) 490 | 491 | # estimate periodicity 492 | a = self.periodicity_estimator(c) 493 | 494 | # assume the first channel is sine and the other is noise 495 | sine, noise = torch.chunk(x, 2, 1) 496 | 497 | # encode to hidden representation 498 | h = self.conv_first_sine(sine) 499 | n = self.conv_first_noise(noise) 500 | 501 | # generate periodic and aperiodic source latent features 502 | h = self.harmonic_network(h, c, d, batch_index, ch_index) 503 | n = self.noise_network(n, c, d, batch_index, ch_index) 504 | 505 | # merge periodic and aperiodic latent features 506 | h = a * h 507 | n = (1.0 - a) * n 508 | s = h + n 509 | 510 | # resonance filtering 511 | x = self.filter_network(s, c, d, batch_index, ch_index) 512 | x = self.conv_last(x) 513 | 514 | # convert to 1d signal for regularization loss 515 | s = self.conv_last(s) 516 | 517 | # just for debug 518 | with torch.no_grad(): 519 | h = self.conv_last(h) 520 | n = self.conv_last(n) 521 | 522 | return x, s, h, n, a 523 | 524 | def remove_weight_norm(self): 525 | """Remove weight normalization module from all of the layers.""" 526 | 527 | def _remove_weight_norm(m): 528 | try: 529 | logger.debug(f"Weight norm is removed from {m}.") 530 | nn.utils.remove_weight_norm(m) 531 | except ValueError: # this module didn't have weight norm 532 | return 533 | 534 | self.apply(_remove_weight_norm) 535 | 536 | def apply_weight_norm(self): 537 | """Apply weight normalization module from all of the layers.""" 538 | 539 | def _apply_weight_norm(m): 540 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): 541 | nn.utils.weight_norm(m) 542 | logger.debug(f"Weight norm is applied to {m}.") 543 | 544 | self.apply(_apply_weight_norm) 545 | -------------------------------------------------------------------------------- /usfgan/bin/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2022 Reo Yoneyama (Nagoya University) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Training Script for Unified Source-Filter GAN. 7 | 8 | References: 9 | - https://github.com/bigpon/QPPWG 10 | - https://github.com/kan-bayashi/ParallelWaveGAN 11 | 12 | """ 13 | 14 | import os 15 | import sys 16 | from collections import defaultdict 17 | from logging import getLogger 18 | 19 | import hydra 20 | import librosa.display 21 | import matplotlib 22 | import numpy as np 23 | import torch 24 | import usfgan 25 | import usfgan.models 26 | from hydra.utils import to_absolute_path 27 | from omegaconf import DictConfig, OmegaConf 28 | from tensorboardX import SummaryWriter 29 | from torch.utils.data import DataLoader 30 | from tqdm import tqdm 31 | from usfgan.datasets import AudioFeatDataset 32 | from usfgan.utils.features import SignalGenerator 33 | 34 | # set to avoid matplotlib error in CLI environment 35 | matplotlib.use("Agg") 36 | 37 | 38 | # A logger for this file 39 | logger = getLogger(__name__) 40 | 41 | 42 | class Trainer(object): 43 | """Customized trainer module for Unified Source-Filter GAN training.""" 44 | 45 | def __init__( 46 | self, 47 | config, 48 | steps, 49 | epochs, 50 | data_loader, 51 | model, 52 | criterion, 53 | optimizer, 54 | scheduler, 55 | device=torch.device("cpu"), 56 | ): 57 | """Initialize trainer. 58 | 59 | Args: 60 | config (dict): Config dict loaded from yaml format configuration file. 61 | steps (int): Initial global steps. 62 | epochs (int): Initial global epochs. 63 | data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. 64 | model (dict): Dict of models. It must contrain "generator" and "discriminator" models. 65 | criterion (dict): Dict of criterions. It must contrain "adversarial", "encode" and "f0" criterions. 66 | optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers. 67 | scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers. 68 | device (torch.deive): Pytorch device instance. 69 | 70 | """ 71 | self.config = config 72 | self.steps = steps 73 | self.epochs = epochs 74 | self.data_loader = data_loader 75 | self.model = model 76 | self.criterion = criterion 77 | self.optimizer = optimizer 78 | self.scheduler = scheduler 79 | self.device = device 80 | self.finish_train = False 81 | self.writer = SummaryWriter(config.out_dir) 82 | self.total_train_loss = defaultdict(float) 83 | self.total_eval_loss = defaultdict(float) 84 | 85 | def run(self): 86 | """Run training.""" 87 | self.tqdm = tqdm( 88 | initial=self.steps, total=self.config.train.train_max_steps, desc="[train]" 89 | ) 90 | while True: 91 | # train one epoch 92 | self._train_epoch() 93 | 94 | # check whether training is finished 95 | if self.finish_train: 96 | break 97 | 98 | self.tqdm.close() 99 | logger.info("Finished training.") 100 | 101 | def save_checkpoint(self, checkpoint_path): 102 | """Save checkpoint. 103 | 104 | Args: 105 | checkpoint_path (str): Checkpoint path to be saved. 106 | 107 | """ 108 | state_dict = { 109 | "optimizer": { 110 | "generator": self.optimizer["generator"].state_dict(), 111 | "discriminator": self.optimizer["discriminator"].state_dict(), 112 | }, 113 | "scheduler": { 114 | "generator": self.scheduler["generator"].state_dict(), 115 | "discriminator": self.scheduler["discriminator"].state_dict(), 116 | }, 117 | "steps": self.steps, 118 | "epochs": self.epochs, 119 | } 120 | state_dict["model"] = { 121 | "generator": self.model["generator"].state_dict(), 122 | "discriminator": self.model["discriminator"].state_dict(), 123 | } 124 | 125 | if not os.path.exists(os.path.dirname(checkpoint_path)): 126 | os.makedirs(os.path.dirname(checkpoint_path)) 127 | torch.save(state_dict, checkpoint_path) 128 | 129 | def load_checkpoint(self, checkpoint_path, load_only_params=False): 130 | """Load checkpoint. 131 | 132 | Args: 133 | checkpoint_path (str): Checkpoint path to be loaded. 134 | load_only_params (bool): Whether to load only model parameters. 135 | 136 | """ 137 | state_dict = torch.load(checkpoint_path, map_location="cpu") 138 | self.model["generator"].load_state_dict(state_dict["model"]["generator"]) 139 | self.model["discriminator"].load_state_dict( 140 | state_dict["model"]["discriminator"] 141 | ) 142 | if not load_only_params: 143 | self.steps = state_dict["steps"] 144 | self.epochs = state_dict["epochs"] 145 | self.optimizer["generator"].load_state_dict( 146 | state_dict["optimizer"]["generator"] 147 | ) 148 | self.optimizer["discriminator"].load_state_dict( 149 | state_dict["optimizer"]["discriminator"] 150 | ) 151 | self.scheduler["generator"].load_state_dict( 152 | state_dict["scheduler"]["generator"] 153 | ) 154 | self.scheduler["discriminator"].load_state_dict( 155 | state_dict["scheduler"]["discriminator"] 156 | ) 157 | 158 | def _train_step(self, batch): 159 | """Train model one step.""" 160 | # parse batch 161 | x, y = batch 162 | x = tuple([x_.to(self.device) for x_ in x]) 163 | z, c, df, f0 = x 164 | y_real = y.to(self.device) 165 | 166 | # generator forward 167 | y_fake, s = self.model["generator"](z, c, df)[:2] 168 | 169 | # calculate spectral loss 170 | if isinstance(self.criterion["stft"], usfgan.losses.MultiResolutionSTFTLoss): 171 | # Parallel WaveGAN Multi-Resolution STFT Loss 172 | sc_loss, mag_loss = self.criterion["stft"](y_fake, y_real) 173 | gen_loss = self.config.train.lambda_stft * (sc_loss + mag_loss) 174 | self.total_train_loss["train/spectral_convergence_loss"] += sc_loss.item() 175 | self.total_train_loss["train/log_stft_magnitude_loss"] += mag_loss.item() 176 | elif isinstance( 177 | self.criterion["stft"], usfgan.losses.MultiResolutionLogSTFTPowerLoss 178 | ): 179 | # Neural Source-Filter Multi-Resolution STFT Loss 180 | stft_loss = self.criterion["stft"](y_fake, y_real) 181 | gen_loss = self.config.train.lambda_stft * stft_loss 182 | self.total_train_loss["train/log_stft_power_loss"] += stft_loss.item() 183 | elif isinstance(self.criterion["stft"], usfgan.losses.MelSpectralLoss): 184 | # HiFiGAN Mel-Spectrogram Reconstruction Loss 185 | mel_loss = self.criterion["stft"](y_fake, y_real) 186 | gen_loss = self.config.train.lambda_stft * mel_loss 187 | self.total_train_loss["train/log_mel_spec_loss"] += mel_loss.item() 188 | 189 | # calculate source regularization loss for usfgan-based models 190 | if self.config.train.lambda_source > 0: 191 | if isinstance( 192 | self.criterion["source"], 193 | usfgan.losses.ResidualLoss, 194 | ): 195 | source_loss = self.criterion["source"](s, y_real, f0) 196 | gen_loss += self.config.train.lambda_source * source_loss 197 | self.total_train_loss["train/source_loss"] += source_loss.item() 198 | else: 199 | source_loss = self.criterion["source"](s, f0) 200 | gen_loss += self.config.train.lambda_source * source_loss 201 | self.total_train_loss["train/source_loss"] += source_loss.item() 202 | 203 | # calculate discriminator related losses 204 | if self.steps > self.config.train.discriminator_train_start_steps: 205 | # calculate feature matching loss 206 | if self.config.train.lambda_feat_match > 0: 207 | p_fake, fmaps_fake = self.model["discriminator"]( 208 | y_fake, return_fmaps=True 209 | ) 210 | with torch.no_grad(): 211 | p_real, fmaps_real = self.model["discriminator"]( 212 | y_real, return_fmaps=True 213 | ) 214 | fm_loss = self.criterion["feat_match"](fmaps_fake, fmaps_real) 215 | gen_loss += self.config.train.lambda_feat_match * fm_loss 216 | self.total_train_loss["train/feat_match_loss"] += fm_loss.item() 217 | else: 218 | p_fake = self.model["discriminator"](y_fake) 219 | # calculate adversarial loss 220 | adv_loss = self.criterion["adversarial"](p_fake) 221 | gen_loss += self.config.train.lambda_adv * adv_loss 222 | self.total_train_loss["train/adversarial_loss"] += adv_loss.item() 223 | 224 | self.total_train_loss["train/generator_loss"] += gen_loss.item() 225 | 226 | # update generator 227 | self.optimizer["generator"].zero_grad() 228 | gen_loss.backward() 229 | if self.config.train.generator_grad_norm > 0: 230 | torch.nn.utils.clip_grad_norm_( 231 | self.model["generator"].parameters(), 232 | self.config.train.generator_grad_norm, 233 | ) 234 | self.optimizer["generator"].step() 235 | self.scheduler["generator"].step() 236 | 237 | # discriminator 238 | if self.steps > self.config.train.discriminator_train_start_steps: 239 | # re-compute y_fake 240 | with torch.no_grad(): 241 | y_fake = self.model["generator"](z, c, df)[0] 242 | # calculate discriminator loss 243 | p_fake = self.model["discriminator"](y_fake.detach()) 244 | p_real = self.model["discriminator"](y_real) 245 | # NOTE: the first argument must to be the fake samples 246 | fake_loss, real_loss = self.criterion["adversarial"](p_fake, p_real) 247 | dis_loss = fake_loss + real_loss 248 | self.total_train_loss["train/fake_loss"] += fake_loss.item() 249 | self.total_train_loss["train/real_loss"] += real_loss.item() 250 | self.total_train_loss["train/discriminator_loss"] += dis_loss.item() 251 | 252 | # update discriminator 253 | self.optimizer["discriminator"].zero_grad() 254 | dis_loss.backward() 255 | if self.config.train.discriminator_grad_norm > 0: 256 | torch.nn.utils.clip_grad_norm_( 257 | self.model["discriminator"].parameters(), 258 | self.config.train.discriminator_grad_norm, 259 | ) 260 | self.optimizer["discriminator"].step() 261 | self.scheduler["discriminator"].step() 262 | 263 | # update counts 264 | self.steps += 1 265 | self.tqdm.update(1) 266 | self._check_train_finish() 267 | 268 | def _train_epoch(self): 269 | """Train model one epoch.""" 270 | for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): 271 | # train one step 272 | self._train_step(batch) 273 | 274 | # check interval 275 | self._check_log_interval() 276 | self._check_eval_interval() 277 | self._check_save_interval() 278 | 279 | # check whether training is finished 280 | if self.finish_train: 281 | return 282 | 283 | # update 284 | self.epochs += 1 285 | self.train_steps_per_epoch = train_steps_per_epoch 286 | logger.info( 287 | f"(Steps: {self.steps}) Finished {self.epochs} epoch training " 288 | f"({self.train_steps_per_epoch} steps per epoch)." 289 | ) 290 | 291 | @torch.no_grad() 292 | def _eval_step(self, batch): 293 | """Evaluate model one step.""" 294 | # parse batch 295 | x, y = batch 296 | x = tuple([x_.to(self.device) for x_ in x]) 297 | z, c, df, f0 = x 298 | y_real = y.to(self.device) 299 | 300 | # generator forward 301 | y_fake, s = self.model["generator"](z, c, df)[:2] 302 | 303 | # calculate spectral loss 304 | if isinstance(self.criterion["stft"], usfgan.losses.MultiResolutionSTFTLoss): 305 | # Parallel WaveGAN Multi-Resolution STFT Loss 306 | sc_loss, mag_loss = self.criterion["stft"](y_fake, y_real) 307 | gen_loss = self.config.train.lambda_stft * (sc_loss + mag_loss) 308 | self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item() 309 | self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item() 310 | elif isinstance( 311 | self.criterion["stft"], usfgan.losses.MultiResolutionLogSTFTPowerLoss 312 | ): 313 | # Neural Source-Filter Multi-Resolution STFT Loss 314 | stft_loss = self.criterion["stft"](y_fake, y_real) 315 | gen_loss = self.config.train.lambda_stft * stft_loss 316 | self.total_eval_loss["eval/log_stft_power_loss"] += stft_loss.item() 317 | elif isinstance(self.criterion["stft"], usfgan.losses.MelSpectralLoss): 318 | # HiFiGAN Mel-Spectrogram Reconstruction Loss 319 | mel_loss = self.criterion["stft"](y_fake, y_real) 320 | gen_loss = self.config.train.lambda_stft * mel_loss 321 | self.total_eval_loss["eval/log_mel_spec_loss"] += mel_loss.item() 322 | 323 | # calculate source regularization loss for usfgan-based models 324 | if self.config.train.lambda_source > 0: 325 | if isinstance( 326 | self.criterion["source"], 327 | usfgan.losses.ResidualLoss, 328 | ): 329 | source_loss = self.criterion["source"](s, y_real, f0) 330 | gen_loss += self.config.train.lambda_source * source_loss 331 | self.total_eval_loss["eval/source_loss"] += source_loss.item() 332 | else: 333 | source_loss = self.criterion["source"](s, f0) 334 | gen_loss += self.config.train.lambda_source * source_loss 335 | self.total_eval_loss["eval/source_loss"] += source_loss.item() 336 | 337 | # calculate discriminator related losses 338 | if self.steps > self.config.train.discriminator_train_start_steps: 339 | # calculate feature matching loss 340 | if self.config.train.lambda_feat_match > 0: 341 | p_fake, fmaps_fake = self.model["discriminator"]( 342 | y_fake, return_fmaps=True 343 | ) 344 | p_real, fmaps_real = self.model["discriminator"]( 345 | y_real, return_fmaps=True 346 | ) 347 | fm_loss = self.criterion["feat_match"](fmaps_fake, fmaps_real) 348 | gen_loss += self.config.train.lambda_feat_match * fm_loss 349 | self.total_eval_loss["eval/feat_match_loss"] += fm_loss.item() 350 | else: 351 | p_fake = self.model["discriminator"](y_fake) 352 | # calculate adversarial loss 353 | adv_loss = self.criterion["adversarial"](p_fake) 354 | gen_loss += self.config.train.lambda_adv * adv_loss 355 | self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item() 356 | 357 | self.total_eval_loss["eval/generator_loss"] += gen_loss.item() 358 | 359 | # discriminator 360 | if self.steps > self.config.train.discriminator_train_start_steps: 361 | # calculate discriminator loss 362 | p_real = self.model["discriminator"](y_real) 363 | # NOTE: the first augment must to be the fake sample 364 | fake_loss, real_loss = self.criterion["adversarial"](p_fake, p_real) 365 | dis_loss = fake_loss + real_loss 366 | self.total_eval_loss["eval/fake_loss"] += fake_loss.item() 367 | self.total_eval_loss["eval/real_loss"] += real_loss.item() 368 | self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item() 369 | 370 | def _eval_epoch(self): 371 | """Evaluate model one epoch.""" 372 | logger.info(f"(Steps: {self.steps}) Start evaluation.") 373 | # change mode 374 | for key in self.model.keys(): 375 | self.model[key].eval() 376 | 377 | # calculate loss for each batch 378 | for eval_steps_per_epoch, batch in enumerate( 379 | tqdm(self.data_loader["valid"], desc="[eval]"), 1 380 | ): 381 | # eval one step 382 | self._eval_step(batch) 383 | 384 | # save intermediate result 385 | if eval_steps_per_epoch == 1: 386 | self._genearete_and_save_intermediate_result(batch) 387 | if eval_steps_per_epoch == 3: 388 | break 389 | 390 | logger.info( 391 | f"(Steps: {self.steps}) Finished evaluation " 392 | f"({eval_steps_per_epoch} steps per epoch)." 393 | ) 394 | 395 | # average loss 396 | for key in self.total_eval_loss.keys(): 397 | self.total_eval_loss[key] /= eval_steps_per_epoch 398 | logger.info( 399 | f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." 400 | ) 401 | 402 | # record 403 | self._write_to_tensorboard(self.total_eval_loss) 404 | 405 | # reset 406 | self.total_eval_loss = defaultdict(float) 407 | 408 | # restore mode 409 | for key in self.model.keys(): 410 | self.model[key].train() 411 | 412 | @torch.no_grad() 413 | def _genearete_and_save_intermediate_result(self, batch): 414 | """Generate and save intermediate result.""" 415 | # delayed import to avoid error related backend error 416 | import matplotlib.pyplot as plt 417 | 418 | x_batch, y_real_batch = batch 419 | # use only the first sample 420 | x_batch = [x[:1].to(self.device) for x in x_batch] 421 | y_real_batch = y_real_batch[:1] 422 | z_batch, c_batch, df_batch = x_batch[:3] 423 | 424 | # generator forward 425 | h_batch, n_batch, a_batch = None, None, None 426 | if isinstance( 427 | self.model["generator"], 428 | ( 429 | usfgan.models.ParallelHnUSFGANGenerator, 430 | usfgan.models.CascadeHnUSFGANGenerator, 431 | ), 432 | ): 433 | y_fake_batch, s_batch, h_batch, n_batch, a_batch = self.model["generator"]( 434 | z_batch, c_batch, df_batch 435 | ) 436 | else: 437 | y_fake_batch, s_batch = self.model["generator"](z_batch, c_batch, df_batch) 438 | 439 | len50ms = int(self.config.data.sample_rate * 0.05) 440 | start = np.random.randint(0, self.config.data.batch_max_length - len50ms) 441 | end = start + len50ms 442 | 443 | for audio, name, save_wav in zip( 444 | [y_real_batch, y_fake_batch, s_batch, h_batch, n_batch], 445 | ["real", "fake", "source", "harmonic", "noise"], 446 | [True, True, True, False, False], 447 | ): 448 | if audio is not None: 449 | audio = audio.view(-1).cpu().numpy() 450 | 451 | # plot spectrogram 452 | fig = plt.figure(figsize=(8, 6)) 453 | spectrogram = np.abs( 454 | librosa.stft( 455 | y=audio, 456 | n_fft=1024, 457 | hop_length=128, 458 | win_length=1024, 459 | window="hann", 460 | ) 461 | ) 462 | spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max) 463 | librosa.display.specshow( 464 | spectrogram_db, 465 | sr=self.config.data.sample_rate, 466 | y_axis="linear", 467 | x_axis="time", 468 | hop_length=128, 469 | ) 470 | self.writer.add_figure(f"spectrogram/{name}", fig, self.steps) 471 | plt.clf() 472 | plt.close() 473 | 474 | # plot full waveform 475 | fig = plt.figure(figsize=(6, 3)) 476 | plt.plot(audio, linewidth=1) 477 | self.writer.add_figure(f"waveform/{name}", fig, self.steps) 478 | plt.clf() 479 | plt.close() 480 | 481 | # plot short term waveform 482 | fig = plt.figure(figsize=(6, 3)) 483 | plt.plot(audio[start:end], linewidth=1) 484 | self.writer.add_figure(f"short_waveform/{name}", fig, self.steps) 485 | plt.clf() 486 | plt.close() 487 | 488 | # save as wavfile 489 | if save_wav: 490 | audio = audio / np.max(np.abs(audio)) 491 | self.writer.add_audio( 492 | f"audio_{name}.wav", 493 | audio, 494 | self.steps, 495 | self.config.data.sample_rate, 496 | ) 497 | 498 | # plot aperiodicity weights 499 | if a_batch is not None: 500 | fig = plt.figure(figsize=(6, 4)) 501 | plt.imshow(a_batch.squeeze(0).cpu().numpy(), aspect="auto") 502 | plt.colorbar() 503 | self.writer.add_figure(f"aperiodicity", fig, self.steps) 504 | plt.clf() 505 | plt.close() 506 | 507 | def _write_to_tensorboard(self, loss): 508 | """Write to tensorboard.""" 509 | for key, value in loss.items(): 510 | self.writer.add_scalar(key, value, self.steps) 511 | 512 | def _check_save_interval(self): 513 | if self.steps % self.config.train.save_interval_steps == 0: 514 | self.save_checkpoint( 515 | os.path.join( 516 | self.config.out_dir, 517 | "checkpoints", 518 | f"checkpoint-{self.steps}steps.pkl", 519 | ) 520 | ) 521 | logger.info(f"Successfully saved checkpoint @ {self.steps} steps.") 522 | 523 | def _check_eval_interval(self): 524 | if self.steps % self.config.train.eval_interval_steps == 0: 525 | self._eval_epoch() 526 | 527 | def _check_log_interval(self): 528 | if self.steps % self.config.train.log_interval_steps == 0: 529 | for key in self.total_train_loss.keys(): 530 | self.total_train_loss[key] /= self.config.train.log_interval_steps 531 | logger.info( 532 | f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." 533 | ) 534 | self._write_to_tensorboard(self.total_train_loss) 535 | 536 | # reset 537 | self.total_train_loss = defaultdict(float) 538 | 539 | def _check_train_finish(self): 540 | if self.steps >= self.config.train.train_max_steps: 541 | self.finish_train = True 542 | 543 | 544 | class Collater(object): 545 | """Customized collater for Pytorch DataLoader in training.""" 546 | 547 | def __init__( 548 | self, 549 | batch_max_length=12000, 550 | sample_rate=24000, 551 | hop_size=120, 552 | aux_context_window=2, 553 | sine_amp=0.1, 554 | noise_amp=0.003, 555 | sine_f0_type="contf0", 556 | signal_types=["sine", "noise"], 557 | ): 558 | """Initialize customized collater for PyTorch DataLoader. 559 | 560 | Args: 561 | batch_max_length (int): The maximum length of batch. 562 | sample_rate (int): Sampling rate. 563 | hop_size (int): Hop size of auxiliary features. 564 | aux_context_window (int): Context window size for auxiliary feature conv. 565 | sine_amp (float): Amplitude of sine signal. 566 | noise_amp (float): Amplitude of random noise signal. 567 | sine_f0_type (str): F0 type for generating sine signal. 568 | signal_types (list): List of types for input signals. 569 | 570 | """ 571 | if batch_max_length % hop_size != 0: 572 | batch_max_length += -(batch_max_length % hop_size) 573 | assert batch_max_length % hop_size == 0 574 | self.batch_max_length = batch_max_length 575 | self.batch_max_frames = batch_max_length // hop_size 576 | self.sample_rate = sample_rate 577 | self.hop_size = hop_size 578 | self.aux_context_window = aux_context_window 579 | self.sine_f0_type = sine_f0_type 580 | self.signal_generator = SignalGenerator( 581 | sample_rate=sample_rate, 582 | hop_size=hop_size, 583 | sine_amp=sine_amp, 584 | noise_amp=noise_amp, 585 | signal_types=signal_types, 586 | ) 587 | 588 | def __call__(self, batch): 589 | """Convert into batch tensors. 590 | 591 | Args: 592 | batch (list): list of tuple of the pair of audio and features. 593 | 594 | Returns: 595 | Tensor: Gaussian noise (and sine) batch (B, D, T). 596 | Tensor: Auxiliary feature batch (B, C, T' + 2 * aux_context_window). 597 | Tensor: Dilated factor batch (B, 1, T). 598 | Tensor: F0 sequence batch (B, 1, T'). 599 | Tensor: Target signal batch (B, 1, T). 600 | 601 | """ 602 | # time resolution check 603 | y_batch, c_batch, df_batch, f0_batch, contf0_batch = [], [], [], [], [] 604 | for idx in range(len(batch)): 605 | x, c, df, f0, contf0 = batch[idx] 606 | self._check_length(x, c, df, f0, contf0, 0) 607 | if len(c) - 2 * self.aux_context_window > self.batch_max_frames: 608 | # randomly pickup with the batch_max_length length of the part 609 | interval_start = self.aux_context_window 610 | interval_end = len(c) - self.batch_max_frames - self.aux_context_window 611 | start_frame = np.random.randint(interval_start, interval_end) 612 | start_step = start_frame * self.hop_size 613 | y = x[start_step : start_step + self.batch_max_length] 614 | c = c[ 615 | start_frame 616 | - self.aux_context_window : start_frame 617 | + self.aux_context_window 618 | + self.batch_max_frames 619 | ] 620 | df = df[start_step : start_step + self.batch_max_length] 621 | f0 = f0[start_frame : start_frame + self.batch_max_frames] 622 | contf0 = contf0[start_frame : start_frame + self.batch_max_frames] 623 | self._check_length( 624 | y, 625 | c, 626 | df, 627 | f0, 628 | contf0, 629 | self.aux_context_window, 630 | ) 631 | else: 632 | logger.warn(f"Removed short sample from batch (length={len(x)}).") 633 | continue 634 | y_batch += [y.astype(np.float32).reshape(-1, 1)] # [(T, 1), ...] 635 | c_batch += [c.astype(np.float32)] # [(T' + 2 * aux_context_window, D), ...] 636 | df_batch += [df.astype(np.float32).reshape(-1, 1)] # [(T, 1), ...] 637 | f0_batch += [f0.astype(np.float32).reshape(-1, 1)] # [(T', 1), ...] 638 | contf0_batch += [contf0.astype(np.float32).reshape(-1, 1)] # [(T', 1), ...] 639 | 640 | # convert each batch to tensor, asuume that each item in batch has the same length 641 | y_batch = torch.FloatTensor(np.array(y_batch)).transpose(2, 1) # (B, 1, T) 642 | c_batch = torch.FloatTensor(np.array(c_batch)).transpose( 643 | 2, 1 644 | ) # (B, 1, T' + 2 * aux_context_window) 645 | df_batch = torch.FloatTensor(np.array(df_batch)).transpose(2, 1) # (B, 1, T) 646 | f0_batch = torch.FloatTensor(np.array(f0_batch)).transpose(2, 1) # (B, 1, T') 647 | contf0_batch = torch.FloatTensor(np.array(contf0_batch)).transpose( 648 | 2, 1 649 | ) # (B, 1, T') 650 | 651 | # make input signal batch tensor 652 | if self.sine_f0_type == "contf0": 653 | in_batch = self.signal_generator(contf0_batch) 654 | else: 655 | in_batch = self.signal_generator(f0_batch) 656 | 657 | return (in_batch, c_batch, df_batch, f0_batch), y_batch 658 | 659 | def _check_length(self, x, c, df, f0, contf0, context_window): 660 | """Assert the audio and feature lengths are correctly adjusted for upsamping.""" 661 | assert len(x) == (len(c) - 2 * context_window) * self.hop_size 662 | assert len(x) == len(df) 663 | assert len(x) == len(f0) * self.hop_size 664 | assert len(x) == len(contf0) * self.hop_size 665 | 666 | 667 | @hydra.main(version_base=None, config_path="config", config_name="train") 668 | def main(config: DictConfig) -> None: 669 | """Run training process.""" 670 | 671 | if not torch.cuda.is_available(): 672 | print("CPU") 673 | device = torch.device("cpu") 674 | else: 675 | print("GPU") 676 | device = torch.device("cuda") 677 | # effective when using fixed size inputs 678 | # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 679 | torch.backends.cudnn.benchmark = True 680 | 681 | # fix seed 682 | np.random.seed(config.seed) 683 | torch.manual_seed(config.seed) 684 | torch.cuda.manual_seed(config.seed) 685 | os.environ["PYTHONHASHSEED"] = str(config.seed) 686 | 687 | # check directory existence 688 | if not os.path.exists(config.out_dir): 689 | os.makedirs(config.out_dir) 690 | 691 | # write config to yaml file 692 | with open(os.path.join(config.out_dir, "config.yaml"), "w") as f: 693 | f.write(OmegaConf.to_yaml(config)) 694 | logger.info(OmegaConf.to_yaml(config)) 695 | 696 | # get dataset 697 | if config.data.remove_short_samples: 698 | feat_length_threshold = ( 699 | config.data.batch_max_length // config.data.hop_size 700 | + 2 * config.generator.aux_context_window 701 | ) 702 | else: 703 | feat_length_threshold = None 704 | 705 | train_dataset = AudioFeatDataset( 706 | stats=to_absolute_path(config.data.stats), 707 | audio_list=to_absolute_path(config.data.train_audio), 708 | feat_list=to_absolute_path(config.data.train_feat), 709 | feat_length_threshold=feat_length_threshold, 710 | allow_cache=config.data.allow_cache, 711 | sample_rate=config.data.sample_rate, 712 | hop_size=config.data.hop_size, 713 | dense_factor=config.data.dense_factor, 714 | df_f0_type=config.data.df_f0_type, 715 | aux_feats=config.data.aux_feats, 716 | ) 717 | logger.info(f"The number of training files = {len(train_dataset)}.") 718 | 719 | valid_dataset = AudioFeatDataset( 720 | stats=to_absolute_path(config.data.stats), 721 | audio_list=to_absolute_path(config.data.valid_audio), 722 | feat_list=to_absolute_path(config.data.valid_feat), 723 | feat_length_threshold=feat_length_threshold, 724 | allow_cache=config.data.allow_cache, 725 | sample_rate=config.data.sample_rate, 726 | hop_size=config.data.hop_size, 727 | dense_factor=config.data.dense_factor, 728 | df_f0_type=config.data.df_f0_type, 729 | aux_feats=config.data.aux_feats, 730 | ) 731 | logger.info(f"The number of validation files = {len(valid_dataset)}.") 732 | 733 | dataset = { 734 | "train": train_dataset, 735 | "valid": valid_dataset, 736 | } 737 | 738 | # get data loader 739 | collater = Collater( 740 | batch_max_length=config.data.batch_max_length, 741 | aux_context_window=config.generator.aux_context_window, 742 | sample_rate=config.data.sample_rate, 743 | hop_size=config.data.hop_size, 744 | sine_amp=config.data.sine_amp, 745 | noise_amp=config.data.noise_amp, 746 | sine_f0_type=config.data.sine_f0_type, 747 | signal_types=config.data.signal_types, 748 | ) 749 | train_sampler, valid_sampler = None, None 750 | 751 | data_loader = { 752 | "train": DataLoader( 753 | dataset=dataset["train"], 754 | shuffle=True, 755 | collate_fn=collater, 756 | batch_size=config.data.batch_size, 757 | num_workers=config.data.num_workers, 758 | sampler=train_sampler, 759 | pin_memory=config.data.pin_memory, 760 | ), 761 | "valid": DataLoader( 762 | dataset=dataset["valid"], 763 | shuffle=True, 764 | collate_fn=collater, 765 | batch_size=config.data.batch_size, 766 | num_workers=config.data.num_workers, 767 | sampler=valid_sampler, 768 | pin_memory=config.data.pin_memory, 769 | ), 770 | } 771 | 772 | # define models and optimizers 773 | model = { 774 | "generator": hydra.utils.instantiate(config.generator).to(device), 775 | "discriminator": hydra.utils.instantiate(config.discriminator).to(device), 776 | } 777 | 778 | # define training criteria 779 | criterion = { 780 | "stft": hydra.utils.instantiate(config.train.stft_loss).to(device), 781 | "adversarial": hydra.utils.instantiate(config.train.adversarial_loss).to( 782 | device 783 | ), 784 | } 785 | if config.train.lambda_feat_match > 0: 786 | criterion["feat_match"] = hydra.utils.instantiate( 787 | config.train.feat_match_loss 788 | ).to(device) 789 | if config.train.lambda_source > 0: 790 | criterion["source"] = hydra.utils.instantiate(config.train.source_loss).to( 791 | device 792 | ) 793 | 794 | # define optimizers and schedulers 795 | optimizer = { 796 | "generator": hydra.utils.instantiate( 797 | config.train.generator_optimizer, 798 | params=model["generator"].parameters(), 799 | ), 800 | "discriminator": hydra.utils.instantiate( 801 | config.train.discriminator_optimizer, 802 | params=model["discriminator"].parameters(), 803 | ), 804 | } 805 | scheduler = { 806 | "generator": hydra.utils.instantiate( 807 | config.train.generator_scheduler, 808 | optimizer=optimizer["generator"], 809 | ), 810 | "discriminator": hydra.utils.instantiate( 811 | config.train.discriminator_scheduler, 812 | optimizer=optimizer["discriminator"], 813 | ), 814 | } 815 | 816 | # define trainer 817 | trainer = Trainer( 818 | config=config, 819 | steps=0, 820 | epochs=0, 821 | data_loader=data_loader, 822 | model=model, 823 | criterion=criterion, 824 | optimizer=optimizer, 825 | scheduler=scheduler, 826 | device=device, 827 | ) 828 | 829 | # load trained parameters from checkpoint 830 | if config.train.resume: 831 | resume = os.path.join( 832 | config.out_dir, "checkpoints", f"checkpoint-{config.train.resume}steps.pkl" 833 | ) 834 | if os.path.exists(resume): 835 | trainer.load_checkpoint(resume) 836 | logger.info(f"Successfully resumed from {resume}.") 837 | else: 838 | logger.info(f"Failed to resume from {resume}.") 839 | sys.exit(0) 840 | else: 841 | logger.info("Start a new training process.") 842 | 843 | # run training loop 844 | try: 845 | trainer.run() 846 | except KeyboardInterrupt: 847 | trainer.save_checkpoint( 848 | os.path.join( 849 | config.out_dir, "checkpoints", f"checkpoint-{trainer.steps}steps.pkl" 850 | ) 851 | ) 852 | logger.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") 853 | 854 | 855 | if __name__ == "__main__": 856 | main() 857 | --------------------------------------------------------------------------------