├── usfgan
├── bin
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── compute_statistics.yaml
│ │ ├── train.yaml
│ │ ├── decode.yaml
│ │ ├── discriminator
│ │ │ ├── pwg.yaml
│ │ │ ├── univnet.yaml
│ │ │ └── hifigan.yaml
│ │ ├── generator
│ │ │ ├── usfgan.yaml
│ │ │ ├── cascade_hn_usfgan.yaml
│ │ │ └── parallel_hn_usfgan.yaml
│ │ ├── anasyn.yaml
│ │ ├── extract_features.yaml
│ │ ├── data
│ │ │ └── vctk_24kHz.yaml
│ │ └── train
│ │ │ ├── usfgan.yaml
│ │ │ └── hn_usfgan.yaml
│ ├── compute_statistics.py
│ ├── decode.py
│ ├── anasyn.py
│ ├── extract_features.py
│ └── train.py
├── __init__.py
├── datasets
│ ├── __init__.py
│ └── audio_feat_dataset.py
├── optimizers
│ ├── __init__.py
│ └── radam.py
├── models
│ ├── __init__.py
│ └── generator.py
├── layers
│ ├── __init__.py
│ ├── upsample.py
│ ├── cheaptrick.py
│ └── residual_block.py
├── utils
│ ├── __init__.py
│ ├── filters.py
│ ├── index.py
│ ├── utils.py
│ └── features.py
└── losses
│ ├── __init__.py
│ ├── feat_match.py
│ ├── adversarial.py
│ ├── source.py
│ └── stft.py
├── .gitignore
├── egs
└── vctk
│ └── data
│ ├── scp
│ ├── vctk_train_24kHz.scp
│ └── vctk_train_24kHz.list
│ └── spk_info.yaml
├── LICENSE
├── setup.py
└── README.md
/usfgan/bin/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/usfgan/bin/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/usfgan/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | __version__ = "0.0.1.post1"
4 |
--------------------------------------------------------------------------------
/usfgan/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from usfgan.datasets.audio_feat_dataset import * # NOQA
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .eggs
3 | *.egg-info
4 | *.log
5 | *.out
6 | .venv
7 |
8 | egs/vctk/exp/
--------------------------------------------------------------------------------
/usfgan/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.optim import * # NOQA
2 | from usfgan.optimizers.radam import * # NOQA
3 |
--------------------------------------------------------------------------------
/usfgan/models/__init__.py:
--------------------------------------------------------------------------------
1 | from usfgan.models.discriminator import * # NOQA
2 | from usfgan.models.generator import * # NOQA
3 |
--------------------------------------------------------------------------------
/usfgan/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from usfgan.layers.cheaptrick import * # NOQA
2 | from usfgan.layers.residual_block import * # NOQA
3 | from usfgan.layers.upsample import * # NOQA
4 |
--------------------------------------------------------------------------------
/usfgan/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from usfgan.utils.features import * # NOQA
2 | from usfgan.utils.filters import * # NOQA
3 | from usfgan.utils.index import * # NOQA
4 | from usfgan.utils.utils import * # NOQA
5 |
--------------------------------------------------------------------------------
/usfgan/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from usfgan.losses.adversarial import * # NOQA
2 | from usfgan.losses.feat_match import * # NOQA
3 | from usfgan.losses.source import * # NOQA
4 | from usfgan.losses.stft import * # NOQA
5 |
--------------------------------------------------------------------------------
/usfgan/bin/config/compute_statistics.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | job:
5 | chdir: false
6 | output_subdir: null
7 | job_logging:
8 | formatters:
9 | simple:
10 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s'
11 |
12 | feats: data/scp/vctk_train_24kHz.list # List file of input features.
13 | stats: data/stats/vctk_train_24kHz.joblib # Path to file to output statistics.
14 | feat_types: ['f0', 'contf0', 'mcep', 'mcap'] # Feature types.
15 |
--------------------------------------------------------------------------------
/usfgan/bin/config/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - generator: parallel_hn_usfgan
6 | - discriminator: hifigan
7 | - train: hn_usfgan
8 | - data: vctk_24kHz
9 |
10 | hydra:
11 | job:
12 | chdir: false
13 | output_subdir: null
14 | job_logging:
15 | formatters:
16 | simple:
17 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s'
18 |
19 | out_dir: # Directory to output training results.
20 | seed: 12345 # Seed number for random numbers.
21 |
--------------------------------------------------------------------------------
/egs/vctk/data/scp/vctk_train_24kHz.scp:
--------------------------------------------------------------------------------
1 | path_to_your_own_dataset_dir/wav/p225/p225_001.wav
2 | path_to_your_own_dataset_dir/wav/p225/p225_002.wav
3 | path_to_your_own_dataset_dir/wav/p225/p225_003.wav
4 | path_to_your_own_dataset_dir/wav/p225/p225_004.wav
5 | path_to_your_own_dataset_dir/wav/p225/p225_005.wav
6 | path_to_your_own_dataset_dir/wav/p225/p225_006.wav
7 | path_to_your_own_dataset_dir/wav/p225/p225_007.wav
8 | path_to_your_own_dataset_dir/wav/p225/p225_008.wav
9 | path_to_your_own_dataset_dir/wav/p225/p225_009.wav
10 | path_to_your_own_dataset_dir/wav/p225/p225_010.wav
--------------------------------------------------------------------------------
/egs/vctk/data/scp/vctk_train_24kHz.list:
--------------------------------------------------------------------------------
1 | path_to_your_own_dataset_dir/hdf5/p225/p225_001.h5
2 | path_to_your_own_dataset_dir/hdf5/p225/p225_002.h5
3 | path_to_your_own_dataset_dir/hdf5/p225/p225_003.h5
4 | path_to_your_own_dataset_dir/hdf5/p225/p225_004.h5
5 | path_to_your_own_dataset_dir/hdf5/p225/p225_005.h5
6 | path_to_your_own_dataset_dir/hdf5/p225/p225_006.h5
7 | path_to_your_own_dataset_dir/hdf5/p225/p225_007.h5
8 | path_to_your_own_dataset_dir/hdf5/p225/p225_008.h5
9 | path_to_your_own_dataset_dir/hdf5/p225/p225_009.h5
10 | path_to_your_own_dataset_dir/hdf5/p225/p225_010.h5
--------------------------------------------------------------------------------
/usfgan/bin/config/decode.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - generator: parallel_hn_usfgan
6 | - data: vctk_24kHz
7 |
8 | hydra:
9 | job:
10 | chdir: false
11 | output_subdir: null
12 | job_logging:
13 | formatters:
14 | simple:
15 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s'
16 |
17 | out_dir: # Directory to output decoding results.
18 | checkpoint_path: # Path to the checkpoint of pre-trained model.
19 | f0_factor: 1.00 # F0 factor.
20 | seed: 100 # Seed number for random numbers.
21 | save_source: false # Whether to save source excitation signals.
--------------------------------------------------------------------------------
/usfgan/bin/config/discriminator/pwg.yaml:
--------------------------------------------------------------------------------
1 | _target_: usfgan.models.PWGDiscriminator
2 | in_channels: 1 # Number of input channels.
3 | out_channels: 1 # Number of output channels.
4 | kernel_size: 3 # Number of output channels.
5 | layers: 10 # Number of conv layers.
6 | conv_channels: 64 # Number of cnn layers.
7 | dilation_factor: 2 # Dilation factor.
8 | bias: true # Whether to use bias parameter in conv.
9 | use_weight_norm: true # Whether to use weight norm.
10 | # If set to true, it will be applied to all of the conv layers.
11 | nonlinear_activation: LeakyReLU # Nonlinear function after each conv.
12 | nonlinear_activation_params: # Nonlinear function parameters
13 | negative_slope: 0.2 # Alpha in LeakyReLU.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Reo YONEYAMA
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/usfgan/losses/feat_match.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Feature matching loss module.
7 |
8 | References:
9 | - https://github.com/kan-bayashi/ParallelWaveGAN
10 |
11 | """
12 |
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 |
17 | class FeatureMatchLoss(nn.Module):
18 | # Feature matching loss module.
19 |
20 | def __init__(
21 | self,
22 | average_by_layers=False,
23 | ):
24 | """Initialize FeatureMatchLoss module."""
25 | super(FeatureMatchLoss, self).__init__()
26 | self.average_by_layers = average_by_layers
27 |
28 | def forward(self, fmaps_fake, fmaps_real):
29 | """Calculate forward propagation.
30 |
31 | Args:
32 | fmaps_fake (list): List of discriminator outputs
33 | calcuated from generater outputs.
34 | fmaps_real (list): List of discriminator outputs
35 | calcuated from groundtruth.
36 |
37 | Returns:
38 | Tensor: Feature matching loss value.
39 |
40 | """
41 | feat_match_loss = 0.0
42 | for feat_fake, feat_real in zip(fmaps_fake, fmaps_real):
43 | feat_match_loss += F.l1_loss(feat_fake, feat_real.detach())
44 |
45 | if self.average_by_layers:
46 | feat_match_loss /= len(fmaps_fake)
47 |
48 | return feat_match_loss
49 |
--------------------------------------------------------------------------------
/usfgan/bin/config/discriminator/univnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: usfgan.models.UnivNetMultiResolutionMultiPeriodDiscriminator
2 | fft_sizes: [1024, 2048, 512]
3 | hop_sizes: [120, 240, 50]
4 | win_lengths: [600, 1200, 240]
5 | window: "hann_window"
6 | spectral_discriminator_params:
7 | channels: 32
8 | kernel_sizes: [[3, 9], [3, 9], [3, 9], [3, 9], [3, 3], [3, 3]]
9 | strides: [[1, 1], [1, 2], [1, 2], [1, 2], [1, 1], [1, 1]]
10 | bias: true
11 | nonlinear_activation: "LeakyReLU"
12 | nonlinear_activation_params:
13 | negative_slope: 0.2
14 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
15 | period_discriminator_params:
16 | in_channels: 1 # Number of input channels.
17 | out_channels: 1 # Number of output channels.
18 | kernel_sizes: [5, 3] # List of kernel sizes.
19 | channels: 32 # Initial number of channels.
20 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
21 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
22 | bias: true # Whether to use bias parameter in conv layer."
23 | nonlinear_activation: "LeakyReLU" # Nonlinear activation.
24 | nonlinear_activation_params: # Nonlinear activation paramters.
25 | negative_slope: 0.1
26 | use_weight_norm: true # Whether to apply weight normalization.
27 | use_spectral_norm: false # Whether to apply spectral normalization.
28 |
--------------------------------------------------------------------------------
/usfgan/bin/config/generator/usfgan.yaml:
--------------------------------------------------------------------------------
1 | _target_: usfgan.models.USFGANGenerator
2 | source_network_params:
3 | blockA: 30 # Number of adaptive residual blocks.
4 | cycleA: 6 # Number of adaptive dilation cycles.
5 | blockF: 0 # Number of fixed residual blocks.
6 | cycleF: 0 # Number of fixed dilation cycles.
7 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
8 | filter_network_params:
9 | blockA: 0 # Number of adaptive residual blocks.
10 | cycleA: 0 # Number of adaptive dilation cycles.
11 | blockF: 30 # Number of fixed residual blocks.
12 | cycleF: 3 # Number of fixed dilation cycles.
13 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
14 | in_channels: 1 # Number of input channels.
15 | out_channels: 1 # Number of output channels.
16 | residual_channels: 64 # Number of channels in residual conv.
17 | gate_channels: 128 # Number of channels in gated conv.
18 | skip_channels: 64 # Number of channels in skip conv.
19 | aux_channels: 62 # Number of channels for auxiliary feature conv.
20 | aux_context_window: 2 # Context window size for auxiliary feature.
21 | # If set to 2, previous 2 and future 2 frames will be considered.
22 | use_weight_norm: true # Whether to use weight norm.
23 | upsample_params: # Upsampling network parameters.
24 | upsample_scales: [5, 4, 3, 2] # Upsampling scales. Product of these must be the same as hop size.
25 |
--------------------------------------------------------------------------------
/usfgan/utils/filters.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2020 Yi-Chiao Wu (Nagoya University)
4 | # based on a WaveNet script by Tomoki Hayashi (Nagoya University)
5 | # (https://github.com/kan-bayashi/PytorchWaveNetVocoder)
6 | # based on sprocket-vc script by Kazuhiro Kobayashi (Nagoya University)
7 | # (https://github.com/k2kobayashi/sprocket)
8 | # MIT License (https://opensource.org/licenses/MIT)
9 |
10 | """Filters."""
11 |
12 | import numpy as np
13 | from scipy.signal import firwin, lfilter
14 |
15 | NUMTAPS = 255
16 |
17 |
18 | def low_cut_filter(x, fs, cutoff=70):
19 | """Low-cut filter
20 |
21 | Args:
22 | x (ndarray): Waveform sequence
23 | fs (int): Sampling frequency
24 | cutoff (float): Cutoff frequency of low cut filter
25 |
26 | Return:
27 | (ndarray): Low cut filtered waveform sequence
28 |
29 | """
30 | nyquist = fs // 2
31 | norm_cutoff = cutoff / nyquist
32 | numtaps = NUMTAPS
33 | fil = firwin(numtaps, norm_cutoff, pass_zero=False)
34 | lcf_x = lfilter(fil, 1, x)
35 |
36 | return lcf_x
37 |
38 |
39 | def low_pass_filter(x, fs, cutoff=70):
40 | """Low-pass filter
41 |
42 | Args:
43 | x (ndarray): Waveform sequence
44 | fs (int): Sampling frequency
45 | cutoff (float): Cutoff frequency of low pass filter
46 |
47 | Return:
48 | (ndarray): Low pass filtered waveform sequence
49 |
50 | """
51 | nyquist = fs // 2
52 | norm_cutoff = cutoff / nyquist
53 | numtaps = NUMTAPS
54 | fil = firwin(numtaps, norm_cutoff, pass_zero=True)
55 | x_pad = np.pad(x, (numtaps, numtaps), "edge")
56 | lpf_x = lfilter(fil, 1, x_pad)
57 | lpf_x = lpf_x[numtaps + numtaps // 2 : -numtaps // 2]
58 |
59 | return lpf_x
60 |
--------------------------------------------------------------------------------
/usfgan/bin/config/anasyn.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - generator: parallel_hn_usfgan
6 |
7 | hydra:
8 | run:
9 | dir: ./
10 | output_subdir: null
11 | job_logging:
12 | formatters:
13 | simple:
14 | format: '[%(asctime)s][%(levelname)s][%(module)s | %(lineno)s] %(message)s'
15 | disable_existing_loggers: false
16 |
17 | in_dir: # Path to directory which include wav files you want to process.
18 | out_dir: # Path to directory to save the synthesized wavs.
19 | f0_factor: 1.00 # F0 scaling factor.
20 | seed: 100 # Seed number for random numbers.
21 | stats: pretrained_model/stats/vctk_train_24kHz.joblib # Path to statistics file.
22 | checkpoint_path: pretrained_model/checkpoint-600000steps.pkl # Path to pre-trained model.
23 |
24 | # The same parametes should be set as in the training.
25 | sample_rate: 24000 # Sampling rate.
26 | frame_period: 5 # Frameshift in ms.
27 | f0_floor: 70 # Minimum F0 for WORLD F0 analysis.
28 | f0_ceil: 500 # Maximum F0 for WORLD F0 analysis.
29 | mcep_dim: 40 # Number of dimension of MGC.
30 | mcap_dim: 20 # Number of dimension of mel-cepstral AP.
31 | aux_feats: ["mcep", "mcap"] # Input acoustic features.
32 | dense_factor: 4 # Dense factor in PDCNNs.
33 | df_f0_type: "contf0" # F0 type for dilation factor ("f0" or "cf0").
34 | sine_amp: 0.1 # Sine amplitude.
35 | noise_amp: 0.003 # Noise amplitude.
36 | sine_f0_type: "contf0" # F0 type for sine signal ("f0" or "cf0").
37 | signal_types: ["sine", "noise"] # List of input signal types.
--------------------------------------------------------------------------------
/usfgan/bin/config/extract_features.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | job:
5 | chdir: false
6 | output_subdir: null
7 | job_logging:
8 | formatters:
9 | simple:
10 | format: '[%(levelname)s][%(module)s | %(lineno)s] %(message)s'
11 |
12 | audio: data/scp/vctk_train_24kHz.scp # List filr of input wav files.
13 | in_dir: wav # Directory of input feature files.
14 | out_dir: hdf5 # Directory to save generated samples.
15 | feature_format: h5 # Feature format.
16 | sampling_rate: 24000 # Sampling rate.
17 | spkinfo: data/spk_info.yaml # YAML format speaker information.
18 | spkidx: -2 # Speaker index of the split path.
19 | inv: true # If false, wav is restored from acoustic features.
20 |
21 | # Audio preprocess setting.
22 | highpass_cutoff: 70 # Cut-off-frequency for low-cut-filter.
23 | pow_th: # Threshold of power.
24 |
25 | # Mel-spectrogram extraction setting.
26 | fft_size: 1024 # FFT size.
27 | hop_size: 120 # Hop size.
28 | win_length: 1024 # Window length.
29 | # If set to null, it will be same as fft_size.
30 | window: hann # Window function.
31 | num_mels: 80 # Number of mel basis.
32 | fmin: 0 # Minimum frequency in mel basis calculation.
33 | fmax: null # Maximum frequency in mel basis calculation.
34 |
35 | # WORLD feature extraction setting.
36 | minf0: 70 # F0 setting: minimum f0.
37 | maxf0: 340 # F0 setting: maximum f0.
38 | shiftms: 5 # F0 setting: frame shift (ms).
39 | mcep_dim: 40 # Mel-cepstrum setting: number of dimension.
40 | mcap_dim: 20 # Mel-cepstrum setting: number of dimension.
41 | alpha: 0.466 # Mel-cepstrum setting: all-pass constant.
42 |
--------------------------------------------------------------------------------
/usfgan/bin/config/data/vctk_24kHz.yaml:
--------------------------------------------------------------------------------
1 | # Dataset settings
2 | train_audio: data/scp/vctk_train_24kHz.scp # List file of training audio files.
3 | train_feat: data/scp/vctk_train_24kHz.list # List file of training feature files.
4 | valid_audio: data/scp/vctk_valid_24kHz.scp # List file of validation audio files.
5 | valid_feat: data/scp/vctk_valid_24kHz.list # List file of validation feature files.
6 | eval_feat: data/scp/vctk_eval_24kHz.list # List file of evaluation feature files for decoding.
7 | stats: data/stats/vctk_train_24kHz.joblib # Path to the file of statistics.
8 | allow_cache: false # Whether to allow cache in dataset. If true, it requires cpu memory
9 |
10 | # Feature settings
11 | sample_rate: 24000 # Sampling rate.
12 | hop_size: 120 # Hop size.
13 | dense_factor: 4 # Dense factor in PDCNNs.
14 | sine_amp: 0.1 # Sine amplitude.
15 | noise_amp: 0.003 # Noise amplitude.
16 | signal_types: ["sine", "noise"] # List of input signal types for generator.
17 | sine_f0_type: "contf0" # F0 type for sine signal ("f0" or "contf0").
18 | df_f0_type: "contf0" # F0 type for dilation factor ("f0" or "contf0").
19 | aux_feats: ["mcep", "mcap"] # Auxiliary features.
20 | # "uv": V/UV binary.
21 | # "f0": descrete f0.
22 | # "mcep": mel-cepstral envelope.
23 | # "contf0": continuous f0.
24 | # "mcap": mel-cepstral aperiodicity.
25 | # "codeap": coded aperiodicity.
26 | # "logmsp": log mel-spectrogram.
27 |
28 | # Collater setting
29 | batch_max_length: 18000 # Length of each audio in batch. Make sure dividable by hop_size.
30 |
31 | # Data loader setting
32 | batch_size: 5 # Batch size
33 | num_workers: 1 # Number of workers in Pytorch DataLoader
34 | pin_memory: true # Whether to pin memory in Pytorch DataLoader
35 |
36 | # Other setting
37 | remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_length
38 |
--------------------------------------------------------------------------------
/usfgan/bin/config/discriminator/hifigan.yaml:
--------------------------------------------------------------------------------
1 | _target_: usfgan.models.HiFiGANMultiScaleMultiPeriodDiscriminator
2 | scales: 3 # Number of multi-scale discriminator.
3 | scale_downsample_pooling: "AvgPool1d" # Pooling operation for scale discriminator.
4 | scale_downsample_pooling_params:
5 | kernel_size: 4 # Pooling kernel size.
6 | stride: 2 # Pooling stride.
7 | padding: 2 # Padding size.
8 | scale_discriminator_params:
9 | in_channels: 1 # Number of input channels.
10 | out_channels: 1 # Number of output channels.
11 | kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
12 | channels: 128 # Initial number of channels.
13 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
14 | max_groups: 16 # Maximum number of groups in downsampling conv layers.
15 | bias: true
16 | downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
17 | nonlinear_activation: "LeakyReLU" # Nonlinear activation.
18 | nonlinear_activation_params:
19 | negative_slope: 0.1
20 | follow_official_norm: true # Whether to follow the official norm setting.
21 | periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
22 | period_discriminator_params:
23 | in_channels: 1 # Number of input channels.
24 | out_channels: 1 # Number of output channels.
25 | kernel_sizes: [5, 3] # List of kernel sizes.
26 | channels: 32 # Initial number of channels.
27 | downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
28 | max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
29 | bias: true # Whether to use bias parameter in conv layer."
30 | nonlinear_activation: "LeakyReLU" # Nonlinear activation.
31 | nonlinear_activation_params: # Nonlinear activation paramters.
32 | negative_slope: 0.1
33 | use_weight_norm: true # Whether to apply weight normalization.
34 | use_spectral_norm: false # Whether to apply spectral normalization.
35 |
--------------------------------------------------------------------------------
/usfgan/bin/config/generator/cascade_hn_usfgan.yaml:
--------------------------------------------------------------------------------
1 | _target_: usfgan.models.CascadeHnUSFGANGenerator
2 | harmonic_network_params:
3 | blockA: 20 # Number of adaptive residual blocks.
4 | cycleA: 4 # Number of adaptive dilation cycles.
5 | blockF: 0 # Number of fixed residual blocks.
6 | cycleF: 0 # Number of fixed dilation cycles.
7 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
8 | noise_network_params:
9 | blockA: 0 # Number of adaptive residual blocks.
10 | cycleA: 0 # Number of adaptive dilation cycles.
11 | blockF: 5 # Number of fixed residual blocks.
12 | cycleF: 5 # Number of fixed dilation cycles.
13 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
14 | filter_network_params:
15 | blockA: 0 # Number of adaptive residual blocks.
16 | cycleA: 0 # Number of adaptive dilation cycles.
17 | blockF: 30 # Number of fixed residual blocks.
18 | cycleF: 3 # Number of fixed dilation cycles.
19 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
20 | periodicity_estimator_params:
21 | conv_layers: 3 # Number of convolution layers.
22 | kernel_size: 5 # Kernel size.
23 | dilation: 1 # Dilation size.
24 | padding_mode: "replicate" # Padding mode.
25 | in_channels: 1 # Number of input channels.
26 | out_channels: 1 # Number of output channels.
27 | residual_channels: 64 # Number of channels in residual conv.
28 | gate_channels: 128 # Number of channels in gated conv.
29 | skip_channels: 64 # Number of channels in skip conv.
30 | aux_channels: 62 # Number of channels for auxiliary feature conv.
31 | aux_context_window: 2 # Context window size for auxiliary feature.
32 | # If set to 2, previous 2 and future 2 frames will be considered.
33 | use_weight_norm: true # Whether to use weight norm.
34 | upsample_params: # Upsampling network parameters.
35 | upsample_scales: [5, 4, 3, 2] # Upsampling scales. Product of these must be the same as hop size.
36 |
--------------------------------------------------------------------------------
/usfgan/bin/config/generator/parallel_hn_usfgan.yaml:
--------------------------------------------------------------------------------
1 | _target_: usfgan.models.ParallelHnUSFGANGenerator
2 | harmonic_network_params:
3 | blockA: 20 # Number of adaptive residual blocks.
4 | cycleA: 4 # Number of adaptive dilation cycles.
5 | blockF: 0 # Number of fixed residual blocks.
6 | cycleF: 0 # Number of fixed dilation cycles.
7 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
8 | noise_network_params:
9 | blockA: 0 # Number of adaptive residual blocks.
10 | cycleA: 0 # Number of adaptive dilation cycles.
11 | blockF: 5 # Number of fixed residual blocks.
12 | cycleF: 5 # Number of fixed dilation cycles.
13 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
14 | filter_network_params:
15 | blockA: 0 # Number of adaptive residual blocks.
16 | cycleA: 0 # Number of adaptive dilation cycles.
17 | blockF: 30 # Number of fixed residual blocks.
18 | cycleF: 3 # Number of fixed dilation cycles.
19 | cascade_mode: 0 # Network cascaded mode (0: adaptive->fix; 1: fix->adaptive).
20 | periodicity_estimator_params:
21 | conv_layers: 3 # Number of convolution layers.
22 | kernel_size: 5 # Kernel size.
23 | dilation: 1 # Dilation size.
24 | padding_mode: "replicate" # Padding mode.
25 | in_channels: 1 # Number of input channels.
26 | out_channels: 1 # Number of output channels.
27 | residual_channels: 64 # Number of channels in residual conv.
28 | gate_channels: 128 # Number of channels in gated conv.
29 | skip_channels: 64 # Number of channels in skip conv.
30 | aux_channels: 62 # Number of channels for auxiliary feature conv.
31 | aux_context_window: 2 # Context window size for auxiliary feature.
32 | # If set to 2, previous 2 and future 2 frames will be considered.
33 | use_weight_norm: true # Whether to use weight norm.
34 | upsample_params: # Upsampling network parameters.
35 | upsample_scales: [5, 4, 3, 2] # Upsampling scales. Product of these must be the same as hop size.
36 |
--------------------------------------------------------------------------------
/usfgan/utils/index.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2020 Yi-Chiao Wu (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Indexing-related functions."""
7 |
8 | import torch
9 | from torch.nn import ConstantPad1d as pad1d
10 |
11 |
12 | def pd_indexing(x, d, dilation, batch_index, ch_index):
13 | """Pitch-dependent indexing of past and future samples.
14 |
15 | Args:
16 | x (Tensor): Input feature map (B, C, T).
17 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T).
18 | dilation (Int): Dilation size.
19 | batch_index (Tensor): Batch index
20 | ch_index (Tensor): Channel index
21 |
22 | Returns:
23 | Tensor: Past output tensor (B, out_channels, T)
24 | Tensor: Future output tensor (B, out_channels, T)
25 |
26 | """
27 | (_, _, batch_length) = d.size()
28 | dilations = d * dilation
29 |
30 | # get past index
31 | idxP = torch.arange(-batch_length, 0).float()
32 | idxP = idxP.to(x.device)
33 | idxP = torch.add(-dilations, idxP)
34 | idxP = idxP.round().long()
35 | maxP = -((torch.min(idxP) + batch_length))
36 | assert maxP >= 0
37 | idxP = (batch_index, ch_index, idxP)
38 | # padding past tensor
39 | xP = pad1d((maxP, 0), 0)(x)
40 |
41 | # get future index
42 | idxF = torch.arange(0, batch_length).float()
43 | idxF = idxF.to(x.device)
44 | idxF = torch.add(dilations, idxF)
45 | idxF = idxF.round().long()
46 | maxF = torch.max(idxF) - (batch_length - 1)
47 | assert maxF >= 0
48 | idxF = (batch_index, ch_index, idxF)
49 | # padding future tensor
50 | xF = pad1d((0, maxF), 0)(x)
51 |
52 | return xP[idxP], xF[idxF]
53 |
54 |
55 | def index_initial(n_batch, n_ch, tensor=True):
56 | """Tensor batch and channel index initialization.
57 |
58 | Args:
59 | n_batch (Int): Number of batch.
60 | n_ch (Int): Number of channel.
61 | tensor (bool): Return tensor or numpy array
62 |
63 | Returns:
64 | Tensor: Batch index
65 | Tensor: Channel index
66 |
67 | """
68 | batch_index = []
69 | for i in range(n_batch):
70 | batch_index.append([[i]] * n_ch)
71 | ch_index = []
72 | for i in range(n_ch):
73 | ch_index += [[i]]
74 | ch_index = [ch_index] * n_batch
75 |
76 | if tensor:
77 | batch_index = torch.tensor(batch_index)
78 | ch_index = torch.tensor(ch_index)
79 | if torch.cuda.is_available():
80 | batch_index = batch_index.cuda()
81 | ch_index = ch_index.cuda()
82 | return batch_index, ch_index
83 |
--------------------------------------------------------------------------------
/usfgan/bin/compute_statistics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Feature statistics computing script.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 | - https://github.com/kan-bayashi/ParallelWaveGAN
11 |
12 | """
13 |
14 | import os
15 | from logging import getLogger
16 |
17 | import hydra
18 | import numpy as np
19 | from hydra.utils import to_absolute_path
20 | from joblib import dump, load
21 | from omegaconf import DictConfig, OmegaConf
22 | from sklearn.preprocessing import StandardScaler
23 | from usfgan.utils import read_hdf5, read_txt
24 |
25 | # A logger for this file
26 | logger = getLogger(__name__)
27 |
28 |
29 | def calc_stats(file_list, config):
30 | """Calcute statistics
31 | Args:
32 | file_list (list): File list.
33 | config (dict): Dictionary of config.
34 | """
35 |
36 | # define scalers
37 | scaler = load(config.stats) if os.path.isfile(config.stats) else {}
38 | for feat_type in config.feat_types:
39 | if feat_type == "uv":
40 | continue
41 | scaler[feat_type] = StandardScaler()
42 |
43 | # process over all of data
44 | for i, filename in enumerate(file_list):
45 | logger.info(f"now processing {filename} ({i + 1}/{len(file_list)})")
46 | for feat_type in config.feat_types:
47 | if feat_type == "uv":
48 | continue
49 | if feat_type == "f0":
50 | f0 = read_hdf5(to_absolute_path(filename), "/f0")
51 | feat = np.expand_dims(f0[f0 > 0], axis=-1)
52 | else:
53 | feat = read_hdf5(to_absolute_path(filename), f"/{feat_type}")
54 | if feat.shape[0] == 0:
55 | logger.warning(f"feat length is 0 {filename}/{feat_type}")
56 | continue
57 | scaler[feat_type].partial_fit(feat)
58 |
59 | if not os.path.exists(os.path.dirname(config.stats)):
60 | os.makedirs(os.path.dirname(config.stats))
61 | dump(scaler, to_absolute_path(config.stats))
62 | logger.info(f"Successfully saved statistics to {config.stats}.")
63 |
64 |
65 | @hydra.main(version_base=None, config_path="config", config_name="compute_statistics")
66 | def main(config: DictConfig):
67 | # show argument
68 | logger.info(OmegaConf.to_yaml(config))
69 |
70 | # read file list
71 | file_list = read_txt(to_absolute_path(config.feats))
72 | logger.info(f"number of utterances = {len(file_list)}")
73 |
74 | # calculate statistics
75 | calc_stats(file_list, config)
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Setup Unified Source-Filter GAN Library."""
4 |
5 | import os
6 | import sys
7 | from distutils.version import LooseVersion
8 |
9 | import pip
10 | from setuptools import find_packages, setup
11 |
12 | if LooseVersion(sys.version) < LooseVersion("3.8"):
13 | raise RuntimeError(
14 | "usfgan requires Python>=3.8, " "but your Python is {}".format(sys.version)
15 | )
16 | if LooseVersion(pip.__version__) < LooseVersion("21.0.0"):
17 | raise RuntimeError(
18 | "pip>=21.0.0 is required, but your pip is {}. "
19 | 'Try again after "pip install -U pip"'.format(pip.__version__)
20 | )
21 |
22 | requirements = {
23 | "install": [
24 | "wheel",
25 | "torch>=1.9.0",
26 | "torchaudio>=0.8.1",
27 | "setuptools>=38.5.1",
28 | "librosa>=0.8.0",
29 | "soundfile>=0.10.2",
30 | "tensorboardX>=2.2",
31 | "matplotlib>=3.1.0",
32 | "PyYAML>=3.12",
33 | "tqdm>=4.26.1",
34 | "h5py>=2.10.0",
35 | "pyworld>=0.2.12",
36 | "sprocket-vc",
37 | "protobuf<=3.19.0",
38 | "hydra-core>=1.2",
39 | ],
40 | "setup": [
41 | "numpy",
42 | "pytest-runner",
43 | ],
44 | "test": [
45 | "pytest>=3.3.0",
46 | "hacking>=1.1.0",
47 | "flake8>=3.7.8",
48 | "flake8-docstrings>=1.3.1",
49 | ],
50 | }
51 | entry_points = {
52 | "console_scripts": [
53 | "usfgan-extract-features=usfgan.bin.extract_features:main",
54 | "usfgan-compute-statistics=usfgan.bin.compute_statistics:main",
55 | "usfgan-train=usfgan.bin.train:main",
56 | "usfgan-decode=usfgan.bin.decode:main",
57 | ]
58 | }
59 |
60 | install_requires = requirements["install"]
61 | setup_requires = requirements["setup"]
62 | tests_require = requirements["test"]
63 | extras_require = {
64 | k: v for k, v in requirements.items() if k not in ["install", "setup"]
65 | }
66 |
67 | dirname = os.path.dirname(__file__)
68 | setup(
69 | name="usfgan",
70 | version="0.1",
71 | url="http://github.com/chomeyama/HN-UnifiedSourceFilterGAN",
72 | author="Reo Yoneyama",
73 | author_email="yoneyama.reo@g.sp.m.is.nagoya-u.ac.jp",
74 | description="Harmonic-plus-Noise Unified Source-Filter GAN implementation",
75 | long_description_content_type="text/markdown",
76 | long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
77 | license="MIT License",
78 | packages=find_packages(include=["usfgan*"]),
79 | install_requires=install_requires,
80 | setup_requires=setup_requires,
81 | tests_require=tests_require,
82 | extras_require=extras_require,
83 | entry_points=entry_points,
84 | classifiers=[
85 | "Programming Language :: Python :: 3.9.5",
86 | "Intended Audience :: Science/Research",
87 | "Operating System :: POSIX :: Linux",
88 | "License :: OSI Approved :: MIT License",
89 | "Topic :: Software Development :: Libraries :: Python Modules",
90 | ],
91 | )
92 |
--------------------------------------------------------------------------------
/usfgan/bin/config/train/usfgan.yaml:
--------------------------------------------------------------------------------
1 | # Interval setting
2 | discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator.
3 | train_max_steps: 600000 # Number of pre-training steps.
4 | save_interval_steps: 100000 # Interval steps to save checkpoint.
5 | eval_interval_steps: 2000 # Interval steps to evaluate the network.
6 | log_interval_steps: 2000 # Interval steps to record the training log.
7 | resume: # Epoch to resume training.
8 |
9 | # Loss balancing coefficients.
10 | lambda_stft: 1.0
11 | lambda_source: 1.0
12 | lambda_adv: 4.0
13 | lambda_feat_match: 0.0
14 |
15 | # Multi-resolution STFT loss setting
16 | stft_loss:
17 | _target_: usfgan.losses.MultiResolutionSTFTLoss
18 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
19 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
20 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
21 | window: hann_window # Window function for STFT-based loss
22 |
23 | # Source regularization loss setting
24 | source_loss:
25 | _target_: usfgan.losses.FlattenLoss
26 | sampling_rate: 24000 # Sampling rate.
27 | fft_size: 2048 # FFT size.
28 | hop_size: 120 # Hop size.
29 | f0_floor: 70 # Minimum F0.
30 | f0_ceil: 340 # Maximum F0.
31 | power: false # Whether to use power or magnitude spectrogram.
32 | elim_0th: false # Whether to exclude 0th components of cepstrums in
33 | # CheapTrick estimation. If set to true, source-network
34 | # is forced to estimate the power of the output signal.
35 | l2_norm: false # Whether to use L1 or L2 norm.
36 |
37 | # Adversarial loss setting
38 | adversarial_loss:
39 | _target_: usfgan.losses.AdversarialLoss
40 |
41 | # Feature matching loss setting
42 | feat_match_loss:
43 | _target_: usfgan.losses.FeatureMatchLoss
44 |
45 | # Optimizer setting
46 | generator_optimizer:
47 | _target_: usfgan.optimizers.RAdam
48 | lr: 0.0001 # Generator's learning rate.
49 | eps: 1.0e-6 # Generator's epsilon.
50 | weight_decay: 0.0 # Generator's weight decay coefficient.
51 | generator_scheduler:
52 | _target_: torch.optim.lr_scheduler.StepLR
53 | step_size: 200000 # Generator's scheduler step size.
54 | gamma: 0.5 # Generator's scheduler gamma.
55 | # At each step size, lr will be multiplied by this parameter.
56 | generator_grad_norm: 10 # Generator's gradient norm.
57 | discriminator_optimizer:
58 | _target_: usfgan.optimizers.RAdam
59 | lr: 0.00005 # Discriminator's learning rate.
60 | eps: 1.0e-6 # Discriminator's epsilon.
61 | weight_decay: 0.0 # Discriminator's weight decay coefficient.
62 | discriminator_scheduler:
63 | _target_: torch.optim.lr_scheduler.StepLR
64 | step_size: 200000 # Discriminator's scheduler step size.
65 | gamma: 0.5 # Discriminator's scheduler gamma.
66 | # At each step size, lr will be multiplied by this parameter.
67 | discriminator_grad_norm: 10 # Discriminator's gradient norm.
68 |
--------------------------------------------------------------------------------
/usfgan/losses/adversarial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Adversarial loss modules.
7 |
8 | References:
9 | - https://github.com/kan-bayashi/ParallelWaveGAN
10 |
11 | """
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 |
18 | class AdversarialLoss(nn.Module):
19 | """Adversarial loss module."""
20 |
21 | def __init__(
22 | self,
23 | average_by_discriminators=False,
24 | loss_type="mse",
25 | ):
26 | """Initialize AversarialLoss module."""
27 | super(AdversarialLoss, self).__init__()
28 | self.average_by_discriminators = average_by_discriminators
29 |
30 | assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
31 | if loss_type == "mse":
32 | self.adv_criterion = self._mse_adv_loss
33 | self.fake_criterion = self._mse_fake_loss
34 | self.real_criterion = self._mse_real_loss
35 | else:
36 | self.adv_criterion = self._hinge_adv_loss
37 | self.fake_criterion = self._hinge_fake_loss
38 | self.real_criterion = self._hinge_real_loss
39 |
40 | def forward(self, p_fakes, p_reals=None):
41 | """Calcualate generator/discriminator adversarial loss.
42 |
43 | Args:
44 | p_fakes (list): List of
45 | discriminator outputs calculated from generator outputs.
46 | p_reals (list): List of
47 | discriminator outputs calculated from groundtruth.
48 |
49 | Returns:
50 | Tensor: Generator adversarial loss value.
51 | Tensor: Discriminator real loss value.
52 | Tensor: Discriminator fake loss value.
53 |
54 | """
55 | # generator adversarial loss
56 | if p_reals is None:
57 | adv_loss = 0.0
58 | for p_fake in p_fakes:
59 | adv_loss += self.adv_criterion(p_fake)
60 |
61 | if self.average_by_discriminators:
62 | adv_loss /= len(p_fakes)
63 |
64 | return adv_loss
65 |
66 | # discriminator adversarial loss
67 | else:
68 | fake_loss = 0.0
69 | real_loss = 0.0
70 | for p_fake, p_real in zip(p_fakes, p_reals):
71 | fake_loss += self.fake_criterion(p_fake)
72 | real_loss += self.real_criterion(p_real)
73 |
74 | if self.average_by_discriminators:
75 | fake_loss /= len(p_fakes)
76 | real_loss /= len(p_reals)
77 |
78 | return fake_loss, real_loss
79 |
80 | def _mse_adv_loss(self, x):
81 | return F.mse_loss(x, x.new_ones(x.size()))
82 |
83 | def _mse_real_loss(self, x):
84 | return F.mse_loss(x, x.new_ones(x.size()))
85 |
86 | def _mse_fake_loss(self, x):
87 | return F.mse_loss(x, x.new_zeros(x.size()))
88 |
89 | def _hinge_adv_loss(self, x):
90 | return -x.mean()
91 |
92 | def _hinge_real_loss(self, x):
93 | return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
94 |
95 | def _hinge_fake_loss(self, x):
96 | return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
97 |
--------------------------------------------------------------------------------
/usfgan/bin/config/train/hn_usfgan.yaml:
--------------------------------------------------------------------------------
1 | # Interval setting
2 | discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
3 | train_max_steps: 600000 # Number of pre-training steps.
4 | save_interval_steps: 100000 # Interval steps to save checkpoint.
5 | eval_interval_steps: 2000 # Interval steps to evaluate the network.
6 | log_interval_steps: 2000 # Interval steps to record the training log.
7 | resume: # Epoch to resume training.
8 |
9 | # Loss balancing coefficients.
10 | lambda_stft: 45.0
11 | lambda_source: 1.0
12 | lambda_adv: 1.0
13 | lambda_feat_match: 0.0
14 |
15 | # Mel-spectrogram reconstraction loss setting
16 | stft_loss:
17 | _target_: usfgan.losses.MelSpectralLoss
18 | fft_size: 1024 # FFT size for STFT-based loss.
19 | hop_size: 256 # Hop size for STFT-based loss
20 | win_length: 1024 # Window length for STFT-based loss.
21 | window: hann_window # Window function for STFT-based loss.
22 | sampling_rate: 24000 # Samplring rate.
23 | n_mels: 80 # Number of bins of mel-filter-bank.
24 | fmin: 0 # Minimum frequency of mel-filter-bank.
25 | fmax: null # Maximum frequency of mel-filter-bank.
26 | # If null, it will be fs / 2.
27 |
28 | # Source regularization loss setting
29 | source_loss:
30 | _target_: usfgan.losses.ResidualLoss
31 | sampling_rate: 24000 # Sampling rate.
32 | fft_size: 2048 # FFT size.
33 | hop_size: 120 # Hop size.
34 | f0_floor: 70 # Minimum F0.
35 | f0_ceil: 340 # Maximum F0.
36 | n_mels: 80 # Number of mel-filter-bank bins.
37 | fmin: 0 # Minimum frequency of mel-filter-bank.
38 | fmax: null # Maximum frequency of mel-filter-bank.
39 | power: false # Whether to use power or magnitude spectrogram.
40 | elim_0th: true # Whether to exclude 0th components of cepstrums in
41 | # CheapTrick estimation. If set to true, source-network
42 | # is forced to estimate the power of the output signal.
43 |
44 | # Adversarial loss setting
45 | adversarial_loss:
46 | _target_: usfgan.losses.AdversarialLoss
47 | average_by_discriminators: false # Whether to average loss by #discriminators.
48 | loss_type: mse
49 |
50 | # Feature matching loss setting
51 | feat_match_loss:
52 | _target_: usfgan.losses.FeatureMatchLoss
53 | average_by_layers: false # Whether to average loss by #layers in each discriminator.
54 |
55 | # Optimizer and scheduler setting
56 | generator_optimizer:
57 | _target_: torch.optim.Adam
58 | lr: 2.0e-4
59 | betas: [0.5, 0.9]
60 | weight_decay: 0.0
61 | generator_scheduler:
62 | _target_: torch.optim.lr_scheduler.MultiStepLR
63 | gamma: 0.5
64 | milestones:
65 | - 200000
66 | - 400000
67 | - 600000
68 | - 800000
69 | generator_grad_norm: 10
70 | discriminator_optimizer:
71 | _target_: torch.optim.Adam
72 | lr: 2.0e-4
73 | betas: [0.5, 0.9]
74 | weight_decay: 0.0
75 | discriminator_scheduler:
76 | _target_: torch.optim.lr_scheduler.MultiStepLR
77 | gamma: 0.5
78 | milestones:
79 | - 200000
80 | - 400000
81 | - 600000
82 | - 800000
83 | discriminator_grad_norm: 10
84 |
--------------------------------------------------------------------------------
/usfgan/optimizers/radam.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """RAdam optimizer.
4 |
5 | This code is drived from https://github.com/LiyuanLucasLiu/RAdam.
6 |
7 | """
8 |
9 | import math
10 |
11 | import torch
12 | from torch.optim.optimizer import Optimizer
13 |
14 |
15 | class RAdam(Optimizer):
16 | """Rectified Adam optimizer."""
17 |
18 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
19 | """Initilize RAdam optimizer."""
20 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
21 | self.buffer = [[None, None, None] for ind in range(10)]
22 | super(RAdam, self).__init__(params, defaults)
23 |
24 | def __setstate__(self, state):
25 | """Set state."""
26 | super(RAdam, self).__setstate__(state)
27 |
28 | def step(self, closure=None):
29 | """Run one step."""
30 | loss = None
31 | if closure is not None:
32 | loss = closure()
33 |
34 | for group in self.param_groups:
35 |
36 | for p in group["params"]:
37 | if p.grad is None:
38 | continue
39 | grad = p.grad.data.float()
40 | if grad.is_sparse:
41 | raise RuntimeError("RAdam does not support sparse gradients")
42 |
43 | p_data_fp32 = p.data.float()
44 |
45 | state = self.state[p]
46 |
47 | if len(state) == 0:
48 | state["step"] = 0
49 | state["exp_avg"] = torch.zeros_like(p_data_fp32)
50 | state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
51 | else:
52 | state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
53 | state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
54 |
55 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
56 | beta1, beta2 = group["betas"]
57 |
58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
59 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
60 |
61 | state["step"] += 1
62 | buffered = self.buffer[int(state["step"] % 10)]
63 | if state["step"] == buffered[0]:
64 | N_sma, step_size = buffered[1], buffered[2]
65 | else:
66 | buffered[0] = state["step"]
67 | beta2_t = beta2 ** state["step"]
68 | N_sma_max = 2 / (1 - beta2) - 1
69 | N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
70 | buffered[1] = N_sma
71 |
72 | # more conservative since it's an approximated value
73 | if N_sma >= 5:
74 | step_size = math.sqrt(
75 | (1 - beta2_t)
76 | * (N_sma - 4)
77 | / (N_sma_max - 4)
78 | * (N_sma - 2)
79 | / N_sma
80 | * N_sma_max
81 | / (N_sma_max - 2)
82 | ) / (
83 | 1 - beta1 ** state["step"]
84 | ) # NOQA
85 | else:
86 | step_size = 1.0 / (1 - beta1 ** state["step"])
87 | buffered[2] = step_size
88 |
89 | if group["weight_decay"] != 0:
90 | p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
91 |
92 | # more conservative since it's an approximated value
93 | if N_sma >= 5:
94 | denom = exp_avg_sq.sqrt().add_(group["eps"])
95 | p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom)
96 | else:
97 | p_data_fp32.add_(-step_size * group["lr"], exp_avg)
98 |
99 | p.data.copy_(p_data_fp32)
100 |
101 | return loss
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Harmonic-plus-Noise Unified Source-Filter GAN implementation with Pytorch
2 |
3 |
4 | This repo provides official PyTorch implementation of [HN-uSFGAN](https://arxiv.org/abs/2205.06053), a high-fidelity and pitch controllable neural vocoder based on unifid source-filter networks.
5 | HN-uSFGAN is an extended model of [uSFGAN](https://arxiv.org/abs/2104.04668), and this repo includes the original [uSFGAN implementation](https://github.com/chomeyama/UnifiedSourceFilterGAN) with some modifications.
6 |
7 | For more information, please see the [demo](https://chomeyama.github.io/HN-UnifiedSourceFilterGAN-Demo/).
8 |
9 | This repository is tested on the following condition.
10 |
11 | - Ubuntu 20.04.3 LTS
12 | - Titan RTX 3090 GPU
13 | - Python 3.9.5
14 | - Cuda 11.5
15 | - CuDNN 8.1.1.33-1+cuda11.2
16 |
17 | ## Environment setup
18 |
19 | ```bash
20 | $ cd HN-UnifiedSourceFilterGAN
21 | $ pip install -e .
22 | ```
23 |
24 | Please refer to the [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) repo for more details.
25 |
26 | ## Folder architecture
27 | - **egs**:
28 | The folder for projects.
29 | - **egs/vctk**:
30 | The folder of the VCTK project example.
31 | - **usfgan**:
32 | The folder of the source codes.
33 |
34 | ## Run
35 |
36 | In this repo, hyperparameters are managed using [Hydra](https://hydra.cc/docs/intro/).
37 | Hydra provides an easy way to dynamically create a hierarchical configuration by composition and override it through config files and the command line.
38 |
39 | ### Dataset preparation
40 |
41 | Make dataset and scp files denoting paths to each audio files according to your own dataset (E.g., `egs/vctk/data/scp/vctk_train_24kHz.scp`).
42 | Also, list files denoting paths to the features extracted in the next step are required (E.g., `egs/vctk/data/scp/vctk_train_24kHz.list`).
43 | Note that scp/list files for each training/validation/evaluation are needed.
44 |
45 | ### Preprocessing
46 |
47 | ```bash
48 | # Move to the project directory
49 | $ cd egs/vctk
50 |
51 | # Extract acoustic features (F0, mel-cepstrum, and etc.)
52 | # You can customize parameters according to usfgan/bin/config/extract_features.yaml
53 | $ usfgan-extract-features audio=data/scp/vctk_all_24kHz.scp
54 |
55 | # Compute statistics of training and testing data
56 | $ usfgan-compute-statistics feats=data/scp/vctk_train_24kHz.list stats=data/stats/vctk_train_24kHz.joblib
57 | ```
58 |
59 | ### Training
60 |
61 | ```bash
62 | # Train a model customizing the hyperparameters as you like
63 | # The following setting of Parallel-HN-uSFGAN generator with HiFiGAN discriminator would show best performance
64 | $ usfgan-train generator=parallel_hn_usfgan discriminator=hifigan train=hn_usfgan data=vctk_24kHz out_dir=exp/parallel_hn_usfgan
65 | ```
66 |
67 | ### Inference
68 |
69 | ```bash
70 | # Decode with natural acoustic features
71 | $ usfgan-decode out_dir=exp/parallel-hn-usfgan/wav/600000steps checkpoint_path=exp/parallel-hn-usfgan/checkpoints/checkpoint-600000steps.pkl
72 | # Decode with halved f0
73 | $ usfgan-decode out_dir=exp/parallel-hn-usfgan/wav/600000steps checkpoint_path=exp/parallel-hn-usfgan/checkpoints/checkpoint-600000steps.pkl f0_factor=0.50
74 | ```
75 |
76 | ### Monitor training progress
77 |
78 | ```bash
79 | $ tensorboard --logdir exp
80 | ```
81 |
82 | ## Citation
83 | If you find the code is helpful, please cite the following article.
84 |
85 | ```
86 | @inproceedings{yoneyama22_interspeech,
87 | author={Reo Yoneyama and Yi-Chiao Wu and Tomoki Toda},
88 | title={{Unified Source-Filter GAN with Harmonic-plus-Noise Source Excitation Generation}},
89 | year=2022,
90 | booktitle={Proc. Interspeech 2022},
91 | pages={848--852},
92 | doi={10.21437/Interspeech.2022-11130}
93 | }
94 | ```
95 |
96 | ## Authors
97 |
98 | Development:
99 | Reo Yoneyama @ Nagoya University ([@chomeyama](https://github.com/chomeyama))
100 | E-mail: `yoneyama.reo@g.sp.m.is.nagoya-u.ac.jp`
101 |
102 | Advisors:
103 | Yi-Chiao Wu @ Nagoya University ([@bigpon](https://github.com/bigpon))
104 | E-mail: `yichiao.wu@g.sp.m.is.nagoya-u.ac.jp`
105 | Tomoki Toda @ Nagoya University
106 | E-mail: `tomoki@icts.nagoya-u.ac.jp`
107 |
--------------------------------------------------------------------------------
/usfgan/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2020 Yi-Chiao Wu (Nagoya University)
4 | # based on a Parallel WaveGAN script by Tomoki Hayashi (Nagoya University)
5 | # (https://github.com/kan-bayashi/ParallelWaveGAN)
6 | # MIT License (https://opensource.org/licenses/MIT)
7 |
8 | """Utility functions."""
9 |
10 | import fnmatch
11 | import os
12 | import sys
13 | from logging import getLogger
14 |
15 | import h5py
16 | import numpy as np
17 |
18 | # A logger for this file
19 | logger = getLogger(__name__)
20 |
21 |
22 | def find_files(root_dir, query="*.wav", include_root_dir=True):
23 | """Find files recursively.
24 |
25 | Args:
26 | root_dir (str): Root root_dir to find.
27 | query (str): Query to find.
28 | include_root_dir (bool): If False, root_dir name is not included.
29 |
30 | Returns:
31 | list: List of found filenames.
32 |
33 | """
34 | files = []
35 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
36 | for filename in fnmatch.filter(filenames, query):
37 | files.append(os.path.join(root, filename))
38 | if not include_root_dir:
39 | files = [file_.replace(root_dir + "/", "") for file_ in files]
40 |
41 | return files
42 |
43 |
44 | def read_hdf5(hdf5_name, hdf5_path):
45 | """Read hdf5 dataset.
46 |
47 | Args:
48 | hdf5_name (str): Filename of hdf5 file.
49 | hdf5_path (str): Dataset name in hdf5 file.
50 |
51 | Return:
52 | any: Dataset values.
53 |
54 | """
55 | if not os.path.exists(hdf5_name):
56 | logger.error(f"There is no such a hdf5 file ({hdf5_name}).")
57 | sys.exit(1)
58 |
59 | hdf5_file = h5py.File(hdf5_name, "r")
60 |
61 | if hdf5_path not in hdf5_file:
62 | logger.error(f"There is no such a data in hdf5 file. ({hdf5_path})")
63 | sys.exit(1)
64 |
65 | hdf5_data = hdf5_file[hdf5_path][()]
66 | hdf5_file.close()
67 |
68 | return hdf5_data
69 |
70 |
71 | def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True):
72 | """Write dataset to hdf5.
73 |
74 | Args:
75 | hdf5_name (str): Hdf5 dataset filename.
76 | hdf5_path (str): Dataset path in hdf5.
77 | write_data (ndarray): Data to write.
78 | is_overwrite (bool): Whether to overwrite dataset.
79 |
80 | """
81 | # convert to numpy array
82 | write_data = np.array(write_data)
83 |
84 | # check folder existence
85 | folder_name, _ = os.path.split(hdf5_name)
86 | if not os.path.exists(folder_name) and len(folder_name) != 0:
87 | os.makedirs(folder_name)
88 |
89 | # check hdf5 existence
90 | if os.path.exists(hdf5_name):
91 | # if already exists, open with r+ mode
92 | hdf5_file = h5py.File(hdf5_name, "r+")
93 | # check dataset existence
94 | if hdf5_path in hdf5_file:
95 | if is_overwrite:
96 | hdf5_file.__delitem__(hdf5_path)
97 | else:
98 | logger.error(
99 | "Dataset in hdf5 file already exists. "
100 | "if you want to overwrite, please set is_overwrite = True."
101 | )
102 | hdf5_file.close()
103 | sys.exit(1)
104 | else:
105 | # if not exists, open with w mode
106 | hdf5_file = h5py.File(hdf5_name, "w")
107 |
108 | # write data to hdf5
109 | hdf5_file.create_dataset(hdf5_path, data=write_data)
110 | hdf5_file.flush()
111 | hdf5_file.close()
112 |
113 |
114 | def check_hdf5(hdf5_name, hdf5_path):
115 | """Check hdf5 file existence
116 |
117 | Args:
118 | hdf5_name (str): filename of hdf5 file
119 | hdf5_path (str): dataset name in hdf5 file
120 |
121 | Return:
122 | (bool): dataset exists then return true
123 |
124 | """
125 | if not os.path.exists(hdf5_name):
126 | return False
127 | else:
128 | with h5py.File(hdf5_name, "r") as f:
129 | if hdf5_path in f:
130 | return True
131 | else:
132 | return False
133 |
134 |
135 | def read_txt(file_list):
136 | """Read .txt file list
137 |
138 | Arg:
139 | file_list (str): txt file filename
140 |
141 | Return:
142 | (list): list of read lines
143 |
144 | """
145 | with open(file_list, "r") as f:
146 | filenames = f.readlines()
147 | return [filename.replace("\n", "") for filename in filenames]
148 |
149 |
150 | def check_filename(list1, list2):
151 | """Check the filenames of two list are matched
152 |
153 | Arg:
154 | list1 (list): file list 1
155 | list2 (list): file list 2
156 |
157 | Return:
158 | (bool): matched (True) or not (False)
159 |
160 | """
161 |
162 | def _filename(x):
163 | return os.path.basename(x).split(".")[0]
164 |
165 | list1 = list(map(_filename, list1))
166 | list2 = list(map(_filename, list2))
167 |
168 | return list1 == list2
169 |
--------------------------------------------------------------------------------
/usfgan/bin/decode.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Decoding Script for Unified Source-Filter GAN.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 | - https://github.com/kan-bayashi/ParallelWaveGAN
11 |
12 | """
13 |
14 | import logging
15 | import os
16 | from logging import getLogger
17 | from time import time
18 |
19 | import hydra
20 | import numpy as np
21 | import soundfile as sf
22 | import torch
23 | import usfgan
24 | from hydra.utils import to_absolute_path
25 | from omegaconf import DictConfig
26 | from tqdm import tqdm
27 | from usfgan.datasets import FeatDataset
28 | from usfgan.utils.features import SignalGenerator
29 |
30 | # A logger for this file
31 | logger = getLogger(__name__)
32 |
33 |
34 | @hydra.main(version_base=None, config_path="config", config_name="decode")
35 | def main(config: DictConfig) -> None:
36 | """Run decoding process."""
37 |
38 | # fix seed
39 | np.random.seed(config.seed)
40 | torch.manual_seed(config.seed)
41 | torch.cuda.manual_seed(config.seed)
42 | os.environ["PYTHONHASHSEED"] = str(config.seed)
43 |
44 | # set device
45 | if torch.cuda.is_available():
46 | device = torch.device("cuda")
47 | else:
48 | device = torch.device("cpu")
49 |
50 | # load pre-trained model from checkpoint file
51 | model = hydra.utils.instantiate(config.generator)
52 | model.load_state_dict(torch.load(to_absolute_path(config.checkpoint_path))["model"]["generator"])
53 | logger.info(f"Loaded model parameters from {config.checkpoint_path}.")
54 | model.remove_weight_norm()
55 | model.eval().to(device)
56 |
57 | # check directory existence
58 | out_dir = to_absolute_path(config.out_dir)
59 | if not os.path.isdir(out_dir):
60 | os.makedirs(out_dir)
61 |
62 | # get dataset
63 | dataset = FeatDataset(
64 | stats=to_absolute_path(config.data.stats),
65 | feat_list=config.data.eval_feat,
66 | return_filename=True,
67 | sample_rate=config.data.sample_rate,
68 | hop_size=config.data.hop_size,
69 | dense_factor=config.data.dense_factor,
70 | df_f0_type=config.data.df_f0_type,
71 | aux_feats=config.data.aux_feats,
72 | f0_factor=config.f0_factor,
73 | )
74 | logger.info(f"The number of features to be decoded = {len(dataset)}.")
75 |
76 | # get data processor
77 | signal_generator = SignalGenerator(
78 | sample_rate=config.data.sample_rate,
79 | hop_size=config.data.hop_size,
80 | sine_amp=config.data.sine_amp,
81 | noise_amp=config.data.noise_amp,
82 | signal_types=config.data.signal_types,
83 | )
84 | pad_fn = torch.nn.ReplicationPad1d(config.generator.aux_context_window)
85 |
86 | # start generation
87 | total_rtf = 0.0
88 | with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
89 | for idx, (feat_path, c, df, f0, contf0) in enumerate(pbar, 1):
90 | # setup input features
91 | c = pad_fn(torch.FloatTensor(c).unsqueeze(0).transpose(2, 1)).to(device)
92 | df = torch.FloatTensor(df).view(1, 1, -1).to(device)
93 | f0 = torch.FloatTensor(f0).unsqueeze(0).transpose(2, 1).to(device)
94 | contf0 = torch.FloatTensor(contf0).unsqueeze(0).transpose(2, 1).to(device)
95 | # create input signal tensor
96 | if config.data.sine_f0_type == "contf0":
97 | in_signal = signal_generator(contf0)
98 | else:
99 | in_signal = signal_generator(f0)
100 |
101 | # generate
102 | start = time()
103 | y, s = model(in_signal, c, df)[:2]
104 | rtf = (time() - start) / (y.size(-1) / config.data.sample_rate)
105 | pbar.set_postfix({"RTF": rtf})
106 | total_rtf += rtf
107 |
108 | # save output signal as PCM 16 bit wav file
109 | utt_id = os.path.splitext(os.path.basename(feat_path))[0]
110 | wav_filename = f"{utt_id}_f{config.f0_factor:.2f}.wav"
111 | sf.write(
112 | os.path.join(out_dir, wav_filename),
113 | y.view(-1).cpu().numpy(),
114 | config.data.sample_rate,
115 | "PCM_16",
116 | )
117 |
118 | # save source signal as PCM 16 bit wav file
119 | if config.save_source:
120 | wav_filename = wav_filename.replace(".wav", "_s.wav")
121 | s = s.view(-1).cpu().numpy()
122 | s = s / np.max(np.abs(s)) # normalize
123 | sf.write(
124 | os.path.join(out_dir, wav_filename),
125 | s,
126 | config.data.sample_rate,
127 | "PCM_16",
128 | )
129 |
130 | # report average RTF
131 | logger.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
132 |
133 |
134 | if __name__ == "__main__":
135 | main()
136 |
--------------------------------------------------------------------------------
/usfgan/utils/features.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Feature-related functions.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 |
11 | """
12 |
13 | import sys
14 | from logging import getLogger
15 |
16 | import numpy as np
17 | import torch
18 | from torch.nn.functional import interpolate
19 |
20 | # A logger for this file
21 | logger = getLogger(__name__)
22 |
23 |
24 | def validate_length(x, y, hop_size=None):
25 | """Validate length
26 |
27 | Args:
28 | x (ndarray): numpy array with x.shape[0] = len_x
29 | y (ndarray): numpy array with y.shape[0] = len_y
30 | hop_size (int): upsampling factor
31 |
32 | Returns:
33 | (ndarray): length adjusted x with same length y
34 | (ndarray): length adjusted y with same length x
35 |
36 | """
37 | if hop_size is None:
38 | if x.shape[0] < y.shape[0]:
39 | y = y[: x.shape[0]]
40 | if x.shape[0] > y.shape[0]:
41 | x = x[: y.shape[0]]
42 | assert len(x) == len(y)
43 | else:
44 | if x.shape[0] > y.shape[0] * hop_size:
45 | x = x[: y.shape[0] * hop_size]
46 | if x.shape[0] < y.shape[0] * hop_size:
47 | mod_y = y.shape[0] * hop_size - x.shape[0]
48 | mod_y_frame = mod_y // hop_size + 1
49 | y = y[:-mod_y_frame]
50 | x = x[: y.shape[0] * hop_size]
51 | assert len(x) == len(y) * hop_size
52 |
53 | return x, y
54 |
55 |
56 | def dilated_factor(batch_f0, fs, dense_factor):
57 | """Pitch-dependent dilated factor
58 |
59 | Args:
60 | batch_f0 (ndarray): the f0 sequence (T)
61 | fs (int): sampling rate
62 | dense_factor (int): the number of taps in one cycle
63 |
64 | Return:
65 | dilated_factors(np array):
66 | float array of the pitch-dependent dilated factors (T)
67 |
68 | """
69 | batch_f0[batch_f0 == 0] = fs / dense_factor
70 | dilated_factors = np.ones(batch_f0.shape) * fs
71 | dilated_factors /= batch_f0
72 | dilated_factors /= dense_factor
73 | assert np.all(dilated_factors > 0)
74 |
75 | return dilated_factors
76 |
77 |
78 | class SignalGenerator:
79 | """Input signal generator module."""
80 |
81 | def __init__(
82 | self,
83 | sample_rate=24000,
84 | hop_size=120,
85 | sine_amp=0.1,
86 | noise_amp=0.003,
87 | signal_types=["sine", "noise"],
88 | ):
89 | """Initialize WaveNetResidualBlock module.
90 |
91 | Args:
92 | sample_rate (int): Sampling rate.
93 | hop_size (int): Hop size of input F0.
94 | sine_amp (float): Sine amplitude for NSF-based sine generation.
95 | noise_amp (float): Noise amplitude for NSF-based sine generation.
96 | signal_types (list): List of input signal types for generator.
97 |
98 | """
99 | self.sample_rate = sample_rate
100 | self.hop_size = hop_size
101 | self.signal_types = signal_types
102 | self.sine_amp = sine_amp
103 | self.noise_amp = noise_amp
104 |
105 | for signal_type in signal_types:
106 | if not signal_type in ["noise", "sine", "uv"]:
107 | logger.info(f"{signal_type} is not supported type for generator input.")
108 | sys.exit(0)
109 | logger.info(f"Use {signal_types} for generator input signals.")
110 |
111 | @torch.no_grad()
112 | def __call__(self, f0):
113 | signals = []
114 | for typ in self.signal_types:
115 | if "noise" == typ:
116 | signals.append(self.random_noise(f0))
117 | if "sine" == typ:
118 | signals.append(self.sinusoid(f0))
119 | if "uv" == typ:
120 | signals.append(self.vuv_binary(f0))
121 |
122 | input_batch = signals[0]
123 | for signal in signals[1:]:
124 | input_batch = torch.cat([input_batch, signal], axis=1)
125 |
126 | return input_batch
127 |
128 | @torch.no_grad()
129 | def random_noise(self, f0):
130 | """Calculate noise signals.
131 |
132 | Args:
133 | f0 (Tensor): F0 tensor (B, 1, T // hop_size).
134 |
135 | Returns:
136 | Tensor: Gaussian noise signals (B, 1, T).
137 |
138 | """
139 | B, _, T = f0.size()
140 | noise = torch.randn((B, 1, T * self.hop_size), device=f0.device)
141 |
142 | return noise
143 |
144 | @torch.no_grad()
145 | def sinusoid(self, f0):
146 | """Calculate sine signals.
147 |
148 | Args:
149 | f0 (Tensor): F0 tensor (B, 1, T // hop_size).
150 |
151 | Returns:
152 | Tensor: Sines generated following NSF (B, 1, T).
153 |
154 | """
155 | B, _, T = f0.size()
156 | vuv = interpolate((f0 > 0) * torch.ones_like(f0), T * self.hop_size)
157 | radious = (interpolate(f0, T * self.hop_size) / self.sample_rate) % 1
158 | sine = vuv * torch.sin(torch.cumsum(radious, dim=2) * 2 * np.pi) * self.sine_amp
159 | if self.noise_amp > 0:
160 | noise_amp = vuv * self.noise_amp + (1.0 - vuv) * self.noise_amp / 3.0
161 | noise = torch.randn((B, 1, T * self.hop_size), device=f0.device) * noise_amp
162 | sine = sine + noise
163 |
164 | return sine
165 |
166 | @torch.no_grad()
167 | def vuv_binary(self, f0):
168 | """Calculate V/UV binary sequences.
169 |
170 | Args:
171 | f0 (Tensor): F0 tensor (B, 1, T // hop_size).
172 |
173 | Returns:
174 | Tensor: V/UV binary sequences (B, 1, T).
175 |
176 | """
177 | _, _, T = f0.size()
178 | uv = interpolate((f0 > 0) * torch.ones_like(f0), T * self.hop_size)
179 |
180 | return uv
181 |
--------------------------------------------------------------------------------
/usfgan/losses/source.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Source regularization loss modules."""
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import usfgan.losses
12 | from librosa.filters import mel as librosa_mel
13 | from usfgan.layers import CheapTrick
14 |
15 |
16 | class FlattenLoss(nn.Module):
17 | # Spectral envelope flattening loss module.
18 |
19 | def __init__(
20 | self,
21 | sampling_rate=24000,
22 | hop_size=120,
23 | fft_size=2048,
24 | f0_floor=70,
25 | f0_ceil=340,
26 | power=False,
27 | elim_0th=False,
28 | l2_norm=False,
29 | ):
30 | """Initialize spectral envelope regularlization loss module.
31 |
32 | Args:
33 | sampling_rate (int): Sampling rate.
34 | hop_size (int): Hop size.
35 | fft_size (int): FFT size.
36 | f0_floor (int): Minimum F0 value.
37 | f0_ceil (int): Maximum F0 value.
38 | uv_threshold (float): V/UV determining threshold.
39 | q1 (float): Parameter to remove effect of adjacent harmonics.
40 |
41 | """
42 | super(FlattenLoss, self).__init__()
43 | self.hop_size = hop_size
44 | self.power = power
45 | self.elim_0th = elim_0th
46 | self.cheaptrick = CheapTrick(
47 | sampling_rate=sampling_rate,
48 | hop_size=hop_size,
49 | fft_size=fft_size,
50 | f0_floor=f0_floor,
51 | f0_ceil=f0_ceil,
52 | )
53 | if l2_norm:
54 | self.loss = nn.MSELoss()
55 | else:
56 | self.loss = nn.L1Loss()
57 |
58 | def forward(self, s, f):
59 | """Calculate forward propagation.
60 |
61 | Args:
62 | s (Tensor): Predicted source signal (B, 1, T).
63 | f (Tensor): Extracted F0 sequence (B, 1, T').
64 |
65 | Returns:
66 | loss (Tensor): Spectral envelope regularization loss value.
67 |
68 | """
69 | s, f = s.squeeze(1), f.squeeze(1)
70 | e = self.cheaptrick.forward(s, f, self.power, self.elim_0th)
71 | loss = self.loss(e, e.new_zeros(e.size()))
72 | return loss
73 |
74 |
75 | class ResidualLoss(nn.Module):
76 | # hn-uSFGAN source regularization loss module.
77 |
78 | def __init__(
79 | self,
80 | sampling_rate=24000,
81 | fft_size=2048,
82 | hop_size=120,
83 | f0_floor=70,
84 | f0_ceil=340,
85 | n_mels=80,
86 | fmin=0,
87 | fmax=None,
88 | power=False,
89 | elim_0th=True,
90 | ):
91 | """Initialize Mel-spectrogram loss.
92 |
93 | Args:
94 | sampling_rate (int): Sampling rate.
95 | fft_size (int): FFT size.
96 | hop_size (int): Hop size.
97 | f0_floor (int): Minimum F0 value.
98 | f0_ceil (int): Maximum F0 value.
99 | n_mels (int): Number of Mel basis.
100 | fmin (int): Minimum frequency for Mel.
101 | fmax (int): Maximum frequency for Mel.
102 | power (bool): Whether to use power or magnitude spectrogram.
103 | elim_0th (bool): Whether to exclude 0th cepstrum in CheapTrick.
104 | If set to true, power is estimated by source-network.
105 |
106 | """
107 | super(ResidualLoss, self).__init__()
108 | self.sampling_rate = sampling_rate
109 | self.fft_size = fft_size
110 | self.hop_size = hop_size
111 | self.cheaptrick = CheapTrick(
112 | sampling_rate=sampling_rate,
113 | hop_size=hop_size,
114 | fft_size=fft_size,
115 | f0_floor=f0_floor,
116 | f0_ceil=f0_ceil,
117 | )
118 | self.win_length = fft_size
119 | self.register_buffer("window", torch.hann_window(self.win_length))
120 |
121 | # define mel-filter-bank
122 | self.n_mels = n_mels
123 | self.fmin = fmin
124 | self.fmax = fmax if fmax is not None else sampling_rate / 2
125 | melmat = librosa_mel(
126 | sr=sampling_rate, n_fft=fft_size, n_mels=n_mels, fmin=fmin, fmax=self.fmax
127 | ).T
128 | self.register_buffer("melmat", torch.from_numpy(melmat).float())
129 |
130 | self.power = power
131 | self.elim_0th = elim_0th
132 |
133 | def forward(self, s, y, f):
134 | """Calculate forward propagation.
135 |
136 | Args:
137 | s (Tensor): Predicted source signal (B, 1, T).
138 | y (Tensor): Target signal (B, 1, T).
139 | f (Tensor): Extracted F0 sequence (B, 1, T').
140 |
141 | Returns:
142 | Tensor: Difference loss value.
143 |
144 | """
145 | s, y, f = s.squeeze(1), y.squeeze(1), f.squeeze(1)
146 |
147 | with torch.no_grad():
148 | # calculate log power (or magnitude) spectrograms
149 | e = self.cheaptrick.forward(y, f, self.power, self.elim_0th)
150 | y = usfgan.losses.stft(
151 | y,
152 | self.fft_size,
153 | self.hop_size,
154 | self.win_length,
155 | self.window,
156 | power=self.power,
157 | )
158 | # adjust length, (B, T', C)
159 | minlen = min(e.size(1), y.size(1))
160 | e, y = e[:, :minlen, :], y[:, :minlen, :]
161 |
162 | # calculate mean power (or magnitude) of y
163 | if self.elim_0th:
164 | y_mean = y.mean(dim=-1, keepdim=True)
165 |
166 | # calculate target of output source signal
167 | y = torch.log(torch.clamp(y, min=1e-7))
168 | t = (y - e).exp()
169 | if self.elim_0th:
170 | t_mean = t.mean(dim=-1, keepdim=True)
171 | t = y_mean / t_mean * t
172 |
173 | # apply mel-filter-bank and log
174 | t = torch.matmul(t, self.melmat)
175 | t = torch.log(torch.clamp(t, min=1e-7))
176 |
177 | # calculate power (or magnitude) spectrogram
178 | s = usfgan.losses.stft(
179 | s,
180 | self.fft_size,
181 | self.hop_size,
182 | self.win_length,
183 | self.window,
184 | power=self.power,
185 | )
186 | # adjust length, (B, T', C)
187 | minlen = min(minlen, s.size(1))
188 | s, t = s[:, :minlen, :], t[:, :minlen, :]
189 |
190 | # apply mel-filter-bank and log
191 | s = torch.matmul(s, self.melmat)
192 | s = torch.log(torch.clamp(s, min=1e-7))
193 |
194 | loss = F.l1_loss(s, t.detach())
195 |
196 | return loss
197 |
--------------------------------------------------------------------------------
/usfgan/bin/anasyn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Analysis-synthesis script.
7 |
8 | Analysis: WORLD vocoder
9 | Synthesis: Pre-trained neural vocoder
10 |
11 | """
12 |
13 | # A logger for this file
14 | import copy
15 | import os
16 | from logging import getLogger
17 |
18 | import hydra
19 | import librosa
20 | import numpy as np
21 | import pysptk
22 | import pyworld as pw
23 | import soundfile as sf
24 | import torch
25 | import torch.nn as nn
26 | from hydra.utils import instantiate, to_absolute_path
27 | from joblib import load
28 | from omegaconf import DictConfig
29 | from scipy.interpolate import interp1d
30 | from usfgan.utils.features import SignalGenerator, dilated_factor
31 |
32 | logger = getLogger(__name__)
33 |
34 | # All-pass-filter coefficients {key -> sampling rate : value -> coefficient}
35 | ALPHA = {
36 | 8000: 0.312,
37 | 12000: 0.369,
38 | 16000: 0.410,
39 | 22050: 0.455,
40 | 24000: 0.466,
41 | 32000: 0.504,
42 | 44100: 0.544,
43 | 48000: 0.554,
44 | }
45 |
46 |
47 | def convert_continuos_f0(f0):
48 | # get uv information as binary
49 | uv = np.float32(f0 != 0)
50 | # get start and end of f0
51 | if (f0 == 0).all():
52 | logger.warn("all of the f0 values are 0.")
53 | return uv, f0, False
54 | start_f0 = f0[f0 != 0][0]
55 | end_f0 = f0[f0 != 0][-1]
56 | # padding start and end of f0 sequence
57 | cont_f0 = copy.deepcopy(f0)
58 | start_idx = np.where(cont_f0 == start_f0)[0][0]
59 | end_idx = np.where(cont_f0 == end_f0)[0][-1]
60 | cont_f0[:start_idx] = start_f0
61 | cont_f0[end_idx:] = end_f0
62 | # get non-zero frame index
63 | nz_frames = np.where(cont_f0 != 0)[0]
64 | # perform linear interpolation
65 | f = interp1d(nz_frames, cont_f0[nz_frames])
66 | cont_f0 = f(np.arange(0, cont_f0.shape[0]))
67 |
68 | return uv, cont_f0
69 |
70 |
71 | @torch.no_grad()
72 | @hydra.main(version_base=None, config_path="config", config_name="anasyn")
73 | def main(config: DictConfig) -> None:
74 | """Run analysis-synthesis process."""
75 |
76 | np.random.seed(config.seed)
77 | torch.manual_seed(config.seed)
78 | torch.cuda.manual_seed(config.seed)
79 | os.environ["PYTHONHASHSEED"] = str(config.seed)
80 |
81 | # set device
82 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83 | logger.info(f"Decode on {device}.")
84 |
85 | # load pre-trained model from checkpoint file
86 | model = instantiate(config.generator)
87 | state_dict = torch.load(
88 | to_absolute_path(config.checkpoint_path), map_location="cpu"
89 | )
90 | model.load_state_dict(state_dict["model"]["generator"])
91 | logger.info(f"Loaded model parameters from {config.checkpoint_path}.")
92 | model.remove_weight_norm()
93 | model.eval().to(device)
94 |
95 | # get scaler
96 | scaler = load(to_absolute_path(config.stats))
97 |
98 | # get data processor
99 | hop_size = int(config.sample_rate * config.frame_period * 0.001)
100 | signal_generator = SignalGenerator(
101 | sample_rate=config.sample_rate,
102 | hop_size=hop_size,
103 | sine_amp=config.sine_amp,
104 | noise_amp=config.noise_amp,
105 | signal_types=config.signal_types,
106 | )
107 | pad_fn = nn.ReplicationPad1d(config.generator.aux_context_window)
108 |
109 | # create output directory
110 | os.makedirs(config.out_dir, exist_ok=True)
111 |
112 | # loop all wav files in in_dir
113 | for wav_file in os.listdir(config.in_dir):
114 | logger.info(os.path.splitext(wav_file)[1])
115 | if os.path.splitext(wav_file)[1] != ".wav":
116 | continue
117 | wav_path = os.path.join(config.in_dir, wav_file)
118 |
119 | # WORLD analysis
120 | x, sr = sf.read(to_absolute_path(wav_path))
121 | if sr != config.sample_rate:
122 | x = librosa.resample(x, orig_sr=sr, target_sr=config.sample_rate)
123 | f0, t = pw.harvest(
124 | x,
125 | config.sample_rate,
126 | f0_floor=config.f0_floor,
127 | f0_ceil=config.f0_ceil,
128 | frame_period=config.frame_period,
129 | )
130 | sp = pw.cheaptrick(x, f0, t, config.sample_rate)
131 | ap = pw.d4c(x, f0, t, config.sample_rate)
132 | mcep = pysptk.sp2mc(sp, order=config.mcep_dim, alpha=ALPHA[config.sample_rate])
133 | mcap = pysptk.sp2mc(ap, order=config.mcap_dim, alpha=ALPHA[config.sample_rate])
134 | bap = pw.code_aperiodicity(ap, config.sample_rate)
135 |
136 | # prepare f0 related features
137 | uv, cf0 = convert_continuos_f0(f0)
138 | uv = uv[:, np.newaxis] # (T, 1)
139 | f0 = f0[:, np.newaxis] # (T, 1)
140 | cf0 = cf0[:, np.newaxis] # (T, 1)
141 | f0 *= config.f0_factor
142 | cf0 *= config.f0_factor
143 |
144 | # prepare input acoustic features
145 | c = []
146 | for feat_type in config.aux_feats:
147 | if feat_type == "f0":
148 | c += [scaler[feat_type].transform(f0)]
149 | elif feat_type in ["cf0", "contf0"]:
150 | c += [scaler[feat_type].transform(cf0)]
151 | elif feat_type == "uv":
152 | c += [scaler[feat_type].transform(uv)]
153 | elif feat_type == "mcep":
154 | c += [scaler[feat_type].transform(mcep)]
155 | elif feat_type == "mcap":
156 | c += [scaler[feat_type].transform(mcap)]
157 | elif feat_type == "bap":
158 | c += [scaler[feat_type].transform(bap)]
159 | c = np.concatenate(c, axis=1)
160 | logger.info(c.shape)
161 |
162 | # prepare dense factors
163 | df = dilated_factor(
164 | cf0 if config.df_f0_type in ["cf0", "contf0"] else f0,
165 | config.sample_rate,
166 | config.dense_factor,
167 | ).repeat(hop_size, axis=0)
168 |
169 | # convert to torch tensors
170 | f0 = torch.FloatTensor(f0).view(1, 1, -1).to(device)
171 | cf0 = torch.FloatTensor(cf0).view(1, 1, -1).to(device)
172 | c = pad_fn(torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device))
173 | df = torch.FloatTensor(np.array(df)).view(1, 1, -1).to(device)
174 |
175 | # generate input signals
176 | if config.sine_f0_type in ["cf0", "contf0"]:
177 | in_signal = signal_generator(cf0)
178 | elif config.sine_f0_type == "f0":
179 | in_signal = signal_generator(f0)
180 |
181 | # synthesize with the neural vocoder
182 | y = model(in_signal, c, df)[0]
183 |
184 | # save output signal as PCM 16 bit wav file
185 | out_path = os.path.join(config.out_dir, wav_file).replace(
186 | ".wav", f"_f{config.f0_factor:.2f}.wav"
187 | )
188 | sf.write(
189 | to_absolute_path(out_path),
190 | y.view(-1).cpu().numpy(),
191 | config.sample_rate,
192 | "PCM_16",
193 | )
194 |
195 |
196 | if __name__ == "__main__":
197 | main()
198 |
--------------------------------------------------------------------------------
/usfgan/layers/upsample.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Upsampling module.
4 |
5 | This code is modified from https://github.com/r9y9/wavenet_vocoder.
6 |
7 | """
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as F
12 | from usfgan.layers import Conv1d
13 |
14 |
15 | class Stretch2d(torch.nn.Module):
16 | """Stretch2d module."""
17 |
18 | def __init__(self, x_scale, y_scale, mode="nearest"):
19 | """Initialize Stretch2d module.
20 |
21 | Args:
22 | x_scale (int): X scaling factor (Time axis in spectrogram).
23 | y_scale (int): Y scaling factor (Frequency axis in spectrogram).
24 | mode (str): Interpolation mode.
25 |
26 | """
27 | super(Stretch2d, self).__init__()
28 | self.x_scale = x_scale
29 | self.y_scale = y_scale
30 | self.mode = mode
31 |
32 | def forward(self, x):
33 | """Calculate forward propagation.
34 |
35 | Args:
36 | x (Tensor): Input tensor (B, C, F, T).
37 |
38 | Returns:
39 | Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale)
40 |
41 | """
42 | return F.interpolate(
43 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
44 | )
45 |
46 |
47 | class Conv2d(torch.nn.Conv2d):
48 | """Conv2d module with customized initialization."""
49 |
50 | def __init__(self, *args, **kwargs):
51 | """Initialize Conv2d module."""
52 | super(Conv2d, self).__init__(*args, **kwargs)
53 |
54 | def reset_parameters(self):
55 | """Reset parameters."""
56 | self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
57 | if self.bias is not None:
58 | torch.nn.init.constant_(self.bias, 0.0)
59 |
60 |
61 | class UpsampleNetwork(torch.nn.Module):
62 | """Upsampling network module."""
63 |
64 | def __init__(
65 | self,
66 | upsample_scales,
67 | nonlinear_activation=None,
68 | nonlinear_activation_params={},
69 | interpolate_mode="nearest",
70 | freq_axis_kernel_size=1,
71 | use_causal_conv=False,
72 | ):
73 | """Initialize upsampling network module.
74 |
75 | Args:
76 | upsample_scales (list): List of upsampling scales.
77 | nonlinear_activation (str): Activation function name.
78 | nonlinear_activation_params (dict): Arguments for specified activation function.
79 | interpolate_mode (str): Interpolation mode.
80 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
81 |
82 | """
83 | super(UpsampleNetwork, self).__init__()
84 | self.use_causal_conv = use_causal_conv
85 | self.up_layers = torch.nn.ModuleList()
86 | for scale in upsample_scales:
87 | # interpolation layer
88 | stretch = Stretch2d(scale, 1, interpolate_mode)
89 | self.up_layers += [stretch]
90 |
91 | # conv layer
92 | assert (
93 | freq_axis_kernel_size - 1
94 | ) % 2 == 0, "Not support even number freq axis kernel size."
95 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2
96 | kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
97 | if use_causal_conv:
98 | padding = (freq_axis_padding, scale * 2)
99 | else:
100 | padding = (freq_axis_padding, scale)
101 | conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
102 | self.up_layers += [conv]
103 |
104 | # nonlinear
105 | if nonlinear_activation is not None:
106 | nonlinear = getattr(torch.nn, nonlinear_activation)(
107 | **nonlinear_activation_params
108 | )
109 | self.up_layers += [nonlinear]
110 |
111 | def forward(self, c):
112 | """Calculate forward propagation.
113 |
114 | Args:
115 | c : Input tensor (B, C, T).
116 |
117 | Returns:
118 | Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
119 |
120 | """
121 | c = c.unsqueeze(1) # (B, 1, C, T)
122 | for f in self.up_layers:
123 | if self.use_causal_conv and isinstance(f, Conv2d):
124 | c = f(c)[..., : c.size(-1)]
125 | else:
126 | c = f(c)
127 |
128 | return c.squeeze(1) # (B, C, T')
129 |
130 |
131 | class ConvInUpsampleNetwork(torch.nn.Module):
132 | """Convolution + upsampling network module."""
133 |
134 | def __init__(
135 | self,
136 | upsample_scales,
137 | nonlinear_activation=None,
138 | nonlinear_activation_params={},
139 | interpolate_mode="nearest",
140 | freq_axis_kernel_size=1,
141 | aux_channels=80,
142 | aux_context_window=0,
143 | use_causal_conv=False,
144 | ):
145 | """Initialize convolution + upsampling network module.
146 |
147 | Args:
148 | upsample_scales (list): List of upsampling scales.
149 | nonlinear_activation (str): Activation function name.
150 | nonlinear_activation_params (dict): Arguments for specified activation function.
151 | mode (str): Interpolation mode.
152 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
153 | aux_channels (int): Number of channels of pre-convolutional layer.
154 | aux_context_window (int): Context window size of the pre-convolutional layer.
155 | use_causal_conv (bool): Whether to use causal structure.
156 |
157 | """
158 | super(ConvInUpsampleNetwork, self).__init__()
159 | self.aux_context_window = aux_context_window
160 | self.use_causal_conv = use_causal_conv and aux_context_window > 0
161 | # To capture wide-context information in conditional features
162 | kernel_size = (
163 | aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
164 | )
165 | # NOTE(kan-bayashi): Here do not use padding because the input is already padded
166 | self.conv_in = Conv1d(
167 | aux_channels, aux_channels, kernel_size=kernel_size, bias=False
168 | )
169 | self.upsample = UpsampleNetwork(
170 | upsample_scales=upsample_scales,
171 | nonlinear_activation=nonlinear_activation,
172 | nonlinear_activation_params=nonlinear_activation_params,
173 | interpolate_mode=interpolate_mode,
174 | freq_axis_kernel_size=freq_axis_kernel_size,
175 | use_causal_conv=use_causal_conv,
176 | )
177 |
178 | def forward(self, c):
179 | """Calculate forward propagation.
180 |
181 | Args:
182 | c : Input tensor (B, C, T').
183 |
184 | Returns:
185 | Tensor: Upsampled tensor (B, C, T),
186 | where T = (T' - aux_context_window * 2) * prod(upsample_scales).
187 |
188 | Note:
189 | The length of inputs considers the context window size.
190 |
191 | """
192 | c_ = self.conv_in(c)
193 | c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
194 | return self.upsample(c)
195 |
--------------------------------------------------------------------------------
/egs/vctk/data/spk_info.yaml:
--------------------------------------------------------------------------------
1 | p225:
2 | f0_min: 130
3 | f0_max: 260
4 | pow_th: -20
5 | p226:
6 | f0_min: 80
7 | f0_max: 180
8 | pow_th: -20
9 | p227:
10 | f0_min: 80
11 | f0_max: 190
12 | pow_th: -20
13 | p228:
14 | f0_min: 120
15 | f0_max: 300
16 | pow_th: -20
17 | p229:
18 | f0_min: 140
19 | f0_max: 230
20 | pow_th: -20
21 | p230:
22 | f0_min: 140
23 | f0_max: 260
24 | pow_th: -20
25 | p231:
26 | f0_min: 110
27 | f0_max: 280
28 | pow_th: -20
29 | p232:
30 | f0_min: 70
31 | f0_max: 180
32 | pow_th: -20
33 | p233:
34 | f0_min: 150
35 | f0_max: 290
36 | pow_th: -20
37 | p234:
38 | f0_min: 130
39 | f0_max: 270
40 | pow_th: -20
41 | p236:
42 | f0_min: 170
43 | f0_max: 330
44 | pow_th: -20
45 | p237:
46 | f0_min: 60
47 | f0_max: 120
48 | pow_th: -20
49 | p238:
50 | f0_min: 140
51 | f0_max: 320
52 | pow_th: -20
53 | p239:
54 | f0_min: 140
55 | f0_max: 270
56 | pow_th: -20
57 | p240:
58 | f0_min: 150
59 | f0_max: 330
60 | pow_th: -20
61 | p241:
62 | f0_min: 85
63 | f0_max: 170
64 | pow_th: -20
65 | p243:
66 | f0_min: 80
67 | f0_max: 180
68 | pow_th: -20
69 | p244:
70 | f0_min: 150
71 | f0_max: 320
72 | pow_th: -20
73 | p245:
74 | f0_min: 70
75 | f0_max: 130
76 | pow_th: -20
77 | p246:
78 | f0_min: 70
79 | f0_max: 140
80 | pow_th: -20
81 | p247:
82 | f0_min: 90
83 | f0_max: 200
84 | pow_th: -20
85 | p248:
86 | f0_min: 130
87 | f0_max: 390
88 | pow_th: -20
89 | p249:
90 | f0_min: 130
91 | f0_max: 240
92 | pow_th: -20
93 | p250:
94 | f0_min: 110
95 | f0_max: 380
96 | pow_th: -20
97 | p251:
98 | f0_min: 80
99 | f0_max: 170
100 | pow_th: -20
101 | p252:
102 | f0_min: 80
103 | f0_max: 160
104 | pow_th: -20
105 | p253:
106 | f0_min: 160
107 | f0_max: 310
108 | pow_th: -20
109 | p254:
110 | f0_min: 55
111 | f0_max: 120
112 | pow_th: -20
113 | p255:
114 | f0_min: 95
115 | f0_max: 210
116 | pow_th: -20
117 | p256:
118 | f0_min: 60
119 | f0_max: 130
120 | pow_th: -20
121 | p257:
122 | f0_min: 150
123 | f0_max: 300
124 | pow_th: -20
125 | p258:
126 | f0_min: 90
127 | f0_max: 170
128 | pow_th: -20
129 | p259:
130 | f0_min: 80
131 | f0_max: 200
132 | pow_th: -20
133 | p260:
134 | f0_min: 70
135 | f0_max: 170
136 | pow_th: -20
137 | p261:
138 | f0_min: 140
139 | f0_max: 360
140 | pow_th: -20
141 | p262:
142 | f0_min: 120
143 | f0_max: 250
144 | pow_th: -20
145 | p263:
146 | f0_min: 70
147 | f0_max: 160
148 | pow_th: -20
149 | p264:
150 | f0_min: 150
151 | f0_max: 250
152 | pow_th: -20
153 | p265:
154 | f0_min: 130
155 | f0_max: 290
156 | pow_th: -20
157 | p266:
158 | f0_min: 100
159 | f0_max: 300
160 | pow_th: -20
161 | p267:
162 | f0_min: 130
163 | f0_max: 270
164 | pow_th: -20
165 | p268:
166 | f0_min: 130
167 | f0_max: 300
168 | pow_th: -20
169 | p269:
170 | f0_min: 150
171 | f0_max: 260
172 | pow_th: -20
173 | p270:
174 | f0_min: 75
175 | f0_max: 150
176 | pow_th: -20
177 | p271:
178 | f0_min: 90
179 | f0_max: 160
180 | pow_th: -20
181 | p272:
182 | f0_min: 80
183 | f0_max: 180
184 | pow_th: -20
185 | p273:
186 | f0_min: 100
187 | f0_max: 220
188 | pow_th: -20
189 | p274:
190 | f0_min: 75
191 | f0_max: 150
192 | pow_th: -20
193 | p275:
194 | f0_min: 80
195 | f0_max: 160
196 | pow_th: -20
197 | p276:
198 | f0_min: 170
199 | f0_max: 340
200 | pow_th: -20
201 | p277:
202 | f0_min: 160
203 | f0_max: 260
204 | pow_th: -20
205 | p278:
206 | f0_min: 80
207 | f0_max: 170
208 | pow_th: -20
209 | p279:
210 | f0_min: 80
211 | f0_max: 200
212 | pow_th: -20
213 | p280:
214 | f0_min: 140
215 | f0_max: 280
216 | pow_th: -20
217 | p281:
218 | f0_min: 70
219 | f0_max: 140
220 | pow_th: -20
221 | p282:
222 | f0_min: 150
223 | f0_max: 250
224 | pow_th: -20
225 | p283:
226 | f0_min: 160
227 | f0_max: 280
228 | pow_th: -20
229 | p284:
230 | f0_min: 60
231 | f0_max: 170
232 | pow_th: -20
233 | p285:
234 | f0_min: 80
235 | f0_max: 180
236 | pow_th: -20
237 | p286:
238 | f0_min: 90
239 | f0_max: 210
240 | pow_th: -20
241 | p287:
242 | f0_min: 65
243 | f0_max: 140
244 | pow_th: -20
245 | p288:
246 | f0_min: 130
247 | f0_max: 300
248 | pow_th: -20
249 | p292:
250 | f0_min: 70
251 | f0_max: 170
252 | pow_th: -20
253 | p293:
254 | f0_min: 130
255 | f0_max: 260
256 | pow_th: -20
257 | p294:
258 | f0_min: 120
259 | f0_max: 260
260 | pow_th: -20
261 | p295:
262 | f0_min: 140
263 | f0_max: 260
264 | pow_th: -20
265 | p297:
266 | f0_min: 120
267 | f0_max: 300
268 | pow_th: -20
269 | p298:
270 | f0_min: 80
271 | f0_max: 190
272 | pow_th: -20
273 | p299:
274 | f0_min: 110
275 | f0_max: 270
276 | pow_th: -20
277 | p300:
278 | f0_min: 130
279 | f0_max: 270
280 | pow_th: -20
281 | p301:
282 | f0_min: 120
283 | f0_max: 250
284 | pow_th: -20
285 | p302:
286 | f0_min: 90
287 | f0_max: 180
288 | pow_th: -20
289 | p303:
290 | f0_min: 160
291 | f0_max: 300
292 | pow_th: -20
293 | p304:
294 | f0_min: 65
295 | f0_max: 160
296 | pow_th: -20
297 | p305:
298 | f0_min: 160
299 | f0_max: 330
300 | pow_th: -20
301 | p306:
302 | f0_min: 130
303 | f0_max: 260
304 | pow_th: -20
305 | p307:
306 | f0_min: 160
307 | f0_max: 400
308 | pow_th: -20
309 | p308:
310 | f0_min: 140
311 | f0_max: 280
312 | pow_th: -20
313 | p310:
314 | f0_min: 170
315 | f0_max: 350
316 | pow_th: -20
317 | p311:
318 | f0_min: 75
319 | f0_max: 160
320 | pow_th: -20
321 | p312:
322 | f0_min: 130
323 | f0_max: 280
324 | pow_th: -20
325 | p313:
326 | f0_min: 120
327 | f0_max: 260
328 | pow_th: -20
329 | p314:
330 | f0_min: 120
331 | f0_max: 270
332 | pow_th: -20
333 | p315:
334 | f0_min: 90
335 | f0_max: 190
336 | pow_th: -20
337 | p316:
338 | f0_min: 65
339 | f0_max: 140
340 | pow_th: -20
341 | p317:
342 | f0_min: 175
343 | f0_max: 360
344 | pow_th: -20
345 | p318:
346 | f0_min: 110
347 | f0_max: 290
348 | pow_th: -20
349 | p323:
350 | f0_min: 130
351 | f0_max: 390
352 | pow_th: -20
353 | p326:
354 | f0_min: 50
355 | f0_max: 110
356 | pow_th: -20
357 | p329:
358 | f0_min: 150
359 | f0_max: 280
360 | pow_th: -20
361 | p330:
362 | f0_min: 140
363 | f0_max: 270
364 | pow_th: -20
365 | p333:
366 | f0_min: 120
367 | f0_max: 270
368 | pow_th: -20
369 | p334:
370 | f0_min: 65
371 | f0_max: 140
372 | pow_th: -20
373 | p335:
374 | f0_min: 130
375 | f0_max: 310
376 | pow_th: -20
377 | p336:
378 | f0_min: 130
379 | f0_max: 310
380 | pow_th: -20
381 | p339:
382 | f0_min: 120
383 | f0_max: 310
384 | pow_th: -20
385 | p340:
386 | f0_min: 170
387 | f0_max: 270
388 | pow_th: -20
389 | p341:
390 | f0_min: 130
391 | f0_max: 280
392 | pow_th: -20
393 | p343:
394 | f0_min: 110
395 | f0_max: 240
396 | pow_th: -20
397 | p345:
398 | f0_min: 70
399 | f0_max: 140
400 | pow_th: -20
401 | p347:
402 | f0_min: 60
403 | f0_max: 130
404 | pow_th: -20
405 | p351:
406 | f0_min: 170
407 | f0_max: 280
408 | pow_th: -20
409 | p360:
410 | f0_min: 70
411 | f0_max: 160
412 | pow_th: -20
413 | p361:
414 | f0_min: 120
415 | f0_max: 290
416 | pow_th: -20
417 | p362:
418 | f0_min: 130
419 | f0_max: 300
420 | pow_th: -20
421 | p363:
422 | f0_min: 70
423 | f0_max: 180
424 | pow_th: -20
425 | p364:
426 | f0_min: 65
427 | f0_max: 145
428 | pow_th: -20
429 | p374:
430 | f0_min: 70
431 | f0_max: 190
432 | pow_th: -20
433 | p376:
434 | f0_min: 70
435 | f0_max: 170
436 | pow_th: -20
437 | s5:
438 | f0_min: 140
439 | f0_max: 340
440 | pow_th: -20
441 |
--------------------------------------------------------------------------------
/usfgan/layers/cheaptrick.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Spectral envelopes estimation module based on CheapTrick.
7 |
8 | References:
9 | - https://www.sciencedirect.com/science/article/pii/S0167639314000697
10 | - https://github.com/mmorise/World
11 |
12 | """
13 |
14 | import math
15 |
16 | import torch
17 | import torch.fft
18 | import torch.nn as nn
19 |
20 |
21 | class AdaptiveWindowing(nn.Module):
22 | """CheapTrick F0 adptive windowing module."""
23 |
24 | def __init__(
25 | self,
26 | sampling_rate,
27 | hop_size,
28 | fft_size,
29 | f0_floor,
30 | f0_ceil,
31 | ):
32 | """Initilize AdaptiveWindowing module.
33 |
34 | Args:
35 | sampling_rate (int): Sampling rate.
36 | hop_size (int): Hop size.
37 | fft_size (int): FFT size.
38 | f0_floor (int): Minimum value of F0.
39 | f0_ceil (int): Maximum value of F0.
40 |
41 | """
42 | super(AdaptiveWindowing, self).__init__()
43 |
44 | self.sampling_rate = sampling_rate
45 | self.hop_size = hop_size
46 | self.fft_size = fft_size
47 | self.register_buffer("window", torch.zeros((f0_ceil + 1, fft_size)))
48 | self.zero_padding = nn.ConstantPad2d((fft_size // 2, fft_size // 2, 0, 0), 0)
49 |
50 | # Pre-calculation of the window functions
51 | for f0 in range(f0_floor, f0_ceil + 1):
52 | half_win_len = round(1.5 * self.sampling_rate / f0)
53 | base_index = torch.arange(
54 | -half_win_len, half_win_len + 1, dtype=torch.int64
55 | )
56 | position = base_index / 1.5 / self.sampling_rate
57 | left = fft_size // 2 - half_win_len
58 | right = fft_size // 2 + half_win_len + 1
59 | window = torch.zeros(fft_size)
60 | window[left:right] = 0.5 * torch.cos(math.pi * position * f0) + 0.5
61 | average = torch.sum(window * window).pow(0.5)
62 | self.window[f0] = window / average
63 |
64 | def forward(self, x, f, power=False):
65 | """Calculate forward propagation.
66 |
67 | Args:
68 | x (Tensor): Waveform (B, fft_size // 2 + 1, T).
69 | f (Tensor): F0 sequence (B, T').
70 | power (boot): Whether to use power or magnitude.
71 |
72 | Returns:
73 | Tensor: Power spectrogram (B, bin_size, T').
74 |
75 | """
76 | # Get the matrix of window functions corresponding to F0
77 | x = self.zero_padding(x).unfold(1, self.fft_size, self.hop_size)
78 | windows = self.window[f]
79 |
80 | # Adaptive windowing and calculate power spectrogram.
81 | # In test, change x[:, : -1, :] to x.
82 | x = torch.abs(torch.fft.rfft(x[:, :-1, :] * windows))
83 | x = x.pow(2) if power else x
84 |
85 | return x
86 |
87 |
88 | class AdaptiveLiftering(nn.Module):
89 | """CheapTrick F0 adptive windowing module."""
90 |
91 | def __init__(
92 | self,
93 | sampling_rate,
94 | fft_size,
95 | f0_floor,
96 | f0_ceil,
97 | q1=-0.15,
98 | ):
99 | """Initilize AdaptiveLiftering module.
100 |
101 | Args:
102 | sampling_rate (int): Sampling rate.
103 | fft_size (int): FFT size.
104 | f0_floor (int): Minimum value of F0.
105 | f0_ceil (int): Maximum value of F0.
106 | q1 (float): Parameter to remove effect of adjacent harmonics.
107 |
108 | """
109 | super(AdaptiveLiftering, self).__init__()
110 |
111 | self.sampling_rate = sampling_rate
112 | self.bin_size = fft_size // 2 + 1
113 | self.q1 = q1
114 | self.q0 = 1.0 - 2.0 * q1
115 | self.register_buffer(
116 | "smoothing_lifter", torch.zeros((f0_ceil + 1, self.bin_size))
117 | )
118 | self.register_buffer(
119 | "compensation_lifter", torch.zeros((f0_ceil + 1, self.bin_size))
120 | )
121 |
122 | # Pre-calculation of the smoothing lifters and compensation lifters
123 | for f0 in range(f0_floor, f0_ceil + 1):
124 | smoothing_lifter = torch.zeros(self.bin_size)
125 | compensation_lifter = torch.zeros(self.bin_size)
126 | quefrency = torch.arange(1, self.bin_size) / sampling_rate
127 | smoothing_lifter[0] = 1.0
128 | smoothing_lifter[1:] = torch.sin(math.pi * f0 * quefrency) / (
129 | math.pi * f0 * quefrency
130 | )
131 | compensation_lifter[0] = self.q0 + 2.0 * self.q1
132 | compensation_lifter[1:] = self.q0 + 2.0 * self.q1 * torch.cos(
133 | 2.0 * math.pi * f0 * quefrency
134 | )
135 | self.smoothing_lifter[f0] = smoothing_lifter
136 | self.compensation_lifter[f0] = compensation_lifter
137 |
138 | def forward(self, x, f, elim_0th=False):
139 | """Calculate forward propagation.
140 |
141 | Args:
142 | x (Tensor): Power spectrogram (B, bin_size, T').
143 | f (Tensor): F0 sequence (B, T').
144 | elim_0th (bool): Whether to eliminate cepstram 0th component.
145 |
146 | Returns:
147 | Tensor: Estimated spectral envelope (B, bin_size, T').
148 |
149 | """
150 | # Setting the smoothing lifter and compensation lifter
151 | smoothing_lifter = self.smoothing_lifter[f]
152 | compensation_lifter = self.compensation_lifter[f]
153 |
154 | # Calculating cepstrum
155 | tmp = torch.cat((x, torch.flip(x[:, :, 1:-1], [2])), dim=2)
156 | cepstrum = torch.fft.rfft(torch.log(torch.clamp(tmp, min=1e-7))).real
157 |
158 | # Set the 0th cepstrum to 0
159 | if elim_0th:
160 | cepstrum[..., 0] = 0
161 |
162 | # Liftering cepstrum with the lifters
163 | liftered_cepstrum = cepstrum * smoothing_lifter * compensation_lifter
164 |
165 | # Return the result to the spectral domain
166 | x = torch.fft.irfft(liftered_cepstrum)[:, :, : self.bin_size]
167 |
168 | return x
169 |
170 |
171 | class CheapTrick(nn.Module):
172 | """CheapTrick based spectral envelope estimation module."""
173 |
174 | def __init__(
175 | self,
176 | sampling_rate,
177 | hop_size,
178 | fft_size,
179 | f0_floor=70,
180 | f0_ceil=340,
181 | uv_threshold=0,
182 | q1=-0.15,
183 | ):
184 | """Initilize AdaptiveLiftering module.
185 |
186 | Args:
187 | sampling_rate (int): Sampling rate.
188 | hop_size (int): Hop size.
189 | fft_size (int): FFT size.
190 | f0_floor (int): Minimum value of F0.
191 | f0_ceil (int): Maximum value of F0.
192 | uv_threshold (float): V/UV determining threshold.
193 | q1 (float): Parameter to remove effect of adjacent harmonics.
194 |
195 | """
196 | super(CheapTrick, self).__init__()
197 |
198 | # fft_size must be larger than 3.0 * sampling_rate / f0_floor
199 | assert fft_size > 3.0 * sampling_rate / f0_floor
200 | self.f0_floor = f0_floor
201 | self.f0_ceil = f0_ceil
202 | self.uv_threshold = uv_threshold
203 |
204 | self.ada_wind = AdaptiveWindowing(
205 | sampling_rate,
206 | hop_size,
207 | fft_size,
208 | f0_floor,
209 | f0_ceil,
210 | )
211 | self.ada_lift = AdaptiveLiftering(
212 | sampling_rate,
213 | fft_size,
214 | f0_floor,
215 | f0_ceil,
216 | q1,
217 | )
218 |
219 | def forward(self, x, f, power=False, elim_0th=False):
220 | """Calculate forward propagation.
221 |
222 | Args:
223 | x (Tensor): Power spectrogram (B, T).
224 | f (Tensor): F0 sequence (B, T').
225 | power (boot): Whether to use power or magnitude spectrogram.
226 | elim_0th (bool): Whether to eliminate cepstram 0th component.
227 |
228 | Returns:
229 | Tensor: Estimated spectral envelope (B, bin_size, T').
230 |
231 | """
232 | # Step0: Round F0 values to integers.
233 | voiced = (f > self.uv_threshold) * torch.ones_like(f)
234 | f = voiced * f + (1.0 - voiced) * self.f0_ceil
235 | f = torch.round(torch.clamp(f, min=self.f0_floor, max=self.f0_ceil)).to(
236 | torch.int64
237 | )
238 |
239 | # Step1: Adaptive windowing and calculate power or amplitude spectrogram.
240 | x = self.ada_wind(x, f, power)
241 |
242 | # Step3: Smoothing (log axis) and spectral recovery on the cepstrum domain.
243 | x = self.ada_lift(x, f, elim_0th)
244 |
245 | return x
246 |
--------------------------------------------------------------------------------
/usfgan/losses/stft.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """STFT-based loss modules.
7 |
8 | References:
9 | - https://github.com/kan-bayashi/ParallelWaveGAN
10 |
11 | """
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from librosa.filters import mel as librosa_mel
17 |
18 |
19 | def stft(
20 | x, fft_size, hop_size, win_length, window, center=True, onesided=True, power=False
21 | ):
22 | """Perform STFT and convert to magnitude spectrogram.
23 |
24 | Args:
25 | x (Tensor): Input signal tensor (B, T).
26 | fft_size (int): FFT size.
27 | hop_size (int): Hop size.
28 | win_length (int): Window length.
29 | window (str): Window function type.
30 |
31 | Returns:
32 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
33 |
34 | """
35 | x_stft = torch.stft(
36 | x,
37 | fft_size,
38 | hop_size,
39 | win_length,
40 | window,
41 | center=center,
42 | onesided=onesided,
43 | return_complex=False,
44 | )
45 | real = x_stft[..., 0]
46 | imag = x_stft[..., 1]
47 |
48 | if power:
49 | return torch.clamp(real ** 2 + imag ** 2, min=1e-7).transpose(2, 1)
50 | else:
51 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
52 |
53 |
54 | class SpectralConvergenceLoss(nn.Module):
55 | """Spectral convergence loss module."""
56 |
57 | def __init__(self):
58 | """Initilize spectral convergence loss module."""
59 | super(SpectralConvergenceLoss, self).__init__()
60 |
61 | def forward(self, x_mag, y_mag):
62 | """Calculate forward propagation.
63 |
64 | Args:
65 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
66 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
67 |
68 | Returns:
69 | Tensor: Spectral convergence loss value.
70 |
71 | """
72 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
73 |
74 |
75 | class LogSTFTMagnitudeLoss(nn.Module):
76 | """Log STFT magnitude loss module."""
77 |
78 | def __init__(self):
79 | """Initilize log STFT magnitude loss module."""
80 | super(LogSTFTMagnitudeLoss, self).__init__()
81 |
82 | def forward(self, x_mag, y_mag):
83 | """Calculate forward propagation.
84 |
85 | Args:
86 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
87 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
88 |
89 | Returns:
90 | Tensor: Log STFT magnitude loss value.
91 |
92 | """
93 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
94 |
95 |
96 | class STFTLoss(nn.Module):
97 | """STFT loss module."""
98 |
99 | def __init__(
100 | self, fft_size=1024, hop_size=120, win_length=600, window="hann_window"
101 | ):
102 | """Initialize STFT loss module."""
103 | super(STFTLoss, self).__init__()
104 | self.fft_size = fft_size
105 | self.hop_size = hop_size
106 | self.win_length = win_length
107 | self.register_buffer("window", getattr(torch, window)(win_length))
108 | self.spectral_convergence_loss = SpectralConvergenceLoss()
109 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
110 |
111 | def forward(self, x, y):
112 | """Calculate forward propagation.
113 |
114 | Args:
115 | x (Tensor): Predicted signal (B, T).
116 | y (Tensor): Groundtruth signal (B, T).
117 |
118 | Returns:
119 | Tensor: Spectral convergence loss value.
120 | Tensor: Log STFT magnitude loss value.
121 |
122 | """
123 | x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window)
124 | y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window)
125 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
126 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
127 |
128 | return sc_loss, mag_loss
129 |
130 |
131 | class MultiResolutionSTFTLoss(nn.Module):
132 | """Multi resolution STFT loss module."""
133 |
134 | def __init__(
135 | self,
136 | fft_sizes=[1024, 2048, 512],
137 | hop_sizes=[120, 240, 50],
138 | win_lengths=[600, 1200, 240],
139 | window="hann_window",
140 | ):
141 | """Initialize Multi resolution STFT loss module.
142 |
143 | Args:
144 | fft_sizes (list): List of FFT sizes.
145 | hop_sizes (list): List of hop sizes.
146 | win_lengths (list): List of window lengths.
147 | window (str): Window function type.
148 |
149 | """
150 | super(MultiResolutionSTFTLoss, self).__init__()
151 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
152 | self.stft_losses = nn.ModuleList()
153 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
154 | self.stft_losses += [STFTLoss(fs, ss, wl, window)]
155 |
156 | def forward(self, x, y):
157 | """Calculate forward propagation.
158 |
159 | Args:
160 | x (Tensor): Predicted signal (B, 1, T).
161 | y (Tensor): Groundtruth signal (B, 1, T).
162 |
163 | Returns:
164 | Tensor: Multi resolution spectral convergence loss value.
165 | Tensor: Multi resolution log STFT magnitude loss value.
166 |
167 | """
168 | x = x.squeeze(1)
169 | y = y.squeeze(1)
170 |
171 | sc_loss = 0.0
172 | mag_loss = 0.0
173 | for f in self.stft_losses:
174 | sc_l, mag_l = f(x, y)
175 | sc_loss += sc_l
176 | mag_loss += mag_l
177 | sc_loss /= len(self.stft_losses)
178 | mag_loss /= len(self.stft_losses)
179 |
180 | return sc_loss, mag_loss
181 |
182 |
183 | class LogSTFTPowerLoss(nn.Module):
184 | """Log STFT power loss module."""
185 |
186 | def __init__(
187 | self, fft_size=1024, hop_size=120, win_length=600, window="hann_window"
188 | ):
189 | """Initialize STFT loss module."""
190 | super(LogSTFTPowerLoss, self).__init__()
191 | self.fft_size = fft_size
192 | self.hop_size = hop_size
193 | self.win_length = win_length
194 | self.register_buffer("window", getattr(torch, window)(win_length))
195 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
196 | self.mse = nn.MSELoss()
197 |
198 | def forward(self, x, y):
199 | """Calculate forward propagation.
200 |
201 | Args:
202 | x (Tensor): Predicted signal (B, T).
203 | y (Tensor): Groundtruth signal (B, T).
204 |
205 | Returns:
206 | Tensor: Spectral convergence loss value.
207 | Tensor: Log STFT magnitude loss value.
208 |
209 | """
210 | x_pow = stft(
211 | x,
212 | self.fft_size,
213 | self.hop_size,
214 | self.win_length,
215 | self.window,
216 | power=True,
217 | )
218 | y_pow = stft(
219 | y,
220 | self.fft_size,
221 | self.hop_size,
222 | self.win_length,
223 | self.window,
224 | power=True,
225 | )
226 | stft_loss = (
227 | self.mse(
228 | torch.log(torch.clamp(x_pow, min=1e-7)),
229 | torch.log(torch.clamp(y_pow, min=1e-7)),
230 | )
231 | / 2.0
232 | )
233 |
234 | return stft_loss
235 |
236 |
237 | class MultiResolutionLogSTFTPowerLoss(nn.Module):
238 | """Multi-resolution log STFT power loss module.
239 |
240 | This loss is same as the loss of Neural Source-Filter.
241 | https://arxiv.org/abs/1904.12088
242 | """
243 |
244 | def __init__(
245 | self,
246 | fft_sizes=[1024, 2048, 512],
247 | hop_sizes=[120, 240, 50],
248 | win_lengths=[600, 1200, 240],
249 | window="hann_window",
250 | ):
251 | """Initialize Multi-resolution STFT loss module.
252 |
253 | Args:
254 | fft_sizes (list): List of FFT sizes.
255 | hop_sizes (list): List of hop sizes.
256 | win_lengths (list): List of window lengths.
257 | window (str): Window function type.
258 |
259 | """
260 | super(MultiResolutionLogSTFTPowerLoss, self).__init__()
261 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
262 | self.stft_losses = nn.ModuleList()
263 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
264 | self.stft_losses += [LogSTFTPowerLoss(fs, ss, wl, window)]
265 |
266 | def forward(self, x, y):
267 | """Multi-resolution log STFT power loss value.
268 |
269 | Args:
270 | x (Tensor): Predicted signal (B, 1, T).
271 | y (Tensor): Groundtruth signal (B, 1, T).
272 |
273 | Returns:
274 | Tensor: Multi-resolution log STFT power loss value.
275 |
276 | """
277 | x = x.squeeze(1)
278 | y = y.squeeze(1)
279 |
280 | stft_loss = 0.0
281 | for f in self.stft_losses:
282 | l = f(x, y)
283 | stft_loss += l
284 | stft_loss /= len(self.stft_losses)
285 |
286 | return stft_loss
287 |
288 |
289 | class MelSpectralLoss(nn.Module):
290 | """Mel-spectral L1 loss module."""
291 |
292 | def __init__(
293 | self,
294 | fft_size=1024,
295 | hop_size=120,
296 | win_length=1024,
297 | window="hann_window",
298 | sampling_rate=24000,
299 | n_mels=80,
300 | fmin=0,
301 | fmax=None,
302 | ):
303 | """Initialize MelSpectralLoss loss.
304 |
305 | Args:
306 | fft_size (int): FFT points.
307 | hop_length (int): Hop length.
308 | win_length (Optional[int]): Window length.
309 | window (str): Window type.
310 | sampling_rate (int): Sampling rate.
311 | n_mels (int): Number of Mel basis.
312 | fmin (Optional[int]): Minimum frequency of mel-filter-bank.
313 | fmax (Optional[int]): Maximum frequency of mel-filter-bank.
314 |
315 | """
316 | super().__init__()
317 | self.fft_size = fft_size
318 | self.hop_size = hop_size
319 | self.win_length = win_length if win_length is not None else fft_size
320 | self.register_buffer("window", getattr(torch, window)(self.win_length))
321 | self.sampling_rate = sampling_rate
322 | self.n_mels = n_mels
323 | self.fmin = fmin
324 | self.fmax = fmax if fmax is not None else sampling_rate / 2
325 | melmat = librosa_mel(
326 | sr=sampling_rate, n_fft=fft_size, n_mels=n_mels, fmin=fmin, fmax=fmax
327 | ).T
328 | self.register_buffer("melmat", torch.from_numpy(melmat).float())
329 |
330 | def forward(self, x, y):
331 | """Calculate Mel-spectral L1 loss.
332 |
333 | Args:
334 | x (Tensor): Generated waveform tensor (B, 1, T).
335 | y (Tensor): Groundtruth waveform tensor (B, 1, T).
336 |
337 | Returns:
338 | Tensor: Mel-spectral L1 loss value.
339 |
340 | """
341 | x = x.squeeze(1)
342 | y = y.squeeze(1)
343 | x_mag = stft(x, self.fft_size, self.hop_size, self.win_length, self.window)
344 | y_mag = stft(y, self.fft_size, self.hop_size, self.win_length, self.window)
345 | x_log_mel = torch.log(torch.clamp(torch.matmul(x_mag, self.melmat), min=1e-7))
346 | y_log_mel = torch.log(torch.clamp(torch.matmul(y_mag, self.melmat), min=1e-7))
347 | mel_loss = F.l1_loss(x_log_mel, y_log_mel)
348 |
349 | return mel_loss
350 |
--------------------------------------------------------------------------------
/usfgan/datasets/audio_feat_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Dataset modules.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 | - https://github.com/kan-bayashi/ParallelWaveGAN
11 |
12 | """
13 |
14 | from logging import getLogger
15 | from multiprocessing import Manager
16 |
17 | import numpy as np
18 | import soundfile as sf
19 | from hydra.utils import to_absolute_path
20 | from joblib import load
21 | from torch.utils.data import Dataset
22 | from usfgan.utils import (
23 | check_filename,
24 | dilated_factor,
25 | read_hdf5,
26 | read_txt,
27 | validate_length,
28 | )
29 |
30 | # A logger for this file
31 | logger = getLogger(__name__)
32 |
33 |
34 | class AudioFeatDataset(Dataset):
35 | """PyTorch compatible audio and acoustic feat. dataset."""
36 |
37 | def __init__(
38 | self,
39 | stats,
40 | audio_list,
41 | feat_list,
42 | audio_length_threshold=None,
43 | feat_length_threshold=None,
44 | return_filename=False,
45 | allow_cache=False,
46 | sample_rate=24000,
47 | hop_size=120,
48 | dense_factor=4,
49 | df_f0_type="contf0",
50 | aux_feats=["mcep", "mcap"],
51 | ):
52 | """Initialize dataset.
53 |
54 | Args:
55 | stats (str): Filename of the statistic hdf5 file.
56 | audio_list (str): Filename of the list of audio files.
57 | feat_list (str): Filename of the list of feature files.
58 | audio_length_threshold (int): Threshold to remove short audio files.
59 | feat_length_threshold (int): Threshold to remove short feature files.
60 | return_filename (bool): Whether to return the filename with arrays.
61 | allow_cache (bool): Whether to allow cache of the loaded files.
62 | hop_size (int): Hope size of acoustic feature
63 | dense_factor (int): Number of taps in one cycle.
64 | aux_feats (str): Type of auxiliary features.
65 |
66 | """
67 | # load audio and feature files & check filename
68 | audio_files = read_txt(to_absolute_path(audio_list))
69 | feat_files = read_txt(to_absolute_path(feat_list))
70 | assert check_filename(audio_files, feat_files)
71 |
72 | # filter by threshold
73 | if audio_length_threshold is not None:
74 | audio_lengths = [sf.read(to_absolute_path(f)).shape[0] for f in audio_files]
75 | idxs = [
76 | idx
77 | for idx in range(len(audio_files))
78 | if audio_lengths[idx] > audio_length_threshold
79 | ]
80 | if len(audio_files) != len(idxs):
81 | logger.warning(
82 | f"Some files are filtered by audio length threshold "
83 | f"({len(audio_files)} -> {len(idxs)})."
84 | )
85 | audio_files = [audio_files[idx] for idx in idxs]
86 | feat_files = [feat_files[idx] for idx in idxs]
87 | if feat_length_threshold is not None:
88 | f0_lengths = [
89 | read_hdf5(to_absolute_path(f), "/f0").shape[0] for f in feat_files
90 | ]
91 | idxs = [
92 | idx
93 | for idx in range(len(feat_files))
94 | if f0_lengths[idx] > feat_length_threshold
95 | ]
96 | if len(feat_files) != len(idxs):
97 | logger.warning(
98 | f"Some files are filtered by mel length threshold "
99 | f"({len(feat_files)} -> {len(idxs)})."
100 | )
101 | audio_files = [audio_files[idx] for idx in idxs]
102 | feat_files = [feat_files[idx] for idx in idxs]
103 |
104 | # assert the number of files
105 | assert len(audio_files) != 0, f"${audio_list} is empty."
106 | assert len(audio_files) == len(
107 | feat_files
108 | ), f"Number of audio and features files are different ({len(audio_files)} vs {len(feat_files)})."
109 |
110 | self.audio_files = audio_files
111 | self.feat_files = feat_files
112 | self.return_filename = return_filename
113 | self.allow_cache = allow_cache
114 | self.sample_rate = sample_rate
115 | self.hop_size = hop_size
116 | self.dense_factor = dense_factor
117 | self.aux_feats = aux_feats
118 | self.df_f0_type = df_f0_type
119 | logger.info(f"Feature type : {self.aux_feats}")
120 |
121 | if allow_cache:
122 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
123 | self.manager = Manager()
124 | self.caches = self.manager.list()
125 | self.caches += [() for _ in range(len(audio_files))]
126 |
127 | # define feature pre-processing function
128 | self.scaler = load(stats)
129 |
130 | def __getitem__(self, idx):
131 | """Get specified idx items.
132 |
133 | Args:
134 | idx (int): Index of the item.
135 |
136 | Returns:
137 | str: Utterance id (only in return_filename = True).
138 | ndarray: Audio signal (T,).
139 | ndarray: UV feature (T', C).
140 | ndarray: F0 feature (T', C).
141 | ndarray: Mel-spectrogram feature (T', C).
142 | ndarray: Auxiliary feature (T', C).
143 | ndarray: Dilated factor (T, 1).
144 |
145 | """
146 | if self.allow_cache and len(self.caches[idx]) != 0:
147 | return self.caches[idx]
148 | # load audio and features
149 | audio, sr = sf.read(to_absolute_path(self.audio_files[idx]))
150 | # audio & feature pre-processing
151 | audio = audio.astype(np.float32)
152 |
153 | # get auxiliary features
154 | aux_feats = []
155 | for feat_type in self.aux_feats:
156 | aux_feat = read_hdf5(
157 | to_absolute_path(self.feat_files[idx]), f"/{feat_type}"
158 | )
159 | if feat_type != "uv":
160 | aux_feat = self.scaler[f"{feat_type}"].transform(aux_feat)
161 | aux_feats += [aux_feat]
162 | aux_feats = np.concatenate(aux_feats, axis=1)
163 |
164 | # adjust length
165 | audio, aux_feats = validate_length(audio, aux_feats, self.hop_size)
166 |
167 | # get dilated factor sequence
168 | f0 = read_hdf5(to_absolute_path(self.feat_files[idx]), "/f0") # descrete F0
169 | contf0 = read_hdf5(
170 | to_absolute_path(self.feat_files[idx]), "/contf0"
171 | ) # continuous F0
172 | aux_feats, f0 = validate_length(aux_feats, f0)
173 | f0, contf0 = validate_length(f0, contf0)
174 | if self.df_f0_type == "contf0":
175 | df = dilated_factor(
176 | np.squeeze(contf0.copy()), self.sample_rate, self.dense_factor
177 | )
178 | else:
179 | df = dilated_factor(
180 | np.squeeze(f0.copy()), self.sample_rate, self.dense_factor
181 | )
182 | df = df.repeat(self.hop_size, axis=0)
183 |
184 | if self.return_filename:
185 | items = self.feat_files[idx], audio, aux_feats, df, f0, contf0
186 | else:
187 | items = audio, aux_feats, df, f0, contf0
188 |
189 | if self.allow_cache:
190 | self.caches[idx] = items
191 |
192 | return items
193 |
194 | def __len__(self):
195 | """Return dataset length.
196 |
197 | Returns:
198 | int: The length of dataset.
199 |
200 | """
201 | return len(self.audio_files)
202 |
203 |
204 | class FeatDataset(Dataset):
205 | """PyTorch compatible mel dataset."""
206 |
207 | def __init__(
208 | self,
209 | stats,
210 | feat_list,
211 | feat_length_threshold=None,
212 | return_filename=False,
213 | allow_cache=False,
214 | sample_rate=24000,
215 | hop_size=120,
216 | dense_factor=4,
217 | df_f0_type="contf0",
218 | aux_feats=["mcep", "mcap"],
219 | f0_factor=1.0,
220 | ):
221 | """Initialize dataset.
222 |
223 | Args:
224 | stats (str): Filename of the statistic hdf5 file.
225 | feat_list (str): Filename of the list of feature files.
226 | feat_length_threshold (int): Threshold to remove short feature files.
227 | return_filename (bool): Whether to return the utterance id with arrays.
228 | allow_cache (bool): Whether to allow cache of the loaded files.
229 | hop_size (int): Hope size of acoustic feature
230 | dense_factor (int): Number of taps in one cycle.
231 | aux_feats (str): Type of auxiliary features.
232 | f0_factor (float): Ratio of scaled f0.
233 |
234 | """
235 | # load feat. files
236 | feat_files = read_txt(to_absolute_path(feat_list))
237 |
238 | # filter by threshold
239 | if feat_length_threshold is not None:
240 | f0_lengths = [
241 | read_hdf5(to_absolute_path(f), "/f0").shape[0] for f in feat_files
242 | ]
243 | idxs = [
244 | idx
245 | for idx in range(len(feat_files))
246 | if f0_lengths[idx] > feat_length_threshold
247 | ]
248 | if len(feat_files) != len(idxs):
249 | logger.warning(
250 | f"Some files are filtered by mel length threshold "
251 | f"({len(feat_files)} -> {len(idxs)})."
252 | )
253 | feat_files = [feat_files[idx] for idx in idxs]
254 |
255 | # assert the number of files
256 | assert len(feat_files) != 0, f"${feat_list} is empty."
257 |
258 | self.feat_files = feat_files
259 | self.return_filename = return_filename
260 | self.allow_cache = allow_cache
261 | self.sample_rate = sample_rate
262 | self.hop_size = hop_size
263 | self.dense_factor = dense_factor
264 | self.df_f0_type = df_f0_type
265 | self.aux_feats = aux_feats
266 | self.f0_factor = f0_factor
267 | logger.info(f"Feature type : {self.aux_feats}")
268 |
269 | if allow_cache:
270 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
271 | self.manager = Manager()
272 | self.caches = self.manager.list()
273 | self.caches += [() for _ in range(len(feat_files))]
274 |
275 | # define feature pre-processing function
276 | self.scaler = load(stats)
277 |
278 | def __getitem__(self, idx):
279 | """Get specified idx items.
280 |
281 | Args:
282 | idx (int): Index of the item.
283 |
284 | Returns:
285 | str: Utterance id (only in return_filename = True).
286 | ndarray: UV feature (T', C).
287 | ndarray: F0 feature (T', C).
288 | ndarray: Mel-spectrogram feature (T', C).
289 | ndarray: Auxiliary feature (T', C).
290 | ndarray: Dilated factor (T, 1).
291 |
292 | """
293 | if self.allow_cache and len(self.caches[idx]) != 0:
294 | return self.caches[idx]
295 |
296 | # get auxiliary features
297 | aux_feats = []
298 | for feat_type in self.aux_feats:
299 | aux_feat = read_hdf5(
300 | to_absolute_path(self.feat_files[idx]), f"/{feat_type}"
301 | )
302 | if "f0" in feat_type:
303 | aux_feat *= self.f0_factor
304 | if feat_type != "uv":
305 | aux_feat = self.scaler[f"{feat_type}"].transform(aux_feat)
306 | aux_feats += [aux_feat]
307 | aux_feats = np.concatenate(aux_feats, axis=1)
308 |
309 | # get dilated factor sequence
310 | f0 = read_hdf5(to_absolute_path(self.feat_files[idx]), "/f0") # descrete F0
311 | contf0 = read_hdf5(
312 | to_absolute_path(self.feat_files[idx]), "/contf0"
313 | ) # continuous F0
314 | f0 *= self.f0_factor
315 | contf0 *= self.f0_factor
316 | aux_feats, f0 = validate_length(aux_feats, f0)
317 | f0, contf0 = validate_length(f0, contf0)
318 | if self.df_f0_type == "contf0":
319 | df = dilated_factor(
320 | np.squeeze(contf0.copy()), self.sample_rate, self.dense_factor
321 | )
322 | else:
323 | df = dilated_factor(
324 | np.squeeze(f0.copy()), self.sample_rate, self.dense_factor
325 | )
326 | df = df.repeat(self.hop_size, axis=0)
327 |
328 | if self.return_filename:
329 | items = self.feat_files[idx], aux_feats, df, f0, contf0
330 | else:
331 | items = aux_feats, df, f0, contf0
332 |
333 | if self.allow_cache:
334 | self.caches[idx] = items
335 |
336 | return items
337 |
338 | def __len__(self):
339 | """Return dataset length.
340 |
341 | Returns:
342 | int: The length of dataset.
343 |
344 | """
345 | return len(self.feat_files)
346 |
--------------------------------------------------------------------------------
/usfgan/bin/extract_features.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Feature extraction script.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 | - https://github.com/kan-bayashi/ParallelWaveGAN
11 | - https://github.com/k2kobayashi/sprocket
12 |
13 | """
14 |
15 | import copy
16 | import multiprocessing as mp
17 | import os
18 | import sys
19 | from logging import getLogger
20 |
21 | import hydra
22 | import librosa
23 | import numpy as np
24 | import pysptk
25 | import pyworld
26 | import soundfile as sf
27 | import yaml
28 | from hydra.utils import to_absolute_path
29 | from omegaconf import DictConfig, OmegaConf
30 | from scipy.interpolate import interp1d
31 | from scipy.signal import firwin, lfilter
32 | from usfgan.utils import read_txt, write_hdf5
33 |
34 | # A logger for this file
35 | logger = getLogger(__name__)
36 |
37 |
38 | def path_create(wav_list, in_dir, out_dir, extname):
39 | for wav_name in wav_list:
40 | path_replace(wav_name, in_dir, out_dir, extname=extname)
41 |
42 |
43 | def path_replace(filepath, inputpath, outputpath, extname=None):
44 | if extname is not None:
45 | filepath = f"{filepath.split('.')[0]}.{extname}"
46 | filepath = filepath.replace(inputpath, outputpath)
47 | if not os.path.exists(os.path.dirname(filepath)):
48 | os.makedirs(os.path.dirname(filepath))
49 | return filepath
50 |
51 |
52 | def spk_division(file_list, config, spkinfo, split="/"):
53 | """Divide list into speaker-dependent list
54 |
55 | Args:
56 | file_list (list): Waveform list
57 | config (dict): Config
58 | spkinfo (dict): Dictionary of
59 | speaker-dependent f0 range and power threshold
60 | split: Path split string
61 |
62 | Return:
63 | (list): List of divided file lists
64 | (list): List of speaker-dependent configs
65 |
66 | """
67 | file_lists, configs, tempf = [], [], []
68 | prespk = None
69 | for file in file_list:
70 | spk = file.split(split)[config.spkidx]
71 | if spk != prespk:
72 | if tempf:
73 | file_lists.append(tempf)
74 | tempf = []
75 | prespk = spk
76 | tempc = copy.deepcopy(config)
77 | if spk in spkinfo:
78 | tempc["minf0"] = spkinfo[spk]["f0_min"]
79 | tempc["maxf0"] = spkinfo[spk]["f0_max"]
80 | tempc["pow_th"] = spkinfo[spk]["pow_th"]
81 | else:
82 | msg = f"Since {spk} is not in spkinfo dict, "
83 | msg += "default f0 range and power threshold are used."
84 | logger.info(msg)
85 | tempc["minf0"] = 70
86 | tempc["maxf0"] = 300
87 | tempc["pow_th"] = -20
88 | configs.append(tempc)
89 | tempf.append(file)
90 | file_lists.append(tempf)
91 |
92 | return file_lists, configs
93 |
94 |
95 | def aux_list_create(wav_list_file, config):
96 | """Create list of auxiliary acoustic features
97 |
98 | Args:
99 | wav_list_file (str): Filename of wav list
100 | config (dict): Config
101 |
102 | """
103 | aux_list_file = wav_list_file.replace(".scp", ".list")
104 | wav_files = read_txt(wav_list_file)
105 | with open(aux_list_file, "w") as f:
106 | for wav_name in wav_files:
107 | feat_name = path_replace(
108 | wav_name,
109 | config.in_dir,
110 | config.out_dir,
111 | extname=config.feature_format,
112 | )
113 | f.write(f"{feat_name}\n")
114 |
115 |
116 | def low_cut_filter(x, fs, cutoff=70):
117 | """Low cut filter
118 |
119 | Args:
120 | x (ndarray): Waveform sequence
121 | fs (int): Sampling frequency
122 | cutoff (float): Cutoff frequency of low cut filter
123 |
124 | Return:
125 | (ndarray): Low cut filtered waveform sequence
126 |
127 | """
128 | nyquist = fs // 2
129 | norm_cutoff = cutoff / nyquist
130 | fil = firwin(255, norm_cutoff, pass_zero=False)
131 | lcf_x = lfilter(fil, 1, x)
132 |
133 | return lcf_x
134 |
135 |
136 | def low_pass_filter(x, fs, cutoff=70, padding=True):
137 | """Low pass filter
138 |
139 | Args:
140 | x (ndarray): Waveform sequence
141 | fs (int): Sampling frequency
142 | cutoff (float): Cutoff frequency of low pass filter
143 |
144 | Return:
145 | (ndarray): Low pass filtered waveform sequence
146 |
147 | """
148 | nyquist = fs // 2
149 | norm_cutoff = cutoff / nyquist
150 | numtaps = 255
151 | fil = firwin(numtaps, norm_cutoff)
152 | x_pad = np.pad(x, (numtaps, numtaps), "edge")
153 | lpf_x = lfilter(fil, 1, x_pad)
154 | lpf_x = lpf_x[numtaps + numtaps // 2 : -numtaps // 2]
155 |
156 | return lpf_x
157 |
158 |
159 | def convert_continuos_f0(f0):
160 | """Convert F0 to continuous F0
161 |
162 | Args:
163 | f0 (ndarray): original f0 sequence with the shape (T)
164 |
165 | Return:
166 | (ndarray): continuous f0 with the shape (T)
167 |
168 | """
169 | # get uv information as binary
170 | uv = np.float32(f0 != 0)
171 | # get start and end of f0
172 | if (f0 == 0).all():
173 | logger.warn("all of the f0 values are 0.")
174 | return uv, f0, False
175 | start_f0 = f0[f0 != 0][0]
176 | end_f0 = f0[f0 != 0][-1]
177 | # padding start and end of f0 sequence
178 | cont_f0 = copy.deepcopy(f0)
179 | start_idx = np.where(cont_f0 == start_f0)[0][0]
180 | end_idx = np.where(cont_f0 == end_f0)[0][-1]
181 | cont_f0[:start_idx] = start_f0
182 | cont_f0[end_idx:] = end_f0
183 | # get non-zero frame index
184 | nz_frames = np.where(cont_f0 != 0)[0]
185 | # perform linear interpolation
186 | f = interp1d(nz_frames, cont_f0[nz_frames])
187 | cont_f0 = f(np.arange(0, cont_f0.shape[0]))
188 |
189 | return uv, cont_f0, True
190 |
191 |
192 | # Mel-spectrogram and F0 features
193 | def melfilterbank(
194 | audio,
195 | sampling_rate,
196 | fft_size=1024,
197 | hop_size=256,
198 | win_length=None,
199 | window="hann",
200 | num_mels=80,
201 | fmin=None,
202 | fmax=None,
203 | ):
204 | """Extract linear mel filterbank feature.
205 |
206 | Args:
207 | audio (ndarray): Audio signal (T,).
208 | sampling_rate (int): Sampling rate.
209 | fft_size (int): FFT size.
210 | hop_size (int): Hop size.
211 | win_length (int): Window length. If set to None, it will be the same as fft_size.
212 | window (str): Window function type.
213 | num_mels (int): Number of mel basis.
214 | fmin (int): Minimum frequency in mel basis calculation.
215 | fmax (int): Maximum frequency in mel basis calculation.
216 |
217 | Returns:
218 | ndarray: Linear mel filterbank feature (#frames, num_mels).
219 |
220 | """
221 | # get amplitude spectrogram
222 | x_stft = librosa.stft(
223 | audio,
224 | n_fft=fft_size,
225 | hop_length=hop_size,
226 | win_length=win_length,
227 | window=window,
228 | pad_mode="reflect",
229 | )
230 | spc = np.abs(x_stft).T # (#frames, #bins)
231 |
232 | # get mel basis
233 | fmin = 0 if fmin is None else fmin
234 | fmax = sampling_rate / 2 if fmax is None else fmax
235 | mel_basis = librosa.filters.mel(
236 | sr=sampling_rate,
237 | n_fft=fft_size,
238 | n_mels=num_mels,
239 | fmin=fmin,
240 | fmax=fmax,
241 | )
242 |
243 | return np.dot(spc, mel_basis.T)
244 |
245 |
246 | def melf0_feature_extract(queue, wav_list, config, eps=1e-7):
247 | """Mel-spc w/ discrete F0 and continuous F0 feature extraction
248 |
249 | Args:
250 | queue (multiprocessing.Queue): the queue to store the file name of utterance
251 | wav_list (list): list of the wav files
252 | config (dict): feature extraction config
253 |
254 | """
255 | # extraction
256 | for i, wav_name in enumerate(wav_list):
257 | logger.info(f"now processing {wav_name} ({i + 1}/{len(wav_list)})")
258 |
259 | # load wavfile
260 | x, fs = sf.read(to_absolute_path(wav_name))
261 | x = np.array(x, dtype=np.float)
262 |
263 | # check sampling frequency
264 | if not fs == config.sampling_rate:
265 | logger.error("sampling frequency is not matched.")
266 | sys.exit(1)
267 |
268 | # apply low-cut-filter
269 | if config.highpass_cutoff > 0:
270 | if (x == 0).all():
271 | logger.info(f"xxxxx {wav_name}")
272 | continue
273 | x = low_cut_filter(x, fs, cutoff=config.highpass_cutoff)
274 |
275 | # extract WORLD features
276 | f0, t = pyworld.harvest(
277 | x,
278 | fs=config.sampling_rate,
279 | f0_floor=config.minf0,
280 | f0_ceil=config.maxf0,
281 | frame_period=config.shiftms,
282 | )
283 | env = pyworld.cheaptrick(
284 | x,
285 | f0,
286 | t,
287 | fs=config.sampling_rate,
288 | fft_size=config.fft_size,
289 | )
290 | ap = pyworld.d4c(
291 | x,
292 | f0,
293 | t,
294 | fs=config.sampling_rate,
295 | fft_size=config.fft_size,
296 | )
297 | uv, cont_f0, is_all_uv = convert_continuos_f0(f0)
298 | if is_all_uv:
299 | lpf_fs = int(config.sampling_rate / config.hop_size)
300 | cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=20)
301 | next_cutoff = 70
302 | while not (cont_f0_lpf >= [0]).all():
303 | cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=next_cutoff)
304 | next_cutoff *= 2
305 | else:
306 | cont_f0_lpf = cont_f0
307 | logger.warn(f"all of the f0 values are 0 {wav_name}.")
308 | mcep = pysptk.sp2mc(env, order=config.mcep_dim, alpha=config.alpha)
309 | mcap = pysptk.sp2mc(ap, order=config.mcap_dim, alpha=config.alpha)
310 | codeap = pyworld.code_aperiodicity(ap, config.sampling_rate)
311 |
312 | # extract mel-spectrogram
313 | msp = melfilterbank(
314 | x,
315 | config.sampling_rate,
316 | fft_size=config.fft_size,
317 | hop_size=config.hop_size,
318 | win_length=config.win_length,
319 | window=config.window,
320 | num_mels=config.num_mels,
321 | fmin=config.fmin,
322 | fmax=config.fmax,
323 | )
324 |
325 | # adjust shapes
326 | minlen = min(uv.shape[0], msp.shape[0])
327 | uv = uv[:minlen]
328 | uv = np.expand_dims(uv, axis=-1)
329 | f0 = f0[:minlen]
330 | f0 = np.expand_dims(f0, axis=-1)
331 | cont_f0_lpf = cont_f0_lpf[:minlen]
332 | cont_f0_lpf = np.expand_dims(cont_f0_lpf, axis=-1)
333 | mcep = mcep[:minlen]
334 | mcap = mcap[:minlen]
335 | codeap = codeap[:minlen]
336 | logmsp = np.log10(np.maximum(eps, msp))
337 |
338 | # save features
339 | feat_name = path_replace(
340 | wav_name,
341 | config.in_dir,
342 | config.out_dir,
343 | extname=config.feature_format,
344 | )
345 | write_hdf5(to_absolute_path(feat_name), "/uv", uv)
346 | write_hdf5(to_absolute_path(feat_name), "/f0", f0)
347 | write_hdf5(to_absolute_path(feat_name), "/contf0", cont_f0_lpf)
348 | write_hdf5(to_absolute_path(feat_name), "/mcep", mcep)
349 | write_hdf5(to_absolute_path(feat_name), "/mcap", mcap)
350 | write_hdf5(to_absolute_path(feat_name), "/codeap", codeap)
351 | write_hdf5(to_absolute_path(feat_name), "/logmsp", logmsp)
352 |
353 | queue.put("Finish")
354 |
355 |
356 | @hydra.main(version_base=None, config_path="config", config_name="extract_features")
357 | def main(config: DictConfig):
358 | # show argument
359 | logger.info(OmegaConf.to_yaml(config))
360 |
361 | # read list
362 | file_list = read_txt(to_absolute_path(config.audio))
363 | logger.info(f"number of utterances = {len(file_list)}")
364 |
365 | # list division
366 | if config.spkinfo and os.path.exists(to_absolute_path(config.spkinfo)):
367 | # load speaker info
368 | with open(to_absolute_path(config.spkinfo), "r") as f:
369 | spkinfo = yaml.safe_load(f)
370 | logger.info(f"Spkinfo {config.spkinfo} is used.")
371 | # divide into each spk list
372 | file_lists, configs = spk_division(file_list, config, spkinfo)
373 | else:
374 | logger.info(
375 | f"Since spkinfo {config.spkinfo} is not exist, default f0 range and power threshold are used."
376 | )
377 | file_lists = np.array_split(file_list, 10)
378 | file_lists = [f_list.tolist() for f_list in file_lists]
379 | configs = [config] * len(file_lists)
380 |
381 | # set mode
382 | if config.inv:
383 | target_fn = melf0_feature_extract
384 | # create auxiliary feature list
385 | aux_list_create(to_absolute_path(config.audio), config)
386 | # create folder
387 | path_create(file_list, config.in_dir, config.out_dir, config.feature_format)
388 |
389 | # multi processing
390 | processes = []
391 | queue = mp.Queue()
392 | for f, _config in zip(file_lists, configs):
393 | p = mp.Process(
394 | target=target_fn,
395 | args=(queue, f, _config),
396 | )
397 | p.start()
398 | processes.append(p)
399 |
400 | # wait for all process
401 | for p in processes:
402 | p.join()
403 |
404 |
405 | if __name__ == "__main__":
406 | main()
407 |
--------------------------------------------------------------------------------
/usfgan/layers/residual_block.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Residual block modules.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 | - https://github.com/kan-bayashi/ParallelWaveGAN
11 | - https://github.com/r9y9/wavenet_vocoder
12 |
13 | """
14 |
15 | import math
16 | import sys
17 | from logging import getLogger
18 |
19 | import torch
20 | import torch.nn as nn
21 | from usfgan.utils import pd_indexing
22 |
23 | # A logger for this file
24 | logger = getLogger(__name__)
25 |
26 |
27 | class Conv1d(nn.Conv1d):
28 | """Conv1d module with customized initialization."""
29 |
30 | def __init__(self, *args, **kwargs):
31 | """Initialize Conv1d module."""
32 | super(Conv1d, self).__init__(*args, **kwargs)
33 |
34 | def reset_parameters(self):
35 | """Reset parameters."""
36 | nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
37 | if self.bias is not None:
38 | nn.init.constant_(self.bias, 0.0)
39 |
40 |
41 | class Conv1d1x1(Conv1d):
42 | """1x1 Conv1d with customized initialization."""
43 |
44 | def __init__(self, in_channels, out_channels, bias=True):
45 | """Initialize 1x1 Conv1d module."""
46 | super(Conv1d1x1, self).__init__(
47 | in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
48 | )
49 |
50 |
51 | class Conv2d(nn.Conv2d):
52 | """Conv2d module with customized initialization."""
53 |
54 | def __init__(self, *args, **kwargs):
55 | """Initialize Conv2d module."""
56 | super(Conv2d, self).__init__(*args, **kwargs)
57 |
58 | def reset_parameters(self):
59 | """Reset parameters."""
60 | nn.init.kaiming_normal_(self.weight, mode="fan_out", nonlinearity="relu")
61 | if self.bias is not None:
62 | nn.init.constant_(self.bias, 0.0)
63 |
64 |
65 | class Conv2d1x1(Conv2d):
66 | """1x1 Conv2d with customized initialization."""
67 |
68 | def __init__(self, in_channels, out_channels, bias=True):
69 | """Initialize 1x1 Conv2d module."""
70 | super(Conv2d1x1, self).__init__(
71 | in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
72 | )
73 |
74 |
75 | class FixedBlock(nn.Module):
76 | """Fixed block module in QPPWG."""
77 |
78 | def __init__(
79 | self,
80 | residual_channels=64,
81 | gate_channels=128,
82 | skip_channels=64,
83 | aux_channels=80,
84 | kernel_size=3,
85 | dilation=1,
86 | bias=True,
87 | ):
88 | """Initialize Fixed ResidualBlock module.
89 |
90 | Args:
91 | residual_channels (int): Number of channels for residual connection.
92 | skip_channels (int): Number of channels for skip connection.
93 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
94 | dilation (int): Dilation size.
95 | bias (bool): Whether to add bias parameter in convolution layers.
96 |
97 | """
98 | super(FixedBlock, self).__init__()
99 | padding = (kernel_size - 1) // 2 * dilation
100 |
101 | # dilation conv
102 | self.conv = Conv1d(
103 | residual_channels,
104 | gate_channels,
105 | kernel_size,
106 | padding=padding,
107 | padding_mode="reflect",
108 | dilation=dilation,
109 | bias=bias,
110 | )
111 |
112 | # local conditioning
113 | if aux_channels > 0:
114 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
115 | else:
116 | self.conv1x1_aux = None
117 |
118 | # conv output is split into two groups
119 | gate_out_channels = gate_channels // 2
120 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
121 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
122 |
123 | def forward(self, x, c):
124 | """Calculate forward propagation.
125 |
126 | Args:
127 | x (Tensor): Input tensor (B, residual_channels, T).
128 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
129 |
130 | Returns:
131 | Tensor: Output tensor for residual connection (B, residual_channels, T).
132 | Tensor: Output tensor for skip connection (B, skip_channels, T).
133 |
134 | """
135 | residual = x
136 | x = self.conv(x)
137 |
138 | # split into two part for gated activation
139 | splitdim = 1
140 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
141 |
142 | # local conditioning
143 | if c is not None:
144 | assert self.conv1x1_aux is not None
145 | c = self.conv1x1_aux(c)
146 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
147 | xa, xb = xa + ca, xb + cb
148 |
149 | x = torch.tanh(xa) * torch.sigmoid(xb)
150 |
151 | # for skip connection
152 | s = self.conv1x1_skip(x)
153 |
154 | # for residual connection
155 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
156 |
157 | return x, s
158 |
159 |
160 | class AdaptiveBlock(nn.Module):
161 | """Adaptive block module in QPPWG."""
162 |
163 | def __init__(
164 | self,
165 | residual_channels=64,
166 | gate_channels=128,
167 | skip_channels=64,
168 | aux_channels=80,
169 | bias=True,
170 | ):
171 | """Initialize Adaptive ResidualBlock module.
172 |
173 | Args:
174 | residual_channels (int): Number of channels for residual connection.
175 | skip_channels (int): Number of channels for skip connection.
176 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
177 | bias (bool): Whether to add bias parameter in convolution layers.
178 |
179 | """
180 | super(AdaptiveBlock, self).__init__()
181 |
182 | # pitch-dependent dilation conv
183 | self.convP = Conv1d1x1(residual_channels, gate_channels, bias=bias) # past
184 | self.convC = Conv1d1x1(residual_channels, gate_channels, bias=bias) # current
185 | self.convF = Conv1d1x1(residual_channels, gate_channels, bias=bias) # future
186 |
187 | # local conditioning
188 | if aux_channels > 0:
189 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
190 | else:
191 | self.conv1x1_aux = None
192 |
193 | # conv output is split into two groups
194 | gate_out_channels = gate_channels // 2
195 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
196 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
197 |
198 | def forward(self, xC, xP, xF, c):
199 | """Calculate forward propagation.
200 |
201 | Args:
202 | xC (Tensor): Current input tensor (B, residual_channels, T).
203 | xP (Tensor): Past input tensor (B, residual_channels, T).
204 | xF (Tensor): Future input tensor (B, residual_channels, T).
205 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
206 |
207 | Returns:
208 | Tensor: Output tensor for residual connection (B, residual_channels, T).
209 | Tensor: Output tensor for skip connection (B, skip_channels, T).
210 |
211 | """
212 | residual = xC
213 | x = self.convC(xC) + self.convP(xP) + self.convF(xF)
214 |
215 | # split into two part for gated activation
216 | splitdim = 1
217 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
218 |
219 | # local conditioning
220 | if c is not None:
221 | assert self.conv1x1_aux is not None
222 | c = self.conv1x1_aux(c)
223 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
224 | xa, xb = xa + ca, xb + cb
225 |
226 | x = torch.tanh(xa) * torch.sigmoid(xb)
227 |
228 | # for skip connection
229 | s = self.conv1x1_skip(x)
230 |
231 | # for residual connection
232 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
233 |
234 | return x, s
235 |
236 |
237 | class ResidualBlocks(nn.Module):
238 | """Multiple residual blocks stacking module."""
239 |
240 | def __init__(
241 | self,
242 | blockA,
243 | cycleA,
244 | blockF,
245 | cycleF,
246 | cascade_mode=0,
247 | residual_channels=64,
248 | gate_channels=128,
249 | skip_channels=64,
250 | aux_channels=80,
251 | ):
252 | """Initialize ResidualBlocks module.
253 |
254 | Args:
255 | blockA (int): Number of adaptive residual blocks.
256 | cycleA (int): Number of dilation cycles of adaptive residual blocks.
257 | blockF (int): Number of fixed residual blocks.
258 | cycleF (int): Number of dilation cycles of fixed residual blocks.
259 | cascade_mode (int): Cascaded mode (0: Adaptive->Fixed; 1: Fixed->Adaptive).
260 | residual_channels (int): Number of channels in residual conv.
261 | gate_channels (int): Number of channels in gated conv.
262 | skip_channels (int): Number of channels in skip conv.
263 | aux_channels (int): Number of channels for auxiliary feature conv.
264 |
265 | """
266 | super(ResidualBlocks, self).__init__()
267 |
268 | # check the number of blocks and cycles
269 | cycleA = max(cycleA, 1)
270 | cycleF = max(cycleF, 1)
271 | assert blockA % cycleA == 0
272 | self.blockA_per_cycle = blockA // cycleA
273 | assert blockF % cycleF == 0
274 | blockF_per_cycle = blockF // cycleF
275 |
276 | # define adaptive residual blocks
277 | adaptive_blocks = nn.ModuleList()
278 | for block in range(blockA):
279 | conv = AdaptiveBlock(
280 | residual_channels=residual_channels,
281 | gate_channels=gate_channels,
282 | skip_channels=skip_channels,
283 | aux_channels=aux_channels,
284 | )
285 | adaptive_blocks += [conv]
286 |
287 | # define fixed residual blocks
288 | fixed_blocks = nn.ModuleList()
289 | for block in range(blockF):
290 | dilation = 2 ** (block % blockF_per_cycle)
291 | conv = FixedBlock(
292 | residual_channels=residual_channels,
293 | gate_channels=gate_channels,
294 | skip_channels=skip_channels,
295 | aux_channels=aux_channels,
296 | dilation=dilation,
297 | )
298 | fixed_blocks += [conv]
299 |
300 | # define cascaded structure
301 | if cascade_mode == 0: # adaptive->fixed
302 | self.conv_dilated = adaptive_blocks.extend(fixed_blocks)
303 | self.block_modes = [True] * blockA + [False] * blockF
304 | elif cascade_mode == 1: # fixed->adaptive
305 | self.conv_dilated = fixed_blocks.extend(adaptive_blocks)
306 | self.block_modes = [False] * blockF + [True] * blockA
307 | else:
308 | logger.error(f"Cascaded mode {cascade_mode} is not supported!")
309 | sys.exit(0)
310 |
311 | def forward(self, x, c, d, batch_index, ch_index):
312 | """Calculate forward propagation.
313 |
314 | Args:
315 | x (Tensor): Input noise signal (B, 1, T).
316 | c (Tensor): Local conditioning auxiliary features (B, C ,T).
317 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T).
318 |
319 | Returns:
320 | Tensor: Output tensor (B, residual_channels, T).
321 |
322 | """
323 | skips = 0
324 | blockA_idx = 0
325 | for f, mode in zip(self.conv_dilated, self.block_modes):
326 | if mode: # adaptive block
327 | dilation = 2 ** (blockA_idx % self.blockA_per_cycle)
328 | xP, xF = pd_indexing(x, d, dilation, batch_index, ch_index)
329 | x, h = f(x, xP, xF, c)
330 | blockA_idx += 1
331 | else: # fixed block
332 | x, h = f(x, c)
333 | skips = h + skips
334 | skips *= math.sqrt(1.0 / len(self.conv_dilated))
335 |
336 | return x
337 |
338 |
339 | class PeriodicityEstimator(nn.Module):
340 | """Periodicity estimator module."""
341 |
342 | def __init__(
343 | self,
344 | in_channels,
345 | residual_channels=64,
346 | conv_layers=3,
347 | kernel_size=5,
348 | dilation=1,
349 | padding_mode="replicate",
350 | ):
351 | """Initialize USFGANGenerator module.
352 |
353 | Args:
354 | in_channels (int): Number of input channels.
355 | residual_channels (int): Number of channels in residual conv.
356 | conv_layers (int): # Number of convolution layers.
357 | kernel_size (int): Kernel size.
358 | dilation (int): Dilation size.
359 | padding_mode (str): Padding mode.
360 |
361 | """
362 | super(PeriodicityEstimator, self).__init__()
363 |
364 | modules = []
365 | for idx in range(conv_layers):
366 | conv1d = Conv1d(
367 | in_channels,
368 | residual_channels,
369 | kernel_size=kernel_size,
370 | dilation=dilation,
371 | padding=kernel_size // 2 * dilation,
372 | padding_mode=padding_mode,
373 | )
374 |
375 | # initialize the initial outputs sigmoid(0)=0.5 to stabilize training
376 | if idx != conv_layers - 1:
377 | nonlinear = nn.ReLU(inplace=True)
378 | else:
379 | # NOTE: zero init induces nan or inf if weight normalization is used
380 | # nn.init.zeros_(conv1d.weight)
381 | nn.init.normal_(conv1d.weight, std=1e-4)
382 | nonlinear = nn.Sigmoid()
383 |
384 | modules += [conv1d, nonlinear]
385 | in_channels = residual_channels
386 |
387 | self.layers = nn.Sequential(*modules)
388 |
389 | def forward(self, x):
390 | """Calculate forward propagation.
391 |
392 | Args:
393 | x (Tensor): Input auxiliary features (B, C ,T).
394 |
395 | Returns:
396 | Tensor: Output tensor (B, residual_channels, T).
397 |
398 | """
399 | return self.layers(x)
400 |
--------------------------------------------------------------------------------
/usfgan/models/generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Unified Source-Filter GAN Generator modules."""
7 |
8 | from logging import getLogger
9 |
10 | import torch
11 | import torch.nn as nn
12 | from usfgan.layers import Conv1d1x1, ResidualBlocks, upsample
13 | from usfgan.layers.residual_block import PeriodicityEstimator
14 | from usfgan.utils import index_initial
15 |
16 | # A logger for this file
17 | logger = getLogger(__name__)
18 |
19 |
20 | class USFGANGenerator(nn.Module):
21 | """Unified Source-Filter GAN Generator module."""
22 |
23 | def __init__(
24 | self,
25 | source_network_params={
26 | "blockA": 30,
27 | "cycleA": 3,
28 | "blockF": 0,
29 | "cycleF": 0,
30 | "cascade_mode": 0,
31 | },
32 | filter_network_params={
33 | "blockA": 0,
34 | "cycleA": 0,
35 | "blockF": 30,
36 | "cycleF": 3,
37 | "cascade_mode": 0,
38 | },
39 | in_channels=1,
40 | out_channels=1,
41 | residual_channels=64,
42 | gate_channels=128,
43 | skip_channels=64,
44 | aux_channels=80,
45 | aux_context_window=2,
46 | use_weight_norm=True,
47 | upsample_params={"upsample_scales": [5, 4, 3, 2]},
48 | ):
49 | """Initialize USFGANGenerator module.
50 |
51 | Args:
52 | source_network_params (dict): Source-network parameters.
53 | filter_network_params (dict): Filter-network parameters.
54 | in_channels (int): Number of input channels.
55 | out_channels (int): Number of output channels.
56 | residual_channels (int): Number of channels in residual conv.
57 | gate_channels (int): Number of channels in gated conv.
58 | skip_channels (int): Number of channels in skip conv.
59 | aux_channels (int): Number of channels for auxiliary feature conv.
60 | aux_context_window (int): Context window size for auxiliary feature.
61 | use_weight_norm (bool): Whether to use weight norm.
62 | If set to true, it will be applied to all of the conv layers.
63 | upsample_params (dict): Upsampling network parameters.
64 |
65 | """
66 | super(USFGANGenerator, self).__init__()
67 | self.in_channels = in_channels
68 | self.out_channels = out_channels
69 | self.aux_channels = aux_channels
70 | self.n_ch = residual_channels
71 |
72 | # define first convolution
73 | self.conv_first = Conv1d1x1(in_channels, residual_channels)
74 |
75 | # define upsampling network
76 | self.upsample_net = getattr(upsample, "ConvInUpsampleNetwork")(
77 | **upsample_params,
78 | aux_channels=aux_channels,
79 | aux_context_window=aux_context_window,
80 | )
81 |
82 | # define source/filter networks
83 | for params in [
84 | source_network_params,
85 | filter_network_params,
86 | ]:
87 | params.update(
88 | {
89 | "residual_channels": residual_channels,
90 | "gate_channels": gate_channels,
91 | "skip_channels": skip_channels,
92 | "aux_channels": aux_channels,
93 | }
94 | )
95 | self.source_network = ResidualBlocks(**source_network_params)
96 | self.filter_network = ResidualBlocks(**filter_network_params)
97 |
98 | # convert source signal to hidden representation
99 | self.conv_mid = Conv1d1x1(out_channels, skip_channels)
100 |
101 | # convert hidden representation to output signal
102 | self.conv_last = nn.Sequential(
103 | nn.ReLU(),
104 | Conv1d1x1(skip_channels, skip_channels),
105 | nn.ReLU(),
106 | Conv1d1x1(skip_channels, out_channels),
107 | )
108 |
109 | # apply weight norm
110 | if use_weight_norm:
111 | self.apply_weight_norm()
112 |
113 | def forward(self, x, c, d):
114 | """Calculate forward propagation.
115 |
116 | Args:
117 | x (Tensor): Input noise signal (B, 1, T).
118 | c (Tensor): Local conditioning auxiliary features (B, C ,T').
119 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T).
120 |
121 | Returns:
122 | Tensor: Output tensor (B, 1, T)
123 |
124 | """
125 | # index initialization
126 | batch_index, ch_index = index_initial(x.size(0), self.n_ch)
127 |
128 | # perform upsampling
129 | c = self.upsample_net(c)
130 | assert c.size(-1) == x.size(-1)
131 |
132 | # encode to hidden representation
133 | x = self.conv_first(x)
134 |
135 | # source excitation generation
136 | x = self.source_network(x, c, d, batch_index, ch_index)
137 | s = self.conv_last(x)
138 | x = self.conv_mid(s)
139 |
140 | # resonance filtering
141 | x = self.filter_network(x, c, d, batch_index, ch_index)
142 | x = self.conv_last(x)
143 |
144 | return x, s
145 |
146 | def remove_weight_norm(self):
147 | """Remove weight normalization module from all of the layers."""
148 |
149 | def _remove_weight_norm(m):
150 | try:
151 | logger.debug(f"Weight norm is removed from {m}.")
152 | nn.utils.remove_weight_norm(m)
153 | except ValueError: # this module didn't have weight norm
154 | return
155 |
156 | self.apply(_remove_weight_norm)
157 |
158 | def apply_weight_norm(self):
159 | """Apply weight normalization module from all of the layers."""
160 |
161 | def _apply_weight_norm(m):
162 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
163 | nn.utils.weight_norm(m)
164 | logger.debug(f"Weight norm is applied to {m}.")
165 |
166 | self.apply(_apply_weight_norm)
167 |
168 |
169 | class CascadeHnUSFGANGenerator(nn.Module):
170 | """Cascade hn-uSFGAN Generator module."""
171 |
172 | def __init__(
173 | self,
174 | harmonic_network_params={
175 | "blockA": 20,
176 | "cycleA": 4,
177 | "blockF": 0,
178 | "cycleF": 0,
179 | "cascade_mode": 0,
180 | },
181 | noise_network_params={
182 | "blockA": 0,
183 | "cycleA": 0,
184 | "blockF": 5,
185 | "cycleF": 5,
186 | "cascade_mode": 0,
187 | },
188 | filter_network_params={
189 | "blockA": 0,
190 | "cycleA": 0,
191 | "blockF": 30,
192 | "cycleF": 3,
193 | "cascade_mode": 0,
194 | },
195 | periodicity_estimator_params={
196 | "conv_layers": 3,
197 | "kernel_size": 5,
198 | "dilation": 1,
199 | "padding_mode": "replicate",
200 | },
201 | in_channels=1,
202 | out_channels=1,
203 | residual_channels=64,
204 | gate_channels=128,
205 | skip_channels=64,
206 | aux_channels=80,
207 | aux_context_window=2,
208 | use_weight_norm=True,
209 | upsample_params={"upsample_scales": [5, 4, 3, 2]},
210 | ):
211 | """Initialize CascadeHnUSFGANGenerator module.
212 |
213 | Args:
214 | harmonic_network_params (dict): Periodic source generation network parameters.
215 | noise_network_params (dict): Aperiodic source generation network parameters.
216 | filter_network_params (dict): Filter network parameters.
217 | periodicity_estimator_params (dict): Periodicity estimation network parameters.
218 | in_channels (int): Number of input channels.
219 | out_channels (int): Number of output channels.
220 | residual_channels (int): Number of channels in residual conv.
221 | gate_channels (int): Number of channels in gated conv.
222 | skip_channels (int): Number of channels in skip conv.
223 | aux_channels (int): Number of channels for auxiliary feature conv.
224 | aux_context_window (int): Context window size for auxiliary feature.
225 | use_weight_norm (bool): Whether to use weight norm.
226 | If set to true, it will be applied to all of the conv layers.
227 | upsample_params (dict): Upsampling network parameters.
228 |
229 | """
230 | super(CascadeHnUSFGANGenerator, self).__init__()
231 | self.in_channels = in_channels
232 | self.out_channels = out_channels
233 | self.aux_channels = aux_channels
234 | self.n_ch = residual_channels
235 |
236 | # define first convolution
237 | self.conv_first_sine = Conv1d1x1(in_channels, residual_channels)
238 | self.conv_first_noise = Conv1d1x1(in_channels, residual_channels)
239 | self.conv_merge = Conv1d1x1(residual_channels * 2, residual_channels)
240 |
241 | # define upsampling network
242 | self.upsample_net = getattr(upsample, "ConvInUpsampleNetwork")(
243 | **upsample_params,
244 | aux_channels=aux_channels,
245 | aux_context_window=aux_context_window,
246 | )
247 |
248 | # define harmonic/noise/filter networks
249 | for params in [
250 | harmonic_network_params,
251 | noise_network_params,
252 | filter_network_params,
253 | ]:
254 | params.update(
255 | {
256 | "residual_channels": residual_channels,
257 | "gate_channels": gate_channels,
258 | "skip_channels": skip_channels,
259 | "aux_channels": aux_channels,
260 | }
261 | )
262 | self.harmonic_network = ResidualBlocks(**harmonic_network_params)
263 | self.noise_network = ResidualBlocks(**noise_network_params)
264 | self.filter_network = ResidualBlocks(**filter_network_params)
265 |
266 | # define periodicity estimator
267 | self.periodicity_estimator = PeriodicityEstimator(
268 | **periodicity_estimator_params, in_channels=aux_channels
269 | )
270 |
271 | # convert hidden representation to output signal
272 | self.conv_last = nn.Sequential(
273 | nn.ReLU(),
274 | Conv1d1x1(skip_channels, skip_channels),
275 | nn.ReLU(),
276 | Conv1d1x1(skip_channels, out_channels),
277 | )
278 |
279 | # apply weight norm
280 | if use_weight_norm:
281 | self.apply_weight_norm()
282 |
283 | def forward(self, x, c, d):
284 | """Calculate forward propagation.
285 |
286 | Args:
287 | x (Tensor): Input noise signal (B, 1, T).
288 | c (Tensor): Local conditioning auxiliary features (B, C ,T').
289 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T).
290 |
291 | Returns:
292 | Tensor: Output tensor (B, 1, T)
293 |
294 | """
295 | # index initialization
296 | batch_index, ch_index = index_initial(x.size(0), self.n_ch)
297 |
298 | # upsample auxiliary features
299 | c = self.upsample_net(c)
300 | assert c.size(-1) == x.size(-1)
301 |
302 | # estimate periodicity
303 | a = self.periodicity_estimator(c)
304 |
305 | # assume the first channel is sine and the other is noise
306 | sine, noise = torch.chunk(x, 2, 1)
307 |
308 | # encode to hidden representation
309 | h = self.conv_first_sine(sine)
310 | n = self.conv_first_noise(noise)
311 |
312 | # generate periodic and aperiodic source latent features
313 | h = self.harmonic_network(h, c, d, batch_index, ch_index)
314 | h = a * h
315 | n = self.conv_merge(torch.cat([h, n], dim=1))
316 | n = self.noise_network(n, c, d, batch_index, ch_index)
317 | n = (1.0 - a) * n
318 |
319 | # merge periodic and aperiodic latent features
320 | s = h + n
321 |
322 | # resonance filtering
323 | x = self.filter_network(s, c, d, batch_index, ch_index)
324 | x = self.conv_last(x)
325 |
326 | # convert to 1d signal for regularization loss
327 | s = self.conv_last(s)
328 |
329 | # just for debug
330 | with torch.no_grad():
331 | h = self.conv_last(h)
332 | n = self.conv_last(n)
333 |
334 | return x, s, h, n, a
335 |
336 | def remove_weight_norm(self):
337 | """Remove weight normalization module from all of the layers."""
338 |
339 | def _remove_weight_norm(m):
340 | try:
341 | logger.debug(f"Weight norm is removed from {m}.")
342 | nn.utils.remove_weight_norm(m)
343 | except ValueError: # this module didn't have weight norm
344 | return
345 |
346 | self.apply(_remove_weight_norm)
347 |
348 | def apply_weight_norm(self):
349 | """Apply weight normalization module from all of the layers."""
350 |
351 | def _apply_weight_norm(m):
352 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
353 | nn.utils.weight_norm(m)
354 | logger.debug(f"Weight norm is applied to {m}.")
355 |
356 | self.apply(_apply_weight_norm)
357 |
358 |
359 | class ParallelHnUSFGANGenerator(nn.Module):
360 | """Parallel hn-uSFGAN Generator module."""
361 |
362 | def __init__(
363 | self,
364 | harmonic_network_params={
365 | "blockA": 20,
366 | "cycleA": 4,
367 | "blockF": 0,
368 | "cycleF": 0,
369 | "cascade_mode": 0,
370 | },
371 | noise_network_params={
372 | "blockA": 0,
373 | "cycleA": 0,
374 | "blockF": 5,
375 | "cycleF": 5,
376 | "cascade_mode": 0,
377 | },
378 | filter_network_params={
379 | "blockA": 0,
380 | "cycleA": 0,
381 | "blockF": 30,
382 | "cycleF": 3,
383 | "cascade_mode": 0,
384 | },
385 | periodicity_estimator_params={
386 | "conv_layers": 3,
387 | "kernel_size": 5,
388 | "dilation": 1,
389 | "padding_mode": "replicate",
390 | },
391 | in_channels=1,
392 | out_channels=1,
393 | residual_channels=64,
394 | gate_channels=128,
395 | skip_channels=64,
396 | aux_channels=80,
397 | aux_context_window=2,
398 | use_weight_norm=True,
399 | upsample_params={"upsample_scales": [5, 4, 3, 2]},
400 | ):
401 | """Initialize ParallelHnUSFGANGenerator module.
402 |
403 | Args:
404 | harmonic_network_params (dict): Periodic source generation network parameters.
405 | noise_network_params (dict): Aperiodic source generation network parameters.
406 | filter_network_params (dict): Filter network parameters.
407 | periodicity_estimator_params (dict): Periodicity estimation network parameters.
408 | in_channels (int): Number of input channels.
409 | out_channels (int): Number of output channels.
410 | residual_channels (int): Number of channels in residual conv.
411 | gate_channels (int): Number of channels in gated conv.
412 | skip_channels (int): Number of channels in skip conv.
413 | aux_channels (int): Number of channels for auxiliary feature conv.
414 | aux_context_window (int): Context window size for auxiliary feature.
415 | use_weight_norm (bool): Whether to use weight norm.
416 | If set to true, it will be applied to all of the conv layers.
417 | upsample_params (dict): Upsampling network parameters.
418 |
419 | """
420 | super(ParallelHnUSFGANGenerator, self).__init__()
421 | self.in_channels = in_channels
422 | self.out_channels = out_channels
423 | self.aux_channels = aux_channels
424 | self.n_ch = residual_channels
425 |
426 | # define first convolution
427 | self.conv_first_sine = Conv1d1x1(in_channels, residual_channels)
428 | self.conv_first_noise = Conv1d1x1(in_channels, residual_channels)
429 |
430 | # define upsampling network
431 | self.upsample_net = getattr(upsample, "ConvInUpsampleNetwork")(
432 | **upsample_params,
433 | aux_channels=aux_channels,
434 | aux_context_window=aux_context_window,
435 | )
436 |
437 | # define harmonic/noise/filter networks
438 | for params in [
439 | harmonic_network_params,
440 | noise_network_params,
441 | filter_network_params,
442 | ]:
443 | params.update(
444 | {
445 | "residual_channels": residual_channels,
446 | "gate_channels": gate_channels,
447 | "skip_channels": skip_channels,
448 | "aux_channels": aux_channels,
449 | }
450 | )
451 | self.harmonic_network = ResidualBlocks(**harmonic_network_params)
452 | self.noise_network = ResidualBlocks(**noise_network_params)
453 | self.filter_network = ResidualBlocks(**filter_network_params)
454 |
455 | # define periodicity estimator
456 | self.periodicity_estimator = PeriodicityEstimator(
457 | **periodicity_estimator_params, in_channels=aux_channels
458 | )
459 |
460 | # convert hidden representation to output signal
461 | self.conv_last = nn.Sequential(
462 | nn.ReLU(),
463 | Conv1d1x1(skip_channels, skip_channels),
464 | nn.ReLU(),
465 | Conv1d1x1(skip_channels, out_channels),
466 | )
467 |
468 | # apply weight norm
469 | if use_weight_norm:
470 | self.apply_weight_norm()
471 |
472 | def forward(self, x, c, d):
473 | """Calculate forward propagation.
474 |
475 | Args:
476 | x (Tensor): Input noise signal (B, 1, T).
477 | c (Tensor): Local conditioning auxiliary features (B, C ,T').
478 | d (Tensor): Input pitch-dependent dilated factors (B, 1, T).
479 |
480 | Returns:
481 | Tensor: Output tensor (B, 1, T)
482 |
483 | """
484 | # index initialization
485 | batch_index, ch_index = index_initial(x.size(0), self.n_ch)
486 |
487 | # upsample auxiliary features
488 | c = self.upsample_net(c)
489 | assert c.size(-1) == x.size(-1)
490 |
491 | # estimate periodicity
492 | a = self.periodicity_estimator(c)
493 |
494 | # assume the first channel is sine and the other is noise
495 | sine, noise = torch.chunk(x, 2, 1)
496 |
497 | # encode to hidden representation
498 | h = self.conv_first_sine(sine)
499 | n = self.conv_first_noise(noise)
500 |
501 | # generate periodic and aperiodic source latent features
502 | h = self.harmonic_network(h, c, d, batch_index, ch_index)
503 | n = self.noise_network(n, c, d, batch_index, ch_index)
504 |
505 | # merge periodic and aperiodic latent features
506 | h = a * h
507 | n = (1.0 - a) * n
508 | s = h + n
509 |
510 | # resonance filtering
511 | x = self.filter_network(s, c, d, batch_index, ch_index)
512 | x = self.conv_last(x)
513 |
514 | # convert to 1d signal for regularization loss
515 | s = self.conv_last(s)
516 |
517 | # just for debug
518 | with torch.no_grad():
519 | h = self.conv_last(h)
520 | n = self.conv_last(n)
521 |
522 | return x, s, h, n, a
523 |
524 | def remove_weight_norm(self):
525 | """Remove weight normalization module from all of the layers."""
526 |
527 | def _remove_weight_norm(m):
528 | try:
529 | logger.debug(f"Weight norm is removed from {m}.")
530 | nn.utils.remove_weight_norm(m)
531 | except ValueError: # this module didn't have weight norm
532 | return
533 |
534 | self.apply(_remove_weight_norm)
535 |
536 | def apply_weight_norm(self):
537 | """Apply weight normalization module from all of the layers."""
538 |
539 | def _apply_weight_norm(m):
540 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
541 | nn.utils.weight_norm(m)
542 | logger.debug(f"Weight norm is applied to {m}.")
543 |
544 | self.apply(_apply_weight_norm)
545 |
--------------------------------------------------------------------------------
/usfgan/bin/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright 2022 Reo Yoneyama (Nagoya University)
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | """Training Script for Unified Source-Filter GAN.
7 |
8 | References:
9 | - https://github.com/bigpon/QPPWG
10 | - https://github.com/kan-bayashi/ParallelWaveGAN
11 |
12 | """
13 |
14 | import os
15 | import sys
16 | from collections import defaultdict
17 | from logging import getLogger
18 |
19 | import hydra
20 | import librosa.display
21 | import matplotlib
22 | import numpy as np
23 | import torch
24 | import usfgan
25 | import usfgan.models
26 | from hydra.utils import to_absolute_path
27 | from omegaconf import DictConfig, OmegaConf
28 | from tensorboardX import SummaryWriter
29 | from torch.utils.data import DataLoader
30 | from tqdm import tqdm
31 | from usfgan.datasets import AudioFeatDataset
32 | from usfgan.utils.features import SignalGenerator
33 |
34 | # set to avoid matplotlib error in CLI environment
35 | matplotlib.use("Agg")
36 |
37 |
38 | # A logger for this file
39 | logger = getLogger(__name__)
40 |
41 |
42 | class Trainer(object):
43 | """Customized trainer module for Unified Source-Filter GAN training."""
44 |
45 | def __init__(
46 | self,
47 | config,
48 | steps,
49 | epochs,
50 | data_loader,
51 | model,
52 | criterion,
53 | optimizer,
54 | scheduler,
55 | device=torch.device("cpu"),
56 | ):
57 | """Initialize trainer.
58 |
59 | Args:
60 | config (dict): Config dict loaded from yaml format configuration file.
61 | steps (int): Initial global steps.
62 | epochs (int): Initial global epochs.
63 | data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders.
64 | model (dict): Dict of models. It must contrain "generator" and "discriminator" models.
65 | criterion (dict): Dict of criterions. It must contrain "adversarial", "encode" and "f0" criterions.
66 | optimizer (dict): Dict of optimizers. It must contrain "generator" and "discriminator" optimizers.
67 | scheduler (dict): Dict of schedulers. It must contrain "generator" and "discriminator" schedulers.
68 | device (torch.deive): Pytorch device instance.
69 |
70 | """
71 | self.config = config
72 | self.steps = steps
73 | self.epochs = epochs
74 | self.data_loader = data_loader
75 | self.model = model
76 | self.criterion = criterion
77 | self.optimizer = optimizer
78 | self.scheduler = scheduler
79 | self.device = device
80 | self.finish_train = False
81 | self.writer = SummaryWriter(config.out_dir)
82 | self.total_train_loss = defaultdict(float)
83 | self.total_eval_loss = defaultdict(float)
84 |
85 | def run(self):
86 | """Run training."""
87 | self.tqdm = tqdm(
88 | initial=self.steps, total=self.config.train.train_max_steps, desc="[train]"
89 | )
90 | while True:
91 | # train one epoch
92 | self._train_epoch()
93 |
94 | # check whether training is finished
95 | if self.finish_train:
96 | break
97 |
98 | self.tqdm.close()
99 | logger.info("Finished training.")
100 |
101 | def save_checkpoint(self, checkpoint_path):
102 | """Save checkpoint.
103 |
104 | Args:
105 | checkpoint_path (str): Checkpoint path to be saved.
106 |
107 | """
108 | state_dict = {
109 | "optimizer": {
110 | "generator": self.optimizer["generator"].state_dict(),
111 | "discriminator": self.optimizer["discriminator"].state_dict(),
112 | },
113 | "scheduler": {
114 | "generator": self.scheduler["generator"].state_dict(),
115 | "discriminator": self.scheduler["discriminator"].state_dict(),
116 | },
117 | "steps": self.steps,
118 | "epochs": self.epochs,
119 | }
120 | state_dict["model"] = {
121 | "generator": self.model["generator"].state_dict(),
122 | "discriminator": self.model["discriminator"].state_dict(),
123 | }
124 |
125 | if not os.path.exists(os.path.dirname(checkpoint_path)):
126 | os.makedirs(os.path.dirname(checkpoint_path))
127 | torch.save(state_dict, checkpoint_path)
128 |
129 | def load_checkpoint(self, checkpoint_path, load_only_params=False):
130 | """Load checkpoint.
131 |
132 | Args:
133 | checkpoint_path (str): Checkpoint path to be loaded.
134 | load_only_params (bool): Whether to load only model parameters.
135 |
136 | """
137 | state_dict = torch.load(checkpoint_path, map_location="cpu")
138 | self.model["generator"].load_state_dict(state_dict["model"]["generator"])
139 | self.model["discriminator"].load_state_dict(
140 | state_dict["model"]["discriminator"]
141 | )
142 | if not load_only_params:
143 | self.steps = state_dict["steps"]
144 | self.epochs = state_dict["epochs"]
145 | self.optimizer["generator"].load_state_dict(
146 | state_dict["optimizer"]["generator"]
147 | )
148 | self.optimizer["discriminator"].load_state_dict(
149 | state_dict["optimizer"]["discriminator"]
150 | )
151 | self.scheduler["generator"].load_state_dict(
152 | state_dict["scheduler"]["generator"]
153 | )
154 | self.scheduler["discriminator"].load_state_dict(
155 | state_dict["scheduler"]["discriminator"]
156 | )
157 |
158 | def _train_step(self, batch):
159 | """Train model one step."""
160 | # parse batch
161 | x, y = batch
162 | x = tuple([x_.to(self.device) for x_ in x])
163 | z, c, df, f0 = x
164 | y_real = y.to(self.device)
165 |
166 | # generator forward
167 | y_fake, s = self.model["generator"](z, c, df)[:2]
168 |
169 | # calculate spectral loss
170 | if isinstance(self.criterion["stft"], usfgan.losses.MultiResolutionSTFTLoss):
171 | # Parallel WaveGAN Multi-Resolution STFT Loss
172 | sc_loss, mag_loss = self.criterion["stft"](y_fake, y_real)
173 | gen_loss = self.config.train.lambda_stft * (sc_loss + mag_loss)
174 | self.total_train_loss["train/spectral_convergence_loss"] += sc_loss.item()
175 | self.total_train_loss["train/log_stft_magnitude_loss"] += mag_loss.item()
176 | elif isinstance(
177 | self.criterion["stft"], usfgan.losses.MultiResolutionLogSTFTPowerLoss
178 | ):
179 | # Neural Source-Filter Multi-Resolution STFT Loss
180 | stft_loss = self.criterion["stft"](y_fake, y_real)
181 | gen_loss = self.config.train.lambda_stft * stft_loss
182 | self.total_train_loss["train/log_stft_power_loss"] += stft_loss.item()
183 | elif isinstance(self.criterion["stft"], usfgan.losses.MelSpectralLoss):
184 | # HiFiGAN Mel-Spectrogram Reconstruction Loss
185 | mel_loss = self.criterion["stft"](y_fake, y_real)
186 | gen_loss = self.config.train.lambda_stft * mel_loss
187 | self.total_train_loss["train/log_mel_spec_loss"] += mel_loss.item()
188 |
189 | # calculate source regularization loss for usfgan-based models
190 | if self.config.train.lambda_source > 0:
191 | if isinstance(
192 | self.criterion["source"],
193 | usfgan.losses.ResidualLoss,
194 | ):
195 | source_loss = self.criterion["source"](s, y_real, f0)
196 | gen_loss += self.config.train.lambda_source * source_loss
197 | self.total_train_loss["train/source_loss"] += source_loss.item()
198 | else:
199 | source_loss = self.criterion["source"](s, f0)
200 | gen_loss += self.config.train.lambda_source * source_loss
201 | self.total_train_loss["train/source_loss"] += source_loss.item()
202 |
203 | # calculate discriminator related losses
204 | if self.steps > self.config.train.discriminator_train_start_steps:
205 | # calculate feature matching loss
206 | if self.config.train.lambda_feat_match > 0:
207 | p_fake, fmaps_fake = self.model["discriminator"](
208 | y_fake, return_fmaps=True
209 | )
210 | with torch.no_grad():
211 | p_real, fmaps_real = self.model["discriminator"](
212 | y_real, return_fmaps=True
213 | )
214 | fm_loss = self.criterion["feat_match"](fmaps_fake, fmaps_real)
215 | gen_loss += self.config.train.lambda_feat_match * fm_loss
216 | self.total_train_loss["train/feat_match_loss"] += fm_loss.item()
217 | else:
218 | p_fake = self.model["discriminator"](y_fake)
219 | # calculate adversarial loss
220 | adv_loss = self.criterion["adversarial"](p_fake)
221 | gen_loss += self.config.train.lambda_adv * adv_loss
222 | self.total_train_loss["train/adversarial_loss"] += adv_loss.item()
223 |
224 | self.total_train_loss["train/generator_loss"] += gen_loss.item()
225 |
226 | # update generator
227 | self.optimizer["generator"].zero_grad()
228 | gen_loss.backward()
229 | if self.config.train.generator_grad_norm > 0:
230 | torch.nn.utils.clip_grad_norm_(
231 | self.model["generator"].parameters(),
232 | self.config.train.generator_grad_norm,
233 | )
234 | self.optimizer["generator"].step()
235 | self.scheduler["generator"].step()
236 |
237 | # discriminator
238 | if self.steps > self.config.train.discriminator_train_start_steps:
239 | # re-compute y_fake
240 | with torch.no_grad():
241 | y_fake = self.model["generator"](z, c, df)[0]
242 | # calculate discriminator loss
243 | p_fake = self.model["discriminator"](y_fake.detach())
244 | p_real = self.model["discriminator"](y_real)
245 | # NOTE: the first argument must to be the fake samples
246 | fake_loss, real_loss = self.criterion["adversarial"](p_fake, p_real)
247 | dis_loss = fake_loss + real_loss
248 | self.total_train_loss["train/fake_loss"] += fake_loss.item()
249 | self.total_train_loss["train/real_loss"] += real_loss.item()
250 | self.total_train_loss["train/discriminator_loss"] += dis_loss.item()
251 |
252 | # update discriminator
253 | self.optimizer["discriminator"].zero_grad()
254 | dis_loss.backward()
255 | if self.config.train.discriminator_grad_norm > 0:
256 | torch.nn.utils.clip_grad_norm_(
257 | self.model["discriminator"].parameters(),
258 | self.config.train.discriminator_grad_norm,
259 | )
260 | self.optimizer["discriminator"].step()
261 | self.scheduler["discriminator"].step()
262 |
263 | # update counts
264 | self.steps += 1
265 | self.tqdm.update(1)
266 | self._check_train_finish()
267 |
268 | def _train_epoch(self):
269 | """Train model one epoch."""
270 | for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
271 | # train one step
272 | self._train_step(batch)
273 |
274 | # check interval
275 | self._check_log_interval()
276 | self._check_eval_interval()
277 | self._check_save_interval()
278 |
279 | # check whether training is finished
280 | if self.finish_train:
281 | return
282 |
283 | # update
284 | self.epochs += 1
285 | self.train_steps_per_epoch = train_steps_per_epoch
286 | logger.info(
287 | f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
288 | f"({self.train_steps_per_epoch} steps per epoch)."
289 | )
290 |
291 | @torch.no_grad()
292 | def _eval_step(self, batch):
293 | """Evaluate model one step."""
294 | # parse batch
295 | x, y = batch
296 | x = tuple([x_.to(self.device) for x_ in x])
297 | z, c, df, f0 = x
298 | y_real = y.to(self.device)
299 |
300 | # generator forward
301 | y_fake, s = self.model["generator"](z, c, df)[:2]
302 |
303 | # calculate spectral loss
304 | if isinstance(self.criterion["stft"], usfgan.losses.MultiResolutionSTFTLoss):
305 | # Parallel WaveGAN Multi-Resolution STFT Loss
306 | sc_loss, mag_loss = self.criterion["stft"](y_fake, y_real)
307 | gen_loss = self.config.train.lambda_stft * (sc_loss + mag_loss)
308 | self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item()
309 | self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item()
310 | elif isinstance(
311 | self.criterion["stft"], usfgan.losses.MultiResolutionLogSTFTPowerLoss
312 | ):
313 | # Neural Source-Filter Multi-Resolution STFT Loss
314 | stft_loss = self.criterion["stft"](y_fake, y_real)
315 | gen_loss = self.config.train.lambda_stft * stft_loss
316 | self.total_eval_loss["eval/log_stft_power_loss"] += stft_loss.item()
317 | elif isinstance(self.criterion["stft"], usfgan.losses.MelSpectralLoss):
318 | # HiFiGAN Mel-Spectrogram Reconstruction Loss
319 | mel_loss = self.criterion["stft"](y_fake, y_real)
320 | gen_loss = self.config.train.lambda_stft * mel_loss
321 | self.total_eval_loss["eval/log_mel_spec_loss"] += mel_loss.item()
322 |
323 | # calculate source regularization loss for usfgan-based models
324 | if self.config.train.lambda_source > 0:
325 | if isinstance(
326 | self.criterion["source"],
327 | usfgan.losses.ResidualLoss,
328 | ):
329 | source_loss = self.criterion["source"](s, y_real, f0)
330 | gen_loss += self.config.train.lambda_source * source_loss
331 | self.total_eval_loss["eval/source_loss"] += source_loss.item()
332 | else:
333 | source_loss = self.criterion["source"](s, f0)
334 | gen_loss += self.config.train.lambda_source * source_loss
335 | self.total_eval_loss["eval/source_loss"] += source_loss.item()
336 |
337 | # calculate discriminator related losses
338 | if self.steps > self.config.train.discriminator_train_start_steps:
339 | # calculate feature matching loss
340 | if self.config.train.lambda_feat_match > 0:
341 | p_fake, fmaps_fake = self.model["discriminator"](
342 | y_fake, return_fmaps=True
343 | )
344 | p_real, fmaps_real = self.model["discriminator"](
345 | y_real, return_fmaps=True
346 | )
347 | fm_loss = self.criterion["feat_match"](fmaps_fake, fmaps_real)
348 | gen_loss += self.config.train.lambda_feat_match * fm_loss
349 | self.total_eval_loss["eval/feat_match_loss"] += fm_loss.item()
350 | else:
351 | p_fake = self.model["discriminator"](y_fake)
352 | # calculate adversarial loss
353 | adv_loss = self.criterion["adversarial"](p_fake)
354 | gen_loss += self.config.train.lambda_adv * adv_loss
355 | self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
356 |
357 | self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
358 |
359 | # discriminator
360 | if self.steps > self.config.train.discriminator_train_start_steps:
361 | # calculate discriminator loss
362 | p_real = self.model["discriminator"](y_real)
363 | # NOTE: the first augment must to be the fake sample
364 | fake_loss, real_loss = self.criterion["adversarial"](p_fake, p_real)
365 | dis_loss = fake_loss + real_loss
366 | self.total_eval_loss["eval/fake_loss"] += fake_loss.item()
367 | self.total_eval_loss["eval/real_loss"] += real_loss.item()
368 | self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()
369 |
370 | def _eval_epoch(self):
371 | """Evaluate model one epoch."""
372 | logger.info(f"(Steps: {self.steps}) Start evaluation.")
373 | # change mode
374 | for key in self.model.keys():
375 | self.model[key].eval()
376 |
377 | # calculate loss for each batch
378 | for eval_steps_per_epoch, batch in enumerate(
379 | tqdm(self.data_loader["valid"], desc="[eval]"), 1
380 | ):
381 | # eval one step
382 | self._eval_step(batch)
383 |
384 | # save intermediate result
385 | if eval_steps_per_epoch == 1:
386 | self._genearete_and_save_intermediate_result(batch)
387 | if eval_steps_per_epoch == 3:
388 | break
389 |
390 | logger.info(
391 | f"(Steps: {self.steps}) Finished evaluation "
392 | f"({eval_steps_per_epoch} steps per epoch)."
393 | )
394 |
395 | # average loss
396 | for key in self.total_eval_loss.keys():
397 | self.total_eval_loss[key] /= eval_steps_per_epoch
398 | logger.info(
399 | f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
400 | )
401 |
402 | # record
403 | self._write_to_tensorboard(self.total_eval_loss)
404 |
405 | # reset
406 | self.total_eval_loss = defaultdict(float)
407 |
408 | # restore mode
409 | for key in self.model.keys():
410 | self.model[key].train()
411 |
412 | @torch.no_grad()
413 | def _genearete_and_save_intermediate_result(self, batch):
414 | """Generate and save intermediate result."""
415 | # delayed import to avoid error related backend error
416 | import matplotlib.pyplot as plt
417 |
418 | x_batch, y_real_batch = batch
419 | # use only the first sample
420 | x_batch = [x[:1].to(self.device) for x in x_batch]
421 | y_real_batch = y_real_batch[:1]
422 | z_batch, c_batch, df_batch = x_batch[:3]
423 |
424 | # generator forward
425 | h_batch, n_batch, a_batch = None, None, None
426 | if isinstance(
427 | self.model["generator"],
428 | (
429 | usfgan.models.ParallelHnUSFGANGenerator,
430 | usfgan.models.CascadeHnUSFGANGenerator,
431 | ),
432 | ):
433 | y_fake_batch, s_batch, h_batch, n_batch, a_batch = self.model["generator"](
434 | z_batch, c_batch, df_batch
435 | )
436 | else:
437 | y_fake_batch, s_batch = self.model["generator"](z_batch, c_batch, df_batch)
438 |
439 | len50ms = int(self.config.data.sample_rate * 0.05)
440 | start = np.random.randint(0, self.config.data.batch_max_length - len50ms)
441 | end = start + len50ms
442 |
443 | for audio, name, save_wav in zip(
444 | [y_real_batch, y_fake_batch, s_batch, h_batch, n_batch],
445 | ["real", "fake", "source", "harmonic", "noise"],
446 | [True, True, True, False, False],
447 | ):
448 | if audio is not None:
449 | audio = audio.view(-1).cpu().numpy()
450 |
451 | # plot spectrogram
452 | fig = plt.figure(figsize=(8, 6))
453 | spectrogram = np.abs(
454 | librosa.stft(
455 | y=audio,
456 | n_fft=1024,
457 | hop_length=128,
458 | win_length=1024,
459 | window="hann",
460 | )
461 | )
462 | spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max)
463 | librosa.display.specshow(
464 | spectrogram_db,
465 | sr=self.config.data.sample_rate,
466 | y_axis="linear",
467 | x_axis="time",
468 | hop_length=128,
469 | )
470 | self.writer.add_figure(f"spectrogram/{name}", fig, self.steps)
471 | plt.clf()
472 | plt.close()
473 |
474 | # plot full waveform
475 | fig = plt.figure(figsize=(6, 3))
476 | plt.plot(audio, linewidth=1)
477 | self.writer.add_figure(f"waveform/{name}", fig, self.steps)
478 | plt.clf()
479 | plt.close()
480 |
481 | # plot short term waveform
482 | fig = plt.figure(figsize=(6, 3))
483 | plt.plot(audio[start:end], linewidth=1)
484 | self.writer.add_figure(f"short_waveform/{name}", fig, self.steps)
485 | plt.clf()
486 | plt.close()
487 |
488 | # save as wavfile
489 | if save_wav:
490 | audio = audio / np.max(np.abs(audio))
491 | self.writer.add_audio(
492 | f"audio_{name}.wav",
493 | audio,
494 | self.steps,
495 | self.config.data.sample_rate,
496 | )
497 |
498 | # plot aperiodicity weights
499 | if a_batch is not None:
500 | fig = plt.figure(figsize=(6, 4))
501 | plt.imshow(a_batch.squeeze(0).cpu().numpy(), aspect="auto")
502 | plt.colorbar()
503 | self.writer.add_figure(f"aperiodicity", fig, self.steps)
504 | plt.clf()
505 | plt.close()
506 |
507 | def _write_to_tensorboard(self, loss):
508 | """Write to tensorboard."""
509 | for key, value in loss.items():
510 | self.writer.add_scalar(key, value, self.steps)
511 |
512 | def _check_save_interval(self):
513 | if self.steps % self.config.train.save_interval_steps == 0:
514 | self.save_checkpoint(
515 | os.path.join(
516 | self.config.out_dir,
517 | "checkpoints",
518 | f"checkpoint-{self.steps}steps.pkl",
519 | )
520 | )
521 | logger.info(f"Successfully saved checkpoint @ {self.steps} steps.")
522 |
523 | def _check_eval_interval(self):
524 | if self.steps % self.config.train.eval_interval_steps == 0:
525 | self._eval_epoch()
526 |
527 | def _check_log_interval(self):
528 | if self.steps % self.config.train.log_interval_steps == 0:
529 | for key in self.total_train_loss.keys():
530 | self.total_train_loss[key] /= self.config.train.log_interval_steps
531 | logger.info(
532 | f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
533 | )
534 | self._write_to_tensorboard(self.total_train_loss)
535 |
536 | # reset
537 | self.total_train_loss = defaultdict(float)
538 |
539 | def _check_train_finish(self):
540 | if self.steps >= self.config.train.train_max_steps:
541 | self.finish_train = True
542 |
543 |
544 | class Collater(object):
545 | """Customized collater for Pytorch DataLoader in training."""
546 |
547 | def __init__(
548 | self,
549 | batch_max_length=12000,
550 | sample_rate=24000,
551 | hop_size=120,
552 | aux_context_window=2,
553 | sine_amp=0.1,
554 | noise_amp=0.003,
555 | sine_f0_type="contf0",
556 | signal_types=["sine", "noise"],
557 | ):
558 | """Initialize customized collater for PyTorch DataLoader.
559 |
560 | Args:
561 | batch_max_length (int): The maximum length of batch.
562 | sample_rate (int): Sampling rate.
563 | hop_size (int): Hop size of auxiliary features.
564 | aux_context_window (int): Context window size for auxiliary feature conv.
565 | sine_amp (float): Amplitude of sine signal.
566 | noise_amp (float): Amplitude of random noise signal.
567 | sine_f0_type (str): F0 type for generating sine signal.
568 | signal_types (list): List of types for input signals.
569 |
570 | """
571 | if batch_max_length % hop_size != 0:
572 | batch_max_length += -(batch_max_length % hop_size)
573 | assert batch_max_length % hop_size == 0
574 | self.batch_max_length = batch_max_length
575 | self.batch_max_frames = batch_max_length // hop_size
576 | self.sample_rate = sample_rate
577 | self.hop_size = hop_size
578 | self.aux_context_window = aux_context_window
579 | self.sine_f0_type = sine_f0_type
580 | self.signal_generator = SignalGenerator(
581 | sample_rate=sample_rate,
582 | hop_size=hop_size,
583 | sine_amp=sine_amp,
584 | noise_amp=noise_amp,
585 | signal_types=signal_types,
586 | )
587 |
588 | def __call__(self, batch):
589 | """Convert into batch tensors.
590 |
591 | Args:
592 | batch (list): list of tuple of the pair of audio and features.
593 |
594 | Returns:
595 | Tensor: Gaussian noise (and sine) batch (B, D, T).
596 | Tensor: Auxiliary feature batch (B, C, T' + 2 * aux_context_window).
597 | Tensor: Dilated factor batch (B, 1, T).
598 | Tensor: F0 sequence batch (B, 1, T').
599 | Tensor: Target signal batch (B, 1, T).
600 |
601 | """
602 | # time resolution check
603 | y_batch, c_batch, df_batch, f0_batch, contf0_batch = [], [], [], [], []
604 | for idx in range(len(batch)):
605 | x, c, df, f0, contf0 = batch[idx]
606 | self._check_length(x, c, df, f0, contf0, 0)
607 | if len(c) - 2 * self.aux_context_window > self.batch_max_frames:
608 | # randomly pickup with the batch_max_length length of the part
609 | interval_start = self.aux_context_window
610 | interval_end = len(c) - self.batch_max_frames - self.aux_context_window
611 | start_frame = np.random.randint(interval_start, interval_end)
612 | start_step = start_frame * self.hop_size
613 | y = x[start_step : start_step + self.batch_max_length]
614 | c = c[
615 | start_frame
616 | - self.aux_context_window : start_frame
617 | + self.aux_context_window
618 | + self.batch_max_frames
619 | ]
620 | df = df[start_step : start_step + self.batch_max_length]
621 | f0 = f0[start_frame : start_frame + self.batch_max_frames]
622 | contf0 = contf0[start_frame : start_frame + self.batch_max_frames]
623 | self._check_length(
624 | y,
625 | c,
626 | df,
627 | f0,
628 | contf0,
629 | self.aux_context_window,
630 | )
631 | else:
632 | logger.warn(f"Removed short sample from batch (length={len(x)}).")
633 | continue
634 | y_batch += [y.astype(np.float32).reshape(-1, 1)] # [(T, 1), ...]
635 | c_batch += [c.astype(np.float32)] # [(T' + 2 * aux_context_window, D), ...]
636 | df_batch += [df.astype(np.float32).reshape(-1, 1)] # [(T, 1), ...]
637 | f0_batch += [f0.astype(np.float32).reshape(-1, 1)] # [(T', 1), ...]
638 | contf0_batch += [contf0.astype(np.float32).reshape(-1, 1)] # [(T', 1), ...]
639 |
640 | # convert each batch to tensor, asuume that each item in batch has the same length
641 | y_batch = torch.FloatTensor(np.array(y_batch)).transpose(2, 1) # (B, 1, T)
642 | c_batch = torch.FloatTensor(np.array(c_batch)).transpose(
643 | 2, 1
644 | ) # (B, 1, T' + 2 * aux_context_window)
645 | df_batch = torch.FloatTensor(np.array(df_batch)).transpose(2, 1) # (B, 1, T)
646 | f0_batch = torch.FloatTensor(np.array(f0_batch)).transpose(2, 1) # (B, 1, T')
647 | contf0_batch = torch.FloatTensor(np.array(contf0_batch)).transpose(
648 | 2, 1
649 | ) # (B, 1, T')
650 |
651 | # make input signal batch tensor
652 | if self.sine_f0_type == "contf0":
653 | in_batch = self.signal_generator(contf0_batch)
654 | else:
655 | in_batch = self.signal_generator(f0_batch)
656 |
657 | return (in_batch, c_batch, df_batch, f0_batch), y_batch
658 |
659 | def _check_length(self, x, c, df, f0, contf0, context_window):
660 | """Assert the audio and feature lengths are correctly adjusted for upsamping."""
661 | assert len(x) == (len(c) - 2 * context_window) * self.hop_size
662 | assert len(x) == len(df)
663 | assert len(x) == len(f0) * self.hop_size
664 | assert len(x) == len(contf0) * self.hop_size
665 |
666 |
667 | @hydra.main(version_base=None, config_path="config", config_name="train")
668 | def main(config: DictConfig) -> None:
669 | """Run training process."""
670 |
671 | if not torch.cuda.is_available():
672 | print("CPU")
673 | device = torch.device("cpu")
674 | else:
675 | print("GPU")
676 | device = torch.device("cuda")
677 | # effective when using fixed size inputs
678 | # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
679 | torch.backends.cudnn.benchmark = True
680 |
681 | # fix seed
682 | np.random.seed(config.seed)
683 | torch.manual_seed(config.seed)
684 | torch.cuda.manual_seed(config.seed)
685 | os.environ["PYTHONHASHSEED"] = str(config.seed)
686 |
687 | # check directory existence
688 | if not os.path.exists(config.out_dir):
689 | os.makedirs(config.out_dir)
690 |
691 | # write config to yaml file
692 | with open(os.path.join(config.out_dir, "config.yaml"), "w") as f:
693 | f.write(OmegaConf.to_yaml(config))
694 | logger.info(OmegaConf.to_yaml(config))
695 |
696 | # get dataset
697 | if config.data.remove_short_samples:
698 | feat_length_threshold = (
699 | config.data.batch_max_length // config.data.hop_size
700 | + 2 * config.generator.aux_context_window
701 | )
702 | else:
703 | feat_length_threshold = None
704 |
705 | train_dataset = AudioFeatDataset(
706 | stats=to_absolute_path(config.data.stats),
707 | audio_list=to_absolute_path(config.data.train_audio),
708 | feat_list=to_absolute_path(config.data.train_feat),
709 | feat_length_threshold=feat_length_threshold,
710 | allow_cache=config.data.allow_cache,
711 | sample_rate=config.data.sample_rate,
712 | hop_size=config.data.hop_size,
713 | dense_factor=config.data.dense_factor,
714 | df_f0_type=config.data.df_f0_type,
715 | aux_feats=config.data.aux_feats,
716 | )
717 | logger.info(f"The number of training files = {len(train_dataset)}.")
718 |
719 | valid_dataset = AudioFeatDataset(
720 | stats=to_absolute_path(config.data.stats),
721 | audio_list=to_absolute_path(config.data.valid_audio),
722 | feat_list=to_absolute_path(config.data.valid_feat),
723 | feat_length_threshold=feat_length_threshold,
724 | allow_cache=config.data.allow_cache,
725 | sample_rate=config.data.sample_rate,
726 | hop_size=config.data.hop_size,
727 | dense_factor=config.data.dense_factor,
728 | df_f0_type=config.data.df_f0_type,
729 | aux_feats=config.data.aux_feats,
730 | )
731 | logger.info(f"The number of validation files = {len(valid_dataset)}.")
732 |
733 | dataset = {
734 | "train": train_dataset,
735 | "valid": valid_dataset,
736 | }
737 |
738 | # get data loader
739 | collater = Collater(
740 | batch_max_length=config.data.batch_max_length,
741 | aux_context_window=config.generator.aux_context_window,
742 | sample_rate=config.data.sample_rate,
743 | hop_size=config.data.hop_size,
744 | sine_amp=config.data.sine_amp,
745 | noise_amp=config.data.noise_amp,
746 | sine_f0_type=config.data.sine_f0_type,
747 | signal_types=config.data.signal_types,
748 | )
749 | train_sampler, valid_sampler = None, None
750 |
751 | data_loader = {
752 | "train": DataLoader(
753 | dataset=dataset["train"],
754 | shuffle=True,
755 | collate_fn=collater,
756 | batch_size=config.data.batch_size,
757 | num_workers=config.data.num_workers,
758 | sampler=train_sampler,
759 | pin_memory=config.data.pin_memory,
760 | ),
761 | "valid": DataLoader(
762 | dataset=dataset["valid"],
763 | shuffle=True,
764 | collate_fn=collater,
765 | batch_size=config.data.batch_size,
766 | num_workers=config.data.num_workers,
767 | sampler=valid_sampler,
768 | pin_memory=config.data.pin_memory,
769 | ),
770 | }
771 |
772 | # define models and optimizers
773 | model = {
774 | "generator": hydra.utils.instantiate(config.generator).to(device),
775 | "discriminator": hydra.utils.instantiate(config.discriminator).to(device),
776 | }
777 |
778 | # define training criteria
779 | criterion = {
780 | "stft": hydra.utils.instantiate(config.train.stft_loss).to(device),
781 | "adversarial": hydra.utils.instantiate(config.train.adversarial_loss).to(
782 | device
783 | ),
784 | }
785 | if config.train.lambda_feat_match > 0:
786 | criterion["feat_match"] = hydra.utils.instantiate(
787 | config.train.feat_match_loss
788 | ).to(device)
789 | if config.train.lambda_source > 0:
790 | criterion["source"] = hydra.utils.instantiate(config.train.source_loss).to(
791 | device
792 | )
793 |
794 | # define optimizers and schedulers
795 | optimizer = {
796 | "generator": hydra.utils.instantiate(
797 | config.train.generator_optimizer,
798 | params=model["generator"].parameters(),
799 | ),
800 | "discriminator": hydra.utils.instantiate(
801 | config.train.discriminator_optimizer,
802 | params=model["discriminator"].parameters(),
803 | ),
804 | }
805 | scheduler = {
806 | "generator": hydra.utils.instantiate(
807 | config.train.generator_scheduler,
808 | optimizer=optimizer["generator"],
809 | ),
810 | "discriminator": hydra.utils.instantiate(
811 | config.train.discriminator_scheduler,
812 | optimizer=optimizer["discriminator"],
813 | ),
814 | }
815 |
816 | # define trainer
817 | trainer = Trainer(
818 | config=config,
819 | steps=0,
820 | epochs=0,
821 | data_loader=data_loader,
822 | model=model,
823 | criterion=criterion,
824 | optimizer=optimizer,
825 | scheduler=scheduler,
826 | device=device,
827 | )
828 |
829 | # load trained parameters from checkpoint
830 | if config.train.resume:
831 | resume = os.path.join(
832 | config.out_dir, "checkpoints", f"checkpoint-{config.train.resume}steps.pkl"
833 | )
834 | if os.path.exists(resume):
835 | trainer.load_checkpoint(resume)
836 | logger.info(f"Successfully resumed from {resume}.")
837 | else:
838 | logger.info(f"Failed to resume from {resume}.")
839 | sys.exit(0)
840 | else:
841 | logger.info("Start a new training process.")
842 |
843 | # run training loop
844 | try:
845 | trainer.run()
846 | except KeyboardInterrupt:
847 | trainer.save_checkpoint(
848 | os.path.join(
849 | config.out_dir, "checkpoints", f"checkpoint-{trainer.steps}steps.pkl"
850 | )
851 | )
852 | logger.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
853 |
854 |
855 | if __name__ == "__main__":
856 | main()
857 |
--------------------------------------------------------------------------------