├── Basic.pkl ├── LICENSE.txt ├── README.md ├── config └── config.yaml ├── datasets ├── __init__.py ├── audio_mel_dataset.py └── collater.py ├── distributed └── launch.py ├── encoder ├── audio.py ├── config.py ├── data_objects │ ├── __init__.py │ ├── random_cycler.py │ ├── speaker.py │ ├── speaker_batch.py │ ├── speaker_verification_dataset.py │ └── utterance.py ├── inference.py ├── model.py ├── params_data.py ├── params_model.py ├── plot_umap.py ├── preprocess.py ├── train.py └── visualizations.py ├── frontend ├── audio_preprocess.py ├── audio_world_process.py └── world │ ├── analysis │ └── synthesis ├── inference.py ├── layers ├── __init__.py ├── causal_conv.py ├── pqmf.py ├── residual_block.py ├── residual_stack.py ├── tf_layers.py └── upsample.py ├── losses ├── __init__.py └── stft_loss.py ├── models ├── Discriminator.py ├── Generator.py └── __init__.py ├── optimizers ├── __init__.py └── radam.py ├── preprocess.py ├── pretrained1.pt ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc └── utils.cpython-38.pyc ├── display.py └── utils.py /Basic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/Multi-Singer/a6e9f6138a1ddf52ebd4ec29e91795f34c108e42/Basic.pkl -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 SunMail-hub 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Singer: Fast Multi-Singer Singing Voice Vocoder With A Large-Scale Corpus 2 | 3 | PyTorch Implementation of (ACM MM'21)[Multi-Singer: Fast Multi-Singer Singing Voice Vocoder With A Large-Scale Corpus](https://dl.acm.org/doi/pdf/10.1145/3474085.3475437). 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2112.10358) 6 | [![GitHub Stars](https://img.shields.io/github/stars/Rongjiehuang/Multi-Singer?style=social)](https://github.com/Rongjiehuang/Multi-Singer) 7 | MIT License 8 | 9 | ## Requirements 10 | See requirements in requirement.txt: 11 | - linux 12 | - python 3.6 13 | - pytorch 1.0+ 14 | - librosa 15 | - json, tqdm, logging 16 | 17 | 18 | 19 | ## Getting started 20 | 21 | #### Apply recipe to your own dataset 22 | 23 | - Put any wav files in data directory 24 | - Edit configuration in config/config.yaml 25 | 26 | 27 | ## 1. Pretrain 28 | [Use our checkpoint](https://github.com/Rongjiehuang/Multi-Singer/blob/main/pretrained1.pt), or\ 29 | you can also train the encoder on your own [here](https://github.com/dipjyoti92/speaker_embeddings_GE2E), and set the ```enc_model_fpath``` in config/config.yaml. Please set params as those in ```encoder/params_data``` and ```encoder/params_model```. 30 | 31 | ## 2. Preprocess 32 | 33 | Extract mel-spectrogram 34 | 35 | ```python 36 | python preprocess.py -i data/wavs -o data/feature -c config/config.yaml 37 | ``` 38 | 39 | `-i` your audio folder 40 | 41 | `-o` output acoustic feature folder 42 | 43 | `-c` config file 44 | 45 | 46 | 47 | ## 3. Train 48 | 49 | Training conditioned on mel-spectrogram 50 | 51 | ```python 52 | python train.py -i data/feature -o checkpoints/ --config config/config.yaml 53 | ``` 54 | 55 | `-i` acoustic feature folder 56 | 57 | `-o` directory to save checkpoints 58 | 59 | `-c` config file 60 | 61 | ## 4. Inference 62 | 63 | ```python 64 | python inference.py -i data/feature -o outputs/ -c checkpoints/*.pkl -g config/config.yaml 65 | ``` 66 | 67 | `-i` acoustic feature folder 68 | 69 | `-o` directory to save generated speech 70 | 71 | `-c` checkpoints file 72 | 73 | `-c` config file 74 | 75 | ## 5. Singing Voice Synthesis 76 | For Singing Voice Synthesis: 77 | - Take [modified FastSpeech 2](https://github.com/ming024/FastSpeech2) for mel-spectrogram synthesis 78 | - Use synthesized mel-spectrogram in Multi-Singer for waveform synthesis. 79 | 80 | ## Checkpoint 81 | [Trained on OpenSinger](https://github.com/Rongjiehuang/Multi-Singer/blob/main/Basic.pkl) 82 | 83 | 84 | ## Acknowledgements 85 | [GE2E](https://github.com/dipjyoti92/speaker_embeddings_GE2E)\ 86 | [FastSpeech 2](https://github.com/ming024/FastSpeech2)\ 87 | [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) 88 | 89 | 90 | ## Citation 91 | ``` 92 | @inproceedings{huang2021multi, 93 | title={Multi-Singer: Fast Multi-Singer Singing Voice Vocoder With A Large-Scale Corpus}, 94 | author={Huang, Rongjie and Chen, Feiyang and Ren, Yi and Liu, Jinglin and Cui, Chenye and Zhao, Zhou}, 95 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia}, 96 | pages={3945--3954}, 97 | year={2021} 98 | } 99 | ``` 100 | 101 | ## Question 102 | Feel free to contact me at rongjiehuang@zju.edu.cn 103 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # FEATURE EXTRACTION SETTING # 3 | ########################################################### 4 | sampling_rate: 24000 # Sampling rate. 5 | fft_size: 512 # FFT size. 6 | hop_size: 128 # Hop size. 7 | win_length: 512 # Window length. 8 | # If set to null, it will be the same as fft_size. 9 | window: "hann" # Window function. 10 | num_mels: 80 # Number of mel basis. 11 | fmin: 30 # Minimum freq in mel basis calculation. 12 | fmax: 12000 # Maximum frequency upsample_paramsin mel basis calculation. 13 | global_gain_scale: 1.0 # Will be multiplied to all of waveform. 14 | trim_silence: false # Whether to trim the start and end of silence. 15 | trim_threshold_in_db: 60 # Need to tune carefully if the recording is not good. 16 | trim_frame_size: 2048 # Frame size in trimming. 17 | trim_hop_size: 512 # Hop size in trimming.use_embed 18 | format: "hdf5" # Feature file format. "npy" or "hdf5" is supported. 19 | use_f0: false 20 | use_chroma: false 21 | feat_type: librosa 22 | use_noise_input: true 23 | use_embed: true 24 | enc_model_fpath: "encoder/pretrained2.pt" 25 | ########################################################### 26 | # GENERATOR NETWORK ARCHITECTURE SETTING # 27 | ########################################################### 28 | generator_type: "Generator1" 29 | generator_params: 30 | in_channels: 4 # Number of input channels. 31 | out_channels: 1 # Number of output channels. 32 | kernel_size: 5 # Kernel size of dilated convolution. 33 | layers: 30 # Number of residual block layers. 34 | stacks: 3 # Number of stacks i.e., dilation cycles. 35 | residual_channels: 64 # Number of channels in residual conv. 36 | gate_channels: 128 # Number of channels in gated conv. 37 | skip_channels: 64 # Number of channels in skip conv. 38 | aux_channels: 80 # Number of channels for auxiliary feature conv. 39 | # Must be the same as num_mels. 40 | aux_context_window: 2 # Context window size for auxiliary feature. 41 | # If set to 2, previous 2 and future 2 frames will be considered. 42 | dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. 43 | use_weight_norm: true # Whether to use weight norm. 44 | # If set to true, it will be applied to all of the conv layers. 45 | upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture. 46 | upsample_params: # Upsampling network parameters. 47 | upsample_scales: [2, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size. 48 | 49 | ########################################################### 50 | # DISCRIMINATOR NETWORK ARCHITECTURE SETTING # 51 | ########################################################### 52 | discriminator_type: "Unconditional_Discriminator" 53 | discriminator_params: 54 | in_channels: 1 # Number of input channels. 55 | out_channels: 1 # Number of output channels. 56 | kernel_size: 3 # Number of output channels. 57 | layers: 10 # Number of conv layers. 58 | conv_channels: 64 # Number of chnn layers. 59 | bias: true # Whether to use bias parameter in conv. 60 | use_weight_norm: true # Whether to use weight norm. 61 | # If set to true, it will be applied to all of the conv layers. 62 | nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv. 63 | nonlinear_activation_params: # Nonlinear function parameters 64 | negative_slope: 0.2 # Alpha in LeakyReLU. 65 | 66 | embed_discriminator_type: "SingerConditional_Discriminator" 67 | embed_discriminator_params: 68 | in_channels: 1 69 | out_channels: 256 70 | kernel_sizes: [5, 3] 71 | channels: 16 72 | max_downsample_channels: 1024 73 | bias: true 74 | downsample_scales: [4, 4, 4, 4] 75 | nonlinear_activation: "LeakyReLU" 76 | model_hidden_size: 256 77 | model_num_layers: 3 78 | ########################################################### 79 | # STFT LOSS SETTING # 80 | ########################################################### 81 | stft_loss_params: 82 | fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. 83 | hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss 84 | win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. 85 | window: "hann_window" # Window function for STFT-based loss 86 | use_subband_stft_loss: false 87 | subband_stft_loss_params: 88 | fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. 89 | hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss 90 | win_lengths: [150, 300, 60] # List of window length for STFT-based loss. 91 | window: "hann_window" # Window function for STFT-based loss 92 | 93 | ########################################################### 94 | # ADVERSARIAL LOSS SETTING # 95 | ########################################################### 96 | use_feat_match_loss: false # Whether to use feature matching loss. 97 | lambda_feat_match: 25.0 # Loss balancing coefficient for feature matching loss. 98 | lambda_adv: 4.0 # Loss balancing coefficient. 99 | lambda_embed: 2.0 100 | ########################################################### 101 | # DATA LOADER SETTING # 102 | ########################################################### 103 | batch_size: 6 # Batch size. 104 | test_num: 50 105 | batch_max_steps: 12800 # Not change! Length of each audio in batch. Make sure dividable by hop_size. 106 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 107 | num_workers: 0 # Number of wolambda_embedrkers in Pytorch DataLoader. 108 | remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. 109 | allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. 110 | interval: 1 # Discriminator train every {interval} steps 111 | ########################################################### 112 | # OPTIMIZER & SCHEDULER SETTING # 113 | ########################################################### 114 | generator_optimizer_params: 115 | lr: 0.0001 # Generator's learning rate. 116 | eps: 1.0e-6 # Generator's epsilon. 117 | weight_decay: 0.0 # Generator's weight decay coefficient. 118 | generator_scheduler_params: 119 | step_size: 200000 # Generator's scheduler step size. 120 | gamma: 0.5 # Generator's scheduler gamma. 121 | # At each step size, lr will be multiplied by this parameter. 122 | generator_grad_norm: 10 # Generator's gradient norm. 123 | discriminator_optimizer_params: 124 | lr: 0.00005 # Discriminator's learning rate. 125 | eps: 1.0e-6 # Discriminator's epsilon. 126 | weight_decay: 0.0 # Discriminator's weight decay coefficient. 127 | discriminator_scheduler_params: 128 | step_size: 200000 # Discriminator's scheduler step size. 129 | gamma: 0.5 # Discriminator's scheduler gamma. 130 | # At each step size, lr will be multiplied by this parameter. 131 | discriminator_grad_norm: 1 # Discriminator's gradient norm. 132 | 133 | embed_discriminator_optimizer_params: 134 | lr: 0.00005 # Discriminator's learning rate. 135 | eps: 1.0e-6 # Discriminator's epsilon. 136 | weight_decay: 0.0 # Discriminator's weight decay coefficient. 137 | embed_discriminator_scheduler_params: 138 | step_size: 200000 # Discriminator's scheduler step size. 139 | gamma: 0.5 # Discriminator's scheduler gamma. 140 | # At each step size, lr will be multiplied by this parameter. 141 | embed_discriminator_grad_norm: 1 # Discriminator's gradient norm. 142 | ########################################################### 143 | # INTERVAL SETTING # 144 | ########################################################### 145 | discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator. 146 | train_max_steps: 430000 # Number of training steps. 147 | save_interval_steps: 5000 # Interval steps to save checkpoint. 148 | eval_interval_steps: 2000 # Interval steps to evaluate the network. 149 | log_interval_steps: 1000 # Interval steps to record the training log. 150 | ########################################################### 151 | # OTHER SETTING # 152 | ########################################################### 153 | num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. 154 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_mel_dataset import * # NOQA 2 | from .collater import * 3 | -------------------------------------------------------------------------------- /datasets/audio_mel_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Dataset modules.""" 7 | 8 | import logging 9 | import os 10 | 11 | from multiprocessing import Manager 12 | 13 | import numpy as np 14 | 15 | from torch.utils.data import Dataset 16 | 17 | from utils import find_files 18 | from utils import read_hdf5 19 | 20 | 21 | 22 | class AudioMelEmbedDataset(Dataset): #读取audio与mel h5数据集 23 | """PyTorch compatible audio and mel dataset.""" # 读取音频、梅尔频谱数据集 24 | 25 | def __init__(self, 26 | root_file, 27 | feat_type='librosa', 28 | audio_length_threshold=None, 29 | frames_threshold=None, 30 | use_f0=False, 31 | use_chroma=False, 32 | use_utt_id=False, 33 | allow_cache=False, 34 | eval=False 35 | ): 36 | """Initialize dataset. 37 | 38 | Args: 39 | root_dir (str): Root directory including dumped files. 40 | audio_query (str): Query to find audio files in root_dir. 41 | mel_query (str): Query to find feature files in root_dir. 42 | audio_load_fn (func): Function to load audio file. 43 | mel_load_fn (func): Function to load feature file. 44 | audio_length_threshold (int): Threshold to remove short audio files. 45 | mel_length_threshold (int): Threshold to remove short feature files. 46 | return_utt_id (bool): Whether to return the utterance id with arrays. 47 | allow_cache (bool): Whether to allow cache of the loaded files. 48 | 49 | """ 50 | # find all of audio and mel files 51 | if eval: 52 | files = sorted(find_files(root_file, "*.h5")) 53 | else: 54 | files = [] 55 | with open(root_file, encoding='utf-8') as f: 56 | for line in f: 57 | files.append(line.strip().split('|')[1]) 58 | files = sorted(files) 59 | 60 | audio_load_fn = lambda x: read_hdf5(x, "wav") # 读取音频文件映射函数: h5["wav"] 61 | feat_load_fn = lambda x: read_hdf5(x, "mel") # 读取梅尔文件映射函数: h5["mel"] 62 | embed_load_fn = lambda x: read_hdf5(x, "embed") # 读取embed文件映射函数: h5["embed"] 63 | 64 | if feat_type == "world": # 读取world提取特征 65 | feat_load_fn = lambda x: read_hdf5(x, "feats") # 使用world提取特征 h5["feats"] 66 | 67 | # filter by threshold 68 | if audio_length_threshold is not None: # 设置音频最长长度 69 | audio_lengths = [audio_load_fn(f).shape[0] for f in files] 70 | idxs = [idx for idx in range(len(files)) if audio_lengths[idx] > audio_length_threshold] # 过滤得到音频长度超过阈值 71 | if len(files) != len(idxs): 72 | logging.warning(f"Some files are filtered by audio length threshold " 73 | f"({len(files)} -> {len(idxs)}).") 74 | files = [files[idx] for idx in idxs] 75 | if frames_threshold is not None: 76 | frames = [feat_load_fn(f).shape[0] for f in files] 77 | idxs = [idx for idx in range(len(files)) if frames[idx] > frames_threshold] # 过滤得到梅尔长度超过阈值 78 | if len(files) != len(idxs): 79 | logging.warning(f"Some files are filtered by mel length threshold " 80 | f"({len(files)} -> {len(idxs)}).") 81 | files = [files[idx] for idx in idxs] 82 | 83 | # assert the number of files 84 | assert len(files) != 0, f"Not found any audio files in ${root_file}." 85 | 86 | self.files = files 87 | self.audio_load_fn = audio_load_fn 88 | self.feat_load_fn = feat_load_fn 89 | self.embed_load_fn = embed_load_fn 90 | 91 | self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in files] 92 | self.use_f0 = use_f0 93 | self.use_chroma = use_chroma 94 | self.use_utt_id = use_utt_id 95 | self.allow_cache = allow_cache 96 | 97 | if use_f0: 98 | self.f0_origin_load_fn = lambda x: read_hdf5(x, "f0_origin") 99 | # self.uv_load_fn = lambda x: read_hdf5(x, "uv") 100 | # self.f0_load_fn = lambda x: read_hdf5(x, "f0") 101 | 102 | if use_chroma: 103 | self.chroma_load_fn =lambda x: read_hdf5(x, "chroma") 104 | 105 | if allow_cache: 106 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 107 | self.manager = Manager() 108 | self.caches = self.manager.list() 109 | self.caches += [() for _ in range(len(files))] 110 | 111 | def __getitem__(self, idx): 112 | """Get specified idx items. 113 | 114 | Args: 115 | idx (int): Index of the item. 116 | 117 | Returns: 118 | str: Utterance id (only in return_utt_id = True). 119 | ndarray: Audio signal (T,). 120 | ndarray: Feature (T', C). 121 | ndarray: embed (256, ). 122 | """ 123 | if self.allow_cache and len(self.caches[idx]) != 0: 124 | return self.caches[idx] 125 | 126 | audio = self.audio_load_fn(self.files[idx]) 127 | feat = self.feat_load_fn(self.files[idx]) 128 | embed = self.embed_load_fn(self.files[idx]) 129 | items = {'audio':audio, 'feat':feat, 'embed':embed} 130 | 131 | if self.use_utt_id: 132 | items['utt_id'] = self.utt_ids[idx] 133 | if self.use_chroma: 134 | items['chroma'] = self.chroma_load_fn(self.files[idx]) 135 | if self.use_f0: 136 | # items['f0'] = self.f0_load_fn(self.files[idx]) 137 | items['f0_origin'] = self.f0_origin_load_fn(self.files[idx]) 138 | # items['uv'] = self.uv_load_fn(self.files[idx]) 139 | 140 | if self.allow_cache: 141 | self.caches[idx] = items 142 | 143 | return items # 返回音频与梅尔频谱,以及其他可选参数(F0) 144 | 145 | def __len__(self): 146 | """Return dataset length. 147 | 148 | Returns: 149 | int: The length of dataset. 150 | 151 | """ 152 | return len(self.files) 153 | 154 | class AudioDataset(Dataset): 155 | """PyTorch compatible audio dataset.""" 156 | 157 | def __init__(self, 158 | root_dir, 159 | audio_query="*-wave.npy", 160 | audio_length_threshold=None, 161 | audio_load_fn=np.load, 162 | return_utt_id=False, 163 | allow_cache=False, 164 | ): 165 | """Initialize dataset. 166 | 167 | Args: 168 | root_dir (str): Root directory including dumped files. 169 | audio_query (str): Query to find audio files in root_dir. 170 | audio_load_fn (func): Function to load audio file. 171 | audio_length_threshold (int): Threshold to remove short audio files. 172 | return_utt_id (bool): Whether to return the utterance id with arrays. 173 | allow_cache (bool): Whether to allow cache of the loaded files. 174 | 175 | """ 176 | # find all of audio and mel files 177 | audio_files = sorted(find_files(root_dir, audio_query)) 178 | 179 | # filter by threshold 180 | if audio_length_threshold is not None: 181 | audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] 182 | idxs = [idx for idx in range(len(audio_files)) if audio_lengths[idx] > audio_length_threshold] 183 | if len(audio_files) != len(idxs): 184 | logging.waning(f"some files are filtered by audio length threshold " 185 | f"({len(audio_files)} -> {len(idxs)}).") 186 | audio_files = [audio_files[idx] for idx in idxs] 187 | 188 | # assert the number of files 189 | assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}." 190 | 191 | self.audio_files = audio_files 192 | self.audio_load_fn = audio_load_fn 193 | self.return_utt_id = return_utt_id 194 | if ".npy" in audio_query: 195 | self.utt_ids = [os.path.basename(f).replace("-wave.npy", "") for f in audio_files] 196 | else: 197 | self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in audio_files] 198 | self.allow_cache = allow_cache 199 | if allow_cache: 200 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 201 | self.manager = Manager() 202 | self.caches = self.manager.list() 203 | self.caches += [() for _ in range(len(audio_files))] 204 | 205 | def __getitem__(self, idx): 206 | """Get specified idx items. 207 | 208 | Args: 209 | idx (int): Index of the item. 210 | 211 | Returns: 212 | str: Utterance id (only in return_utt_id = True). 213 | ndarray: Audio (T,). 214 | 215 | """ 216 | if self.allow_cache and len(self.caches[idx]) != 0: 217 | return self.caches[idx] 218 | 219 | utt_id = self.utt_ids[idx] 220 | audio = self.audio_load_fn(self.audio_files[idx]) 221 | 222 | if self.return_utt_id: 223 | items = utt_id, audio 224 | else: 225 | items = audio 226 | 227 | if self.allow_cache: 228 | self.caches[idx] = items 229 | 230 | return items 231 | 232 | def __len__(self): 233 | """Return dataset length. 234 | 235 | Returns: 236 | int: The length of dataset. 237 | 238 | """ 239 | return len(self.audio_files) 240 | 241 | 242 | class MelDataset(Dataset): 243 | """PyTorch compatible mel dataset.""" 244 | 245 | def __init__(self, 246 | root_dir, 247 | mel_query="*-feats.npy", 248 | mel_length_threshold=None, 249 | mel_load_fn=np.load, 250 | return_utt_id=False, 251 | allow_cache=False, 252 | ): 253 | """Initialize dataset. 254 | 255 | Args: 256 | root_dir (str): Root directory including dumped files. 257 | mel_query (str): Query to find feature files in root_dir. 258 | mel_load_fn (func): Function to load feature file. 259 | mel_length_threshold (int): Threshold to remove short feature files. 260 | return_utt_id (bool): Whether to return the utterance id with arrays. 261 | allow_cache (bool): Whether to allow cache of the loaded files. 262 | 263 | """ 264 | # find all of the mel files 265 | mel_files = sorted(find_files(root_dir, mel_query)) 266 | 267 | # filter by threshold 268 | if mel_length_threshold is not None: 269 | mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files] 270 | idxs = [idx for idx in range(len(mel_files)) if mel_lengths[idx] > mel_length_threshold] 271 | if len(mel_files) != len(idxs): 272 | logging.warning(f"Some files are filtered by mel length threshold " 273 | f"({len(mel_files)} -> {len(idxs)}).") 274 | mel_files = [mel_files[idx] for idx in idxs] 275 | 276 | # assert the number of files 277 | assert len(mel_files) != 0, f"Not found any mel files in ${root_dir}." 278 | 279 | self.mel_files = mel_files 280 | self.mel_load_fn = mel_load_fn 281 | self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files] 282 | if ".npy" in mel_query: 283 | self.utt_ids = [os.path.basename(f).replace("-feats.npy", "") for f in mel_files] 284 | else: 285 | self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files] 286 | self.return_utt_id = return_utt_id 287 | self.allow_cache = allow_cache 288 | if allow_cache: 289 | # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 290 | self.manager = Manager() 291 | self.caches = self.manager.list() 292 | self.caches += [() for _ in range(len(mel_files))] 293 | 294 | def __getitem__(self, idx): 295 | """Get specified idx items. 296 | 297 | Args: 298 | idx (int): Index of the item. 299 | 300 | Returns: 301 | str: Utterance id (only in return_utt_id = True). 302 | ndarray: Feature (T', C). 303 | 304 | """ 305 | if self.allow_cache and len(self.caches[idx]) != 0: 306 | return self.caches[idx] 307 | 308 | utt_id = self.utt_ids[idx] 309 | mel = self.mel_load_fn(self.mel_files[idx]) 310 | 311 | if self.return_utt_id: 312 | items = utt_id, mel 313 | else: 314 | items = mel 315 | 316 | if self.allow_cache: 317 | self.caches[idx] = items 318 | 319 | return items 320 | 321 | def __len__(self): 322 | """Return dataset length. 323 | 324 | Returns: 325 | int: The length of dataset. 326 | 327 | """ 328 | return len(self.mel_files) 329 | 330 | -------------------------------------------------------------------------------- /datasets/collater.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | 5 | from multiprocessing import Manager 6 | 7 | import numpy as np 8 | 9 | from torch.utils.data import Dataset 10 | 11 | from utils import find_files 12 | from utils import read_hdf5 13 | import torch 14 | 15 | 16 | 17 | class Feats_Collater(object): 18 | """Customized collater for Pytorch DataLoader in training.""" # 收集函数collator 19 | 20 | def __init__(self, 21 | batch_max_steps=20480, 22 | out_dim=1, 23 | hop_size=256, 24 | aux_context_window=2, 25 | use_noise_input=False, 26 | use_f0=False, 27 | use_chroma=False 28 | ): 29 | """Initialize customized collater for PyTorch DataLoader. 30 | 31 | Args: 32 | batch_max_steps (int): The maximum length of input signal in batch. 33 | hop_size (int): Hop size of auxiliary features. 34 | aux_context_window (int): Context window size for auxiliary feature conv. 35 | use_noise_input (bool): Whether to use noise input. 36 | 37 | """ 38 | if batch_max_steps % hop_size != 0: 39 | batch_max_steps += -(batch_max_steps % hop_size) 40 | assert batch_max_steps % hop_size == 0 41 | self.batch_max_steps = batch_max_steps 42 | self.out_dim = out_dim 43 | self.batch_max_frames = batch_max_steps // hop_size 44 | self.hop_size = hop_size 45 | self.aux_context_window = aux_context_window 46 | self.use_noise_input = use_noise_input 47 | self.use_f0 = use_f0 48 | self.use_chroma = use_chroma 49 | 50 | # set useful values in random cutting 随机截取长度 51 | self.start_offset = aux_context_window # 开始偏移位置 = 窗大小 52 | self.end_offset = -(self.batch_max_frames + aux_context_window) # 结束偏移位置 = -(最大帧长 + 窗大小) 53 | self.mel_threshold = self.batch_max_frames + 2 * aux_context_window 54 | 55 | def __call__(self, batch): 56 | # check length 57 | # batch = [self._adjust_length(*b) for b in batch if len(b[1]) > self.mel_threshold] 58 | xs, cs = [b['audio'] for b in batch], [b['feat'] for b in batch] # batch 包含audio & feat(mel) 59 | 60 | # make batch with random cut 随机裁剪窗 61 | c_lengths = [len(c) for c in cs] 62 | start_frames = np.array([np.random.randint( 63 | self.start_offset, cl + self.end_offset) for cl in c_lengths]) 64 | x_starts = start_frames * self.hop_size # audio 起始 65 | x_ends = x_starts + self.batch_max_steps # audio 结束 66 | c_starts = start_frames - self.aux_context_window # mel 起始 67 | c_ends = start_frames + self.batch_max_frames + self.aux_context_window # mel 结束 68 | y_batch = [x[start: end] for x, start, end in zip(xs, x_starts, x_ends)] # 得到audio 69 | c_batch = [c[start: end] for c, start, end in zip(cs, c_starts, c_ends)] # 得到mel 70 | 71 | # convert each batch to tensor, asuume that each item in batch has the same length—————将numpy转为tensor 72 | y_batch = torch.tensor(y_batch, dtype=torch.float).unsqueeze(1) # (B, 1, T) 73 | c_batch = torch.tensor(c_batch, dtype=torch.float).transpose(2, 1) # (B, C, T') 74 | 75 | batchs = {'audios': y_batch, 'feats': c_batch} ################### 得到 batch["audio"] 与 batch["feats"] ################### 76 | 77 | if self.use_f0: 78 | # f0s = [b['f0'] for b in batch if 'f0' in b] 79 | # f0_batch = [f0[start: end] for f0, start, end in zip(f0s, c_starts, c_ends)] 80 | # f0_batch = torch.tensor(f0_batch, dtype=torch.long) 81 | # batchs['f0s'] = f0_batch 82 | 83 | f0_origins = [b['f0_origin'] for b in batch if "f0_origin" in b] 84 | f0_origins_batch = [f0[start+self.aux_context_window: end-self.aux_context_window] for f0, start, end in zip(f0_origins, c_starts, c_ends)] 85 | f0_origins_batch = torch.tensor(f0_origins_batch, dtype=torch.float) 86 | batchs['f0_origins'] = f0_origins_batch 87 | 88 | # vus = [b['uv'] for b in batch if "uv" in b] 89 | # vus_batch = [vu[start: end] for vu, start, end in zip(vus, c_starts, c_ends)] 90 | # vus_batch = torch.tensor(vus_batch, dtype=torch.long) 91 | # batchs['uvs'] = vus_batch 92 | 93 | if self.use_chroma: 94 | chromas = [b['chroma'] for b in batch if 'chroma' in b] 95 | chroma_batch = [chromas[start: end] for chroma, start, end in zip(chromas, c_starts, c_ends)] 96 | chroma_batch = torch.tensor(chroma_batch, dtype=torch.float).transpose(2, 1) # (B, C, T') 97 | batchs['chromas'] = chroma_batch 98 | # make input noise signal batch tensor 99 | if self.use_noise_input: 100 | # z_batch = torch.randn(y_batch.size()) # (B, 1, T) 101 | z_batch = torch.randn(y_batch.size(0), self.out_dim, y_batch.size(2) // self.out_dim) # (B, 1, T) 102 | batchs['noise'] = z_batch 103 | 104 | return batchs 105 | 106 | 107 | class Embeds_Collater(object): 108 | """Customized collater for Pytorch DataLoader in training.""" # 收集函数collator 109 | 110 | def __init__(self, 111 | batch_max_steps=20480, 112 | out_dim=1, 113 | hop_size=256, 114 | aux_context_window=2, 115 | use_noise_input=False, 116 | use_f0=False, 117 | use_chroma=False 118 | ): 119 | """Initialize customized collater for PyTorch DataLoader. 120 | 121 | Args: 122 | batch_max_steps (int): The maximum length of input signal in batch. 123 | hop_size (int): Hop size of auxiliary features. 124 | aux_context_window (int): Context window size for auxiliary feature conv. 125 | use_noise_input (bool): Whether to use noise input. 126 | 127 | """ 128 | if batch_max_steps % hop_size != 0: 129 | batch_max_steps += -(batch_max_steps % hop_size) 130 | assert batch_max_steps % hop_size == 0 131 | self.batch_max_steps = batch_max_steps 132 | self.out_dim = out_dim 133 | self.batch_max_frames = batch_max_steps // hop_size 134 | self.hop_size = hop_size 135 | self.aux_context_window = aux_context_window 136 | self.use_noise_input = use_noise_input 137 | self.use_f0 = use_f0 138 | self.use_chroma = use_chroma 139 | 140 | # set useful values in random cutting 随机截取长度 141 | self.start_offset = aux_context_window # 开始偏移位置 = 窗大小 142 | self.end_offset = -(self.batch_max_frames + aux_context_window) # 结束偏移位置 = -(最大帧长 + 窗大小) 143 | self.mel_threshold = self.batch_max_frames + 2 * aux_context_window 144 | 145 | def __call__(self, batch): 146 | # check length 147 | # batch = [self._adjust_length(*b) for b in batch if len(b[1]) > self.mel_threshold] 148 | xs, cs = [b['audio'] for b in batch], [b['feat'] for b in batch] # batch 包含audio & feat(mel) 149 | embed = [b['embed'] for b in batch] 150 | 151 | # make batch with random cut 随机裁剪窗 152 | c_lengths = [len(c) for c in cs] 153 | start_frames = np.array([np.random.randint( 154 | self.start_offset, cl + self.end_offset) for cl in c_lengths]) 155 | x_starts = start_frames * self.hop_size # audio 起始 156 | x_ends = x_starts + self.batch_max_steps # audio 结束 157 | c_starts = start_frames - self.aux_context_window # mel 起始 158 | c_ends = start_frames + self.batch_max_frames + self.aux_context_window # mel 结束 159 | y_batch = [x[start: end] for x, start, end in zip(xs, x_starts, x_ends)] # 得到audio 160 | c_batch = [c[start: end] for c, start, end in zip(cs, c_starts, c_ends)] # 得到mel 161 | 162 | # convert each batch to tensor, asuume that each item in batch has the same length—————将numpy转为tensor 163 | y_batch = torch.tensor(y_batch, dtype=torch.float).unsqueeze(1) # (B, 1, T) 164 | c_batch = torch.tensor(c_batch, dtype=torch.float).transpose(2, 1) # (B, C, T') 165 | embed_batch = torch.tensor(embed, dtype=torch.float).unsqueeze(-1) # (B, 128) -> (B, 128, 1) 166 | 167 | batchs = {'audios': y_batch, 'feats': c_batch, 'embed': embed_batch} ################### 得到 batch["audio"] 与 batch["feats"] ################### 168 | 169 | if self.use_f0: 170 | # f0s = [b['f0'] for b in batch if 'f0' in b] 171 | # f0_batch = [f0[start: end] for f0, start, end in zip(f0s, c_starts, c_ends)] 172 | # f0_batch = torch.tensor(f0_batch, dtype=torch.long) 173 | # batchs['f0s'] = f0_batch 174 | 175 | f0_origins = [b['f0_origin'] for b in batch if "f0_origin" in b] 176 | f0_origins_batch = [f0[start+self.aux_context_window: end-self.aux_context_window] for f0, start, end in zip(f0_origins, c_starts, c_ends)] 177 | f0_origins_batch = torch.tensor(f0_origins_batch, dtype=torch.float) 178 | batchs['f0_origins'] = f0_origins_batch 179 | 180 | # vus = [b['uv'] for b in batch if "uv" in b] 181 | # vus_batch = [vu[start: end] for vu, start, end in zip(vus, c_starts, c_ends)] 182 | # vus_batch = torch.tensor(vus_batch, dtype=torch.long) 183 | # batchs['uvs'] = vus_batch 184 | 185 | if self.use_chroma: 186 | chromas = [b['chroma'] for b in batch if 'chroma' in b] 187 | chroma_batch = [chromas[start: end] for chroma, start, end in zip(chromas, c_starts, c_ends)] 188 | chroma_batch = torch.tensor(chroma_batch, dtype=torch.float).transpose(2, 1) # (B, C, T') 189 | batchs['chromas'] = chroma_batch 190 | # make input noise signal batch tensor 191 | if self.use_noise_input: 192 | # z_batch = torch.randn(y_batch.size()) # (B, 1, T) 193 | z_batch = torch.randn(y_batch.size(0), 1, y_batch.size(2) // self.out_dim) # (B, 1, T) 194 | batchs['noise'] = z_batch 195 | 196 | return batchs -------------------------------------------------------------------------------- /distributed/launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """Distributed process launcher. 5 | 6 | This code is modified from https://github.com/pytorch/pytorch/blob/v1.3.0/torch/distributed/launch.py. 7 | 8 | """ 9 | import os 10 | import subprocess 11 | import sys 12 | 13 | from argparse import ArgumentParser 14 | from argparse import REMAINDER 15 | 16 | 17 | def parse_args(): 18 | """Parse arguments.""" 19 | parser = ArgumentParser(description="PyTorch distributed training launch " 20 | "helper utilty that will spawn up " 21 | "multiple distributed processes") 22 | 23 | # Optional arguments for the launch helper 24 | parser.add_argument("--nnodes", type=int, default=1, 25 | help="The number of nodes to use for distributed " 26 | "training") 27 | parser.add_argument("--node_rank", type=int, default=0, 28 | help="The rank of the node for multi-node distributed " 29 | "training") 30 | parser.add_argument("--nproc_per_node", type=int, default=1, 31 | help="The number of processes to launch on each node, " 32 | "for GPU training, this is recommended to be set " 33 | "to the number of GPUs in your system so that " 34 | "each process can be bound to a single GPU.") 35 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 36 | help="Master node (rank 0)'s address, should be either " 37 | "the IP address or the hostname of node 0, for " 38 | "single node multi-proc training, the " 39 | "--master_addr can simply be 127.0.0.1") 40 | parser.add_argument("--master_port", default=29500, type=int, 41 | help="Master node (rank 0)'s free port that needs to " 42 | "be used for communciation during distributed " 43 | "training") 44 | parser.add_argument("--use_env", default=False, action="store_true", 45 | help="Use environment variable to pass " 46 | "'local rank'. For legacy reasons, the default value is False. " 47 | "If set to True, the script will not pass " 48 | "--local_rank as argument, and will instead set LOCAL_RANK.") 49 | parser.add_argument("-m", "--module", default=False, action="store_true", 50 | help="Changes each process to interpret the launch script " 51 | "as a python module, executing with the same behavior as" 52 | "'python -m'.") 53 | parser.add_argument("-c", "--command", default=False, action="store_true", 54 | help="Changes each process to interpret the launch script " 55 | "as a command.") 56 | parser.add_argument("-s", "--start", default=0, type=int, 57 | help="Changes each process to interpret the launch script " 58 | "as a command.") 59 | # positional 60 | parser.add_argument("training_script", type=str, 61 | help="The full path to the single GPU training " 62 | "program/script/command to be launched in parallel, " 63 | "followed by all the arguments for the " 64 | "training script") 65 | 66 | # rest from the training program 67 | parser.add_argument('training_script_args', nargs=REMAINDER) 68 | return parser.parse_args() 69 | 70 | 71 | def main(): 72 | """Launch distributed processes.""" 73 | args = parse_args() 74 | 75 | # world size in terms of number of processes 76 | dist_world_size = args.nproc_per_node * args.nnodes 77 | 78 | # set PyTorch distributed related environmental variables 79 | current_env = os.environ.copy() 80 | current_env["MASTER_ADDR"] = args.master_addr 81 | current_env["MASTER_PORT"] = str(args.master_port) 82 | current_env["WORLD_SIZE"] = str(dist_world_size) 83 | 84 | processes = [] 85 | 86 | if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1: 87 | current_env["OMP_NUM_THREADS"] = str(1) 88 | print("*****************************************\n" 89 | "Setting OMP_NUM_THREADS environment variable for each process " 90 | "to be {} in default, to avoid your system being overloaded, " 91 | "please further tune the variable for optimal performance in " 92 | "your application as needed. \n" 93 | "*****************************************".format(current_env["OMP_NUM_THREADS"])) 94 | 95 | for local_rank in range(args.start, args.nproc_per_node): 96 | # each process's rank 97 | dist_rank = args.nproc_per_node * args.node_rank + local_rank 98 | current_env["RANK"] = str(dist_rank) 99 | current_env["LOCAL_RANK"] = str(local_rank) 100 | 101 | # spawn the processes 102 | if args.command: 103 | cmd = [args.training_script] 104 | else: 105 | cmd = [sys.executable, "-u"] 106 | if args.module: 107 | cmd.append("-m") 108 | cmd.append(args.training_script) 109 | 110 | if not args.use_env: 111 | cmd.append("--local_rank={}".format(local_rank)) 112 | 113 | cmd.extend(args.training_script_args) 114 | 115 | process = subprocess.Popen(cmd, env=current_env) 116 | processes.append(process) 117 | 118 | for process in processes: 119 | process.wait() 120 | if process.returncode != 0: 121 | raise subprocess.CalledProcessError( 122 | returncode=process.returncode, cmd=cmd) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /encoder/audio.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import binary_dilation 2 | from encoder.params_data import * 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | import numpy as np 6 | import webrtcvad 7 | import torchaudio 8 | import torch 9 | import librosa 10 | import struct 11 | 12 | int16_max = (2 ** 15) - 1 13 | 14 | 15 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], 16 | source_sr: Optional[int] = None): 17 | """ 18 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform 19 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters. 20 | 21 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 22 | just .wav), either the waveform as a numpy array of floats. 23 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 24 | preprocessing. After preprocessing, the waveform's sampling rate will match the data 25 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 26 | this argument will be ignored. 27 | """ 28 | # Load the wav from disk if needed 29 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 30 | wav, source_sr = librosa.load(fpath_or_wav, sr=None) 31 | else: 32 | wav = fpath_or_wav 33 | 34 | # Resample the wav if needed 35 | if source_sr is not None and source_sr != sampling_rate: 36 | wav = librosa.resample(wav, source_sr, sampling_rate) 37 | 38 | # Apply the preprocessing: normalize volume and shorten long silences 39 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) 40 | # wav = trim_long_silences(wav) 41 | 42 | return wav 43 | 44 | 45 | def preprocess_wav_torch(wav, source_sr=None): 46 | """ 47 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform 48 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters. 49 | 50 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 51 | just .wav), either the waveform as a numpy array of floats. 52 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 53 | preprocessing. After preprocessing, the waveform's sampling rate will match the data 54 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 55 | this argument will be ignored. 56 | """ 57 | # Apply the preprocessing: normalize volume and shorten long silences 58 | wav = normalize_volume_torch(wav, audio_norm_target_dBFS, increase_only=True) 59 | return wav 60 | 61 | def wav_to_mel_spectrogram(wav): 62 | """ 63 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 64 | Note: this not a log-mel spectrogram. 65 | wav: numpy (T,) 66 | return: numpy (T', n_mels) 67 | """ 68 | # frames = librosa.feature.melspectrogram( 69 | # wav, 70 | # sampling_rate, 71 | # n_fft=int(sampling_rate * mel_window_length / 1000), 72 | # hop_length=int(sampling_rate * mel_window_step / 1000), 73 | # n_mels=mel_n_channels 74 | # ) 75 | frames = librosa.feature.melspectrogram( 76 | wav, 77 | sampling_rate, 78 | n_fft=n_fft, 79 | hop_length=hop_length, 80 | win_length=win_length, 81 | n_mels=mel_n_channels, 82 | ) 83 | 84 | return frames.astype(np.float32).T 85 | 86 | def wav_to_mel_spectrogram_torch_preprocess(wav): 87 | """ 88 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 89 | Note: this not a log-mel spectrogram. 90 | wav: Tensor (T) 91 | return: Tensor (n_mels, T') 92 | """ 93 | 94 | MelSpectrogram = torchaudio.transforms.MelSpectrogram( 95 | sample_rate=sampling_rate, 96 | win_length=win_length, 97 | hop_length=hop_length, 98 | n_mels=mel_n_channels, 99 | n_fft=n_fft, 100 | power=2.0, 101 | ) 102 | frames = MelSpectrogram(wav.float()) 103 | return frames.T 104 | 105 | def wav_to_mel_spectrogram_torch(wav): 106 | """ 107 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 108 | Note: this not a log-mel spectrogram. 109 | wav: Tensor (T) 110 | return: Tensor (n_mels, T') 111 | """ 112 | 113 | MelSpectrogram = torchaudio.transforms.MelSpectrogram( 114 | sample_rate=sampling_rate, 115 | win_length=win_length, 116 | hop_length=hop_length, 117 | n_mels=mel_n_channels, 118 | n_fft=n_fft, 119 | power=2.0, 120 | ).cuda() 121 | frames = MelSpectrogram(wav.float()) 122 | return frames.T 123 | 124 | 125 | def trim_long_silences(wav): 126 | """ 127 | Ensures that segments without voice in the waveform remain no longer than a 128 | threshold determined by the VAD parameters in params.py. 129 | 130 | :param wav: the raw waveform as a numpy array of floats 131 | :return: the same waveform with silences trimmed away (length <= original wav length) 132 | """ 133 | # Compute the voice detection window size 134 | samples_per_window = (vad_window_length * sampling_rate) // 1000 135 | 136 | # Trim the end of the audio to have a multiple of the window size 137 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 138 | 139 | # Convert the float waveform to 16-bit mono PCM 140 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 141 | 142 | # Perform voice activation detection 143 | voice_flags = [] 144 | vad = webrtcvad.Vad(mode=3) 145 | for window_start in range(0, len(wav), samples_per_window): 146 | window_end = window_start + samples_per_window 147 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 148 | sample_rate=sampling_rate)) 149 | voice_flags = np.array(voice_flags) 150 | 151 | # Smooth the voice detection with a moving average 152 | def moving_average(array, width): 153 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 154 | ret = np.cumsum(array_padded, dtype=float) 155 | ret[width:] = ret[width:] - ret[:-width] 156 | return ret[width - 1:] / width 157 | 158 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 159 | audio_mask = np.round(audio_mask).astype(np.bool) 160 | 161 | # Dilate the voiced regions 162 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 163 | audio_mask = np.repeat(audio_mask, samples_per_window) 164 | 165 | return wav[audio_mask == True] 166 | 167 | 168 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): 169 | if increase_only and decrease_only: 170 | raise ValueError("Both increase only and decrease only are set") 171 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) 172 | 173 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): 174 | return wav 175 | return wav * (10 ** (dBFS_change / 20)) 176 | 177 | def normalize_volume_torch(wav, target_dBFS, increase_only=False, decrease_only=False): 178 | if increase_only and decrease_only: 179 | raise ValueError("Both increase only and decrease only are set") 180 | dBFS_change = target_dBFS - 10 * torch.log10(torch.mean(wav ** 2)) 181 | 182 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): 183 | return wav 184 | return wav * (10 ** (dBFS_change / 20)) 185 | 186 | if __name__ == '__main__': 187 | one = torch.ones((12800)) 188 | two = np.ones(12800) 189 | one_out = wav_to_mel_spectrogram_torch(one) 190 | two_out = wav_to_mel_spectrogram(two) 191 | print(one_out.shape, two_out.shape) 192 | -------------------------------------------------------------------------------- /encoder/config.py: -------------------------------------------------------------------------------- 1 | librispeech_datasets = { 2 | "train": { 3 | "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], 4 | "other": ["LibriSpeech/train-other-500"] 5 | }, 6 | "test": { 7 | "clean": ["LibriSpeech/test-clean"], 8 | "other": ["LibriSpeech/test-other"] 9 | }, 10 | "dev": { 11 | "clean": ["LibriSpeech/dev-clean"], 12 | "other": ["LibriSpeech/dev-other"] 13 | }, 14 | } 15 | libritts_datasets = { 16 | "train": { 17 | "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], 18 | "other": ["LibriTTS/train-other-500"] 19 | }, 20 | "test": { 21 | "clean": ["LibriTTS/test-clean"], 22 | "other": ["LibriTTS/test-other"] 23 | }, 24 | "dev": { 25 | "clean": ["LibriTTS/dev-clean"], 26 | "other": ["LibriTTS/dev-other"] 27 | }, 28 | } 29 | voxceleb_datasets = { 30 | "voxceleb1" : { 31 | "train": ["VoxCeleb1/wav"], 32 | "test": ["VoxCeleb1/test_wav"] 33 | }, 34 | "voxceleb2" : { 35 | "train": ["VoxCeleb2/dev/aac"], 36 | "test": ["VoxCeleb2/test_wav"] 37 | } 38 | } 39 | 40 | other_datasets = [ 41 | "LJSpeech-1.1", 42 | "VCTK-Corpus/wav48", 43 | ] 44 | 45 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] 46 | -------------------------------------------------------------------------------- /encoder/data_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader 3 | -------------------------------------------------------------------------------- /encoder/data_objects/random_cycler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class RandomCycler: 4 | """ 5 | Creates an internal copy of a sequence and allows access to its items in a constrained random 6 | order. For a source sequence of n items and one or several consecutive queries of a total 7 | of m items, the following guarantees hold (one implies the other): 8 | - Each item will be returned between m // n and ((m - 1) // n) + 1 times. 9 | - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. 10 | """ 11 | 12 | def __init__(self, source): 13 | if len(source) == 0: 14 | raise Exception("Can't create RandomCycler from an empty collection") 15 | self.all_items = list(source) 16 | self.next_items = [] 17 | 18 | def sample(self, count: int): 19 | shuffle = lambda l: random.sample(l, len(l)) 20 | 21 | out = [] 22 | while count > 0: 23 | if count >= len(self.all_items): 24 | out.extend(shuffle(list(self.all_items))) 25 | count -= len(self.all_items) 26 | continue 27 | n = min(count, len(self.next_items)) 28 | out.extend(self.next_items[:n]) 29 | count -= n 30 | self.next_items = self.next_items[n:] 31 | if len(self.next_items) == 0: 32 | self.next_items = shuffle(list(self.all_items)) 33 | return out 34 | 35 | def __next__(self): 36 | return self.sample(1)[0] 37 | 38 | -------------------------------------------------------------------------------- /encoder/data_objects/speaker.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.random_cycler import RandomCycler 2 | from encoder.data_objects.utterance import Utterance 3 | from pathlib import Path 4 | 5 | # Contains the set of utterances of a single speaker 6 | class Speaker: 7 | def __init__(self, root: Path): 8 | self.root = root 9 | self.name = root.name 10 | self.utterances = None 11 | self.utterance_cycler = None 12 | 13 | def _load_utterances(self): 14 | with self.root.joinpath("_sources.txt").open("r") as sources_file: 15 | sources = [l.split(",") for l in sources_file] 16 | sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} 17 | self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] 18 | self.utterance_cycler = RandomCycler(self.utterances) 19 | 20 | def random_partial(self, count, n_frames): 21 | """ 22 | Samples a batch of unique partial utterances from the disk in a way that all 23 | utterances come up at least once every two cycles and in a random order every time. 24 | 25 | :param count: The number of partial utterances to sample from the set of utterances from 26 | that speaker. Utterances are guaranteed not to be repeated if is not larger than 27 | the number of utterances available. 28 | :param n_frames: The number of frames in the partial utterance. 29 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, 30 | frames are the frames of the partial utterances and range is the range of the partial 31 | utterance with regard to the complete utterance. 32 | """ 33 | if self.utterances is None: 34 | self._load_utterances() 35 | 36 | utterances = self.utterance_cycler.sample(count) 37 | 38 | a = [(u,) + u.random_partial(n_frames) for u in utterances] 39 | 40 | return a 41 | -------------------------------------------------------------------------------- /encoder/data_objects/speaker_batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | from encoder.data_objects.speaker import Speaker 4 | 5 | class SpeakerBatch: 6 | def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): 7 | self.speakers = speakers 8 | self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} 9 | 10 | # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with 11 | # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) 12 | self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) 13 | -------------------------------------------------------------------------------- /encoder/data_objects/speaker_verification_dataset.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.random_cycler import RandomCycler 2 | from encoder.data_objects.speaker_batch import SpeakerBatch 3 | from encoder.data_objects.speaker import Speaker 4 | from encoder.params_data import partials_n_frames 5 | from torch.utils.data import Dataset, DataLoader 6 | from pathlib import Path 7 | 8 | # TODO: improve with a pool of speakers for data efficiency 9 | 10 | class SpeakerVerificationDataset(Dataset): 11 | def __init__(self, datasets_root: Path): 12 | self.root = datasets_root 13 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] 14 | if len(speaker_dirs) == 0: 15 | raise Exception("No speakers found. Make sure you are pointing to the directory " 16 | "containing all preprocessed speaker directories.") 17 | self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] 18 | self.speaker_cycler = RandomCycler(self.speakers) 19 | 20 | def __len__(self): 21 | return int(1e10) 22 | 23 | def __getitem__(self, index): 24 | return next(self.speaker_cycler) 25 | 26 | def get_logs(self): 27 | log_string = "" 28 | for log_fpath in self.root.glob("*.txt"): 29 | with log_fpath.open("r") as log_file: 30 | log_string += "".join(log_file.readlines()) 31 | return log_string 32 | 33 | 34 | class SpeakerVerificationDataLoader(DataLoader): 35 | def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, 36 | batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, 37 | worker_init_fn=None): 38 | self.utterances_per_speaker = utterances_per_speaker 39 | 40 | super().__init__( 41 | dataset=dataset, 42 | batch_size=speakers_per_batch, 43 | shuffle=False, 44 | sampler=sampler, 45 | batch_sampler=batch_sampler, 46 | num_workers=num_workers, 47 | collate_fn=self.collate, 48 | pin_memory=pin_memory, 49 | drop_last=False, 50 | timeout=timeout, 51 | worker_init_fn=worker_init_fn 52 | ) 53 | 54 | def collate(self, speakers): 55 | return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) 56 | -------------------------------------------------------------------------------- /encoder/data_objects/utterance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Utterance: 5 | def __init__(self, frames_fpath, wave_fpath): 6 | self.frames_fpath = frames_fpath 7 | self.wave_fpath = wave_fpath 8 | 9 | def get_frames(self): 10 | return np.load(self.frames_fpath) 11 | 12 | def random_partial(self, n_frames): 13 | """ 14 | Crops the frames into a partial utterance of n_frames 15 | 16 | :param n_frames: The number of frames of the partial utterance 17 | :return: the partial utterance frames and a tuple indicating the start and end of the 18 | partial utterance in the complete utterance. 19 | """ 20 | frames = self.get_frames() 21 | if frames.shape[0] == n_frames: 22 | start = 0 23 | else: 24 | start = np.random.randint(0, frames.shape[0] - n_frames) 25 | end = start + n_frames 26 | return frames[start:end], (start, end) -------------------------------------------------------------------------------- /encoder/inference.py: -------------------------------------------------------------------------------- 1 | from encoder.params_data import * 2 | from encoder.model import SpeakerEncoder 3 | from encoder.audio import preprocess_wav_torch # We want to expose this function from here 4 | from matplotlib import cm 5 | from encoder import audio 6 | from pathlib import Path 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import os 11 | _model = None # type: SpeakerEncoder 12 | _device = "cpu" # type: torch.device 13 | 14 | 15 | def load_model(weights_fpath: Path, device=None, preprocess=False, rank=0): 16 | """ 17 | Loads the model in memory. If this function is not explicitely called, it will be run on the 18 | first call to embed_frames() with the default weights file. 19 | 20 | :param weights_fpath: the path to saved model weights. 21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The 22 | model will be loaded and will run on this device. Outputs will however always be on the cpu. 23 | If None, will default to your GPU if it"s available, otherwise your CPU. 24 | """ 25 | # TODO: I think the slow loading of the encoder might have something to do with the device it 26 | # was saved on. Worth investigating. 27 | global _model, _device 28 | if device is None: 29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | _device = "cpu" 31 | elif isinstance(device, str): 32 | _device = torch.device(device) 33 | # _model = SpeakerEncoder(_device, _device).to(device) 34 | _model = SpeakerEncoder(_device, torch.device("cpu")) 35 | checkpoint = torch.load(weights_fpath) 36 | _model.load_state_dict(checkpoint["model_state"]) 37 | 38 | if not preprocess: _model.cuda() 39 | 40 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath, checkpoint["step"])) 41 | 42 | 43 | def is_loaded(): 44 | return _model is not None 45 | 46 | def num_params(): 47 | parameters = filter(lambda p: p.requires_grad, _model.parameters()) 48 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 49 | print('Trainable Parameters: %.3fM' % parameters) 50 | 51 | def embed_frames_batch_torch(frames_batch): 52 | """ 53 | Computes embeddings for a batch of mel spectrogram. 54 | 55 | :param frames_batch: a batch mel of spectrogram as Tensor 56 | (batch_size, n_frames, n_channels) 57 | :return: the embeddings as Tensor (batch_size, model_embedding_size) 58 | """ 59 | if _model is None: 60 | raise Exception("Model was not loaded. Call load_model() before inference.") 61 | # embed = _model.forward(frames_batch) 62 | embed = _model.forward(frames_batch) 63 | 64 | return embed 65 | 66 | 67 | def embed_frames_batch_torch_perceptual(frames_batch): 68 | """ 69 | Computes embeddings for a batch of mel spectrogram. 70 | 71 | :param frames_batch: a batch mel of spectrogram as Tensor 72 | (batch_size, n_frames, n_channels) 73 | :return: the embeddings as Tensor (batch_size, model_embedding_size) 74 | """ 75 | if _model is None: 76 | raise Exception("Model was not loaded. Call load_model() before inference.") 77 | # embed = _model.forward(frames_batch) 78 | embed = _model.forward_perceptual2(frames_batch) # (3, 1, 256) 79 | 80 | return embed 81 | 82 | def embed_frames_batch(frames_batch): 83 | """ 84 | Computes embeddings for a batch of mel spectrogram. 85 | 86 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape 87 | (batch_size, n_frames, n_channels) 88 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) 89 | """ 90 | if _model is None: 91 | raise Exception("Model was not loaded. Call load_model() before inference.") 92 | 93 | frames = torch.from_numpy(frames_batch).to(_device) 94 | embed = _model.forward(frames).detach().cpu().numpy() 95 | return embed 96 | 97 | 98 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, 99 | min_pad_coverage=0.75, overlap=0.5): 100 | """ 101 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain 102 | partial utterances of each. Both the waveform and the mel 103 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to 104 | its spectrogram. This function assumes that the mel spectrogram parameters used are those 105 | defined in params_data.py. 106 | 107 | The returned ranges may be indexing further than the length of the waveform. It is 108 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop. 109 | 110 | :param n_samples: the number of samples in the waveform 111 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial 112 | utterance 113 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have 114 | enough frames. If at least of are present, 115 | then the last partial utterance will be considered, as if we padded the audio. Otherwise, 116 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial 117 | utterance, this parameter is ignored so that the function always returns at least 1 slice. 118 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial 119 | utterances are entirely disjoint. 120 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 121 | respectively the waveform and the mel spectrogram with these slices to obtain the partial 122 | utterances. 123 | """ 124 | assert 0 <= overlap < 1 125 | assert 0 < min_pad_coverage <= 1 126 | 127 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 128 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 129 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) 130 | 131 | # Compute the slices 132 | wav_slices, mel_slices = [], [] 133 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) 134 | for i in range(0, steps, frame_step): 135 | mel_range = np.array([i, i + partial_utterance_n_frames]) 136 | wav_range = mel_range * samples_per_frame 137 | mel_slices.append(slice(*mel_range)) 138 | wav_slices.append(slice(*wav_range)) 139 | 140 | # Evaluate whether extra padding is warranted or not 141 | last_wav_range = wav_slices[-1] 142 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 143 | if coverage < min_pad_coverage and len(mel_slices) > 1: 144 | mel_slices = mel_slices[:-1] 145 | wav_slices = wav_slices[:-1] 146 | 147 | return wav_slices, mel_slices 148 | 149 | 150 | def embed_utterance_torch_preprocess(wav, using_partials=True, return_partials=False, **kwargs): 151 | """ 152 | Computes an embedding for a single utterance. 153 | 154 | # TODO: handle multiple wavs to benefit from batching on GPU 155 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 156 | :param using_partials: if True, then the utterance is split in partial utterances of 157 | frames and the utterance embedding is computed from their 158 | normalized average. If False, the utterance is instead computed from feeding the entire 159 | spectogram to the network. 160 | :param return_partials: if True, the partial embeddings will also be returned along with the 161 | wav slices that correspond to the partial embeddings. 162 | :param kwargs: additional arguments to compute_partial_splits() 163 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 164 | is True, the partial utterances as a numpy array of float32 of shape 165 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 166 | returned. If is simultaneously set to False, both these values will be None 167 | instead. 168 | """ 169 | # Process the entire utterance if not using partials 170 | if not using_partials: 171 | frames = audio.wav_to_mel_spectrogram(wav) 172 | embed = embed_frames_batch(frames[None, ...])[0] 173 | if return_partials: 174 | return embed, None, None 175 | return embed 176 | 177 | # Compute where to split the utterance into partials and pad if necessary 178 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 179 | max_wave_length = wave_slices[-1].stop 180 | if max_wave_length >= len(wav): 181 | wav = torch.nn.functional.pad(wav, (0, max_wave_length - len(wav)), "constant") 182 | 183 | # Split the utterance into partials 184 | frames = audio.wav_to_mel_spectrogram_torch_preprocess(wav) # (T, n_mels) 185 | frames_batch = torch.stack([frames[s] for s in mel_slices]) # (batch, short T, n_mels) 186 | partial_embeds = embed_frames_batch_torch(frames_batch) # (batch, n_embeddings(256)) 187 | 188 | # Compute the utterance embedding from the partial embeddings 189 | raw_embed = torch.mean(partial_embeds, dim=0) # (n_embeddings(256)) 190 | embed = raw_embed / torch.norm(raw_embed, 2) 191 | 192 | if return_partials: 193 | return embed, partial_embeds, wave_slices 194 | return embed 195 | 196 | 197 | 198 | 199 | def embed_utterance_torch(wav, using_partials=True, return_partials=False, **kwargs): 200 | """ 201 | Computes an embedding for a single utterance. 202 | 203 | # TODO: handle multiple wavs to benefit from batching on GPU 204 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 205 | :param using_partials: if True, then the utterance is split in partial utterances of 206 | frames and the utterance embedding is computed from their 207 | normalized average. If False, the utterance is instead computed from feeding the entire 208 | spectogram to the network. 209 | :param return_partials: if True, the partial embeddings will also be returned along with the 210 | wav slices that correspond to the partial embeddings. 211 | :param kwargs: additional arguments to compute_partial_splits() 212 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 213 | is True, the partial utterances as a numpy array of float32 of shape 214 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 215 | returned. If is simultaneously set to False, both these values will be None 216 | instead. 217 | """ 218 | # Process the entire utterance if not using partials 219 | if not using_partials: 220 | frames = audio.wav_to_mel_spectrogram(wav) 221 | embed = embed_frames_batch(frames[None, ...])[0] 222 | if return_partials: 223 | return embed, None, None 224 | return embed 225 | 226 | # Compute where to split the utterance into partials and pad if necessary 227 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 228 | max_wave_length = wave_slices[-1].stop 229 | if max_wave_length >= len(wav): 230 | wav = torch.nn.functional.pad(wav, (0, max_wave_length - len(wav)), "constant") 231 | 232 | # Split the utterance into partials 233 | frames = audio.wav_to_mel_spectrogram_torch(wav) # (T, n_mels) 234 | frames_batch = torch.stack([frames[s] for s in mel_slices]) # (batch, short T, n_mels) 235 | partial_embeds = embed_frames_batch_torch(frames_batch) # (batch, n_embeddings(256)) 236 | 237 | # Compute the utterance embedding from the partial embeddings 238 | raw_embed = torch.mean(partial_embeds, dim=0) # (n_embeddings(256)) 239 | embed = raw_embed / torch.norm(raw_embed, 2) 240 | 241 | if return_partials: 242 | return embed, partial_embeds, wave_slices 243 | return embed 244 | 245 | 246 | 247 | def embed_utterance_torch_perceptual(wav, using_partials=True, return_partials=False, **kwargs): 248 | """ 249 | Computes an embedding for a single utterance. 250 | 251 | # TODO: handle multiple wavs to benefit from batching on GPU 252 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 253 | :param using_partials: if True, then the utterance is split in partial utterances of 254 | frames and the utterance embedding is computed from their 255 | normalized average. If False, the utterance is instead computed from feeding the entire 256 | spectogram to the network. 257 | :param return_partials: if True, the partial embeddings will also be returned along with the 258 | wav slices that correspond to the partial embeddings. 259 | :param kwargs: additional arguments to compute_partial_splits() 260 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 261 | is True, the partial utterances as a numpy array of float32 of shape 262 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 263 | returned. If is simultaneously set to False, both these values will be None 264 | instead. 265 | """ 266 | # Process the entire utterance if not using partials 267 | if not using_partials: 268 | frames = audio.wav_to_mel_spectrogram(wav) 269 | embed = embed_frames_batch(frames[None, ...])[0] 270 | if return_partials: 271 | return embed, None, None 272 | return embed 273 | 274 | # Compute where to split the utterance into partials and pad if necessary 275 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 276 | max_wave_length = wave_slices[-1].stop 277 | if max_wave_length >= len(wav): 278 | wav = torch.nn.functional.pad(wav, (0, max_wave_length - len(wav)), "constant") 279 | 280 | # Split the utterance into partials 281 | frames = audio.wav_to_mel_spectrogram_torch(wav) # (T, n_mels) 282 | frames_batch = torch.stack([frames[s] for s in mel_slices]) # (batch, short T, n_mels) 283 | partial_embeds = embed_frames_batch_torch_perceptual(frames_batch) # (batch, n_embeddings(256)) 284 | 285 | return partial_embeds 286 | 287 | 288 | def embed_speaker(wavs, **kwargs): 289 | raise NotImplemented() 290 | 291 | 292 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): 293 | if ax is None: 294 | ax = plt.gca() 295 | 296 | if shape is None: 297 | height = int(np.sqrt(len(embed))) 298 | shape = (height, -1) 299 | embed = embed.reshape(shape) 300 | 301 | cmap = cm.get_cmap() 302 | mappable = ax.imshow(embed, cmap=cmap) 303 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) 304 | cbar.set_clim(*color_range) 305 | 306 | ax.set_xticks([]), ax.set_yticks([]) 307 | ax.set_title(title) 308 | 309 | if __name__ == '__main__': 310 | x = torch.ones(128000).cuda() 311 | embed = embed_utterance_torch(x) 312 | print(embed.shape) -------------------------------------------------------------------------------- /encoder/model.py: -------------------------------------------------------------------------------- 1 | from encoder.params_model import * 2 | from encoder.params_data import * 3 | from scipy.interpolate import interp1d 4 | from sklearn.metrics import roc_curve 5 | from torch.nn.utils import clip_grad_norm_ 6 | from scipy.optimize import brentq 7 | from torch import nn 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class SpeakerEncoder(nn.Module): 13 | def __init__(self, device, loss_device): 14 | super().__init__() 15 | self.loss_device = loss_device 16 | 17 | # Network defition 18 | self.lstm = nn.LSTM(input_size=mel_n_channels, 19 | hidden_size=model_hidden_size, 20 | num_layers=model_num_layers, 21 | batch_first=True).to(device) 22 | self.linear = nn.Linear(in_features=model_hidden_size, 23 | out_features=model_embedding_size).to(device) 24 | self.relu = torch.nn.ReLU().to(device) 25 | 26 | # Cosine similarity scaling (with fixed initial parameter values) 27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) 28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) 29 | 30 | # Loss 31 | self.loss_fn = nn.CrossEntropyLoss().cuda() 32 | 33 | def do_gradient_ops(self): 34 | # Gradient scale 35 | self.similarity_weight.grad *= 0.01 36 | self.similarity_bias.grad *= 0.01 37 | 38 | # Gradient clipping 39 | clip_grad_norm_(self.parameters(), 3, norm_type=2) 40 | 41 | def forward(self, utterances, hidden_init=None): 42 | """ 43 | Computes the embeddings of a batch of utterance spectrograms. 44 | 45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 46 | (batch_size, n_frames, n_channels) 47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 48 | batch_size, hidden_size). Will default to a tensor of zeros if None. 49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 50 | """ 51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 52 | # and the final cell state. 53 | # if not (next(self.lstm.parameters())).is_cuda: 54 | # self.lstm.cuda() 55 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 56 | 57 | # We take only the hidden state of the last layer 58 | embeds_raw = self.relu(self.linear(hidden[-1])) 59 | 60 | # L2-normalize it 61 | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 62 | 63 | return embeds 64 | 65 | def forward_perceptual(self, utterances, hidden_init=None): 66 | """ 67 | Computes the embeddings of a batch of utterance spectrograms. 68 | 69 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 70 | (batch_size, n_frames, n_channels) 71 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 72 | batch_size, hidden_size). Will default to a tensor of zeros if None. 73 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 74 | """ 75 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 76 | # and the final cell state. 77 | # if not (next(self.lstm.parameters())).is_cuda: 78 | # self.lstm.cuda() 79 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 80 | 81 | # We take only the hidden state of the last layer 82 | embeds_raw = self.relu(self.linear(hidden[-1])) 83 | return embeds_raw 84 | 85 | 86 | def forward_perceptual2(self, utterances, hidden_init=None): 87 | """ 88 | Computes the embeddings of a batch of utterance spectrograms. 89 | 90 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 91 | (batch_size, n_frames, n_channels) 92 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 93 | batch_size, hidden_size). Will default to a tensor of zeros if None. 94 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 95 | """ 96 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 97 | # and the final cell state. 98 | # if not (next(self.lstm.parameters())).is_cuda: 99 | # self.lstm.cuda() 100 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 101 | 102 | return hidden 103 | 104 | def similarity_matrix(self, embeds): 105 | """ 106 | Computes the similarity matrix according the section 2.1 of GE2E. 107 | 108 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 109 | utterances_per_speaker, embedding_size) 110 | :return: the similarity matrix as a tensor of shape (speakers_per_batch, 111 | utterances_per_speaker, speakers_per_batch) 112 | """ 113 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 114 | 115 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation 116 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) 117 | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) 118 | 119 | # Exclusive centroids (1 per utterance) 120 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) 121 | centroids_excl /= (utterances_per_speaker - 1) 122 | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) 123 | 124 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot 125 | # product of these vectors (which is just an element-wise multiplication reduced by a sum). 126 | # We vectorize the computation for efficiency. 127 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, 128 | speakers_per_batch).to(self.loss_device) 129 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) 130 | for j in range(speakers_per_batch): 131 | mask = np.where(mask_matrix[j])[0] 132 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) 133 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) 134 | 135 | ## Even more vectorized version (slower maybe because of transpose) 136 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker 137 | # ).to(self.loss_device) 138 | # eye = np.eye(speakers_per_batch, dtype=np.int) 139 | # mask = np.where(1 - eye) 140 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) 141 | # mask = np.where(eye) 142 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) 143 | # sim_matrix2 = sim_matrix2.transpose(1, 2) 144 | 145 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias 146 | return sim_matrix 147 | 148 | def loss(self, embeds): 149 | """ 150 | Computes the softmax loss according the section 2.1 of GE2E. 151 | 152 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 153 | utterances_per_speaker, embedding_size) 154 | :return: the loss and the EER for this batch of embeddings. 155 | """ 156 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 157 | 158 | # Loss 159 | sim_matrix = self.similarity_matrix(embeds) 160 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 161 | speakers_per_batch)) 162 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) 163 | target = torch.from_numpy(ground_truth).long().to(self.loss_device) 164 | loss = self.loss_fn(sim_matrix, target) 165 | 166 | # EER (not backpropagated) 167 | with torch.no_grad(): 168 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] 169 | labels = np.array([inv_argmax(i) for i in ground_truth]) 170 | preds = sim_matrix.detach().cpu().numpy() 171 | 172 | # Snippet from https://yangcha.github.io/EER-ROC/ 173 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) 174 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 175 | 176 | return loss, eer 177 | 178 | if __name__ == '__main__': 179 | encoder = SpeakerEncoder('cuda','cuda') 180 | input = torch.ones(6, 24, 80).cuda() 181 | output = encoder(input) 182 | -------------------------------------------------------------------------------- /encoder/params_data.py: -------------------------------------------------------------------------------- 1 | 2 | ## Mel-filterbank 3 | mel_n_channels = 80 4 | win_length = 512 5 | hop_length = 128 6 | n_fft = 512 7 | mel_window_length = 25 # In milliseconds 8 | mel_window_step = 10 # In milliseconds 9 | 10 | ## Audio 11 | sampling_rate = 24000 12 | # Number of spectrogram frames in a partial utterance 13 | partials_n_frames = 240 # 2400 ms 14 | # Number of spectrogram frames at inference 15 | inference_n_frames = 120 # 1200 ms 16 | 17 | 18 | ## Voice Activation Detection 19 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 20 | # This sets the granularity of the VAD. Should not need to be changed. 21 | vad_window_length = 20 # In milliseconds 22 | # Number of frames to average together when performing the moving average smoothing. 23 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 24 | vad_moving_average_width = 8 25 | # Maximum number of consecutive silent frames a segment can have. 26 | vad_max_silence_length = 6 27 | 28 | 29 | ## Audio volume normalization 30 | audio_norm_target_dBFS = -30 31 | -------------------------------------------------------------------------------- /encoder/params_model.py: -------------------------------------------------------------------------------- 1 | 2 | ## Model parameters 3 | model_hidden_size = 256 4 | model_embedding_size = 256 5 | model_num_layers = 3 6 | 7 | 8 | ## Training parameters 9 | learning_rate_init = 1e-4 10 | speakers_per_batch = 64 11 | utterances_per_speaker = 10 12 | -------------------------------------------------------------------------------- /encoder/plot_umap.py: -------------------------------------------------------------------------------- 1 | from encoder.visualizations import Visualizations 2 | from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset 3 | from encoder.params_model import * 4 | from encoder.model import SpeakerEncoder 5 | from utils.profiler import Profiler 6 | from pathlib import Path 7 | import torch 8 | 9 | def sync(device: torch.device): 10 | # FIXME 11 | return 12 | # For correct profiling (cuda operations are async) 13 | if device.type == "cuda": 14 | torch.cuda.synchronize(device) 15 | 16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, 17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, 18 | no_visdom: bool): 19 | # Create a dataset and a dataloader 20 | dataset = SpeakerVerificationDataset(clean_data_root) 21 | loader = SpeakerVerificationDataLoader( 22 | dataset, 23 | speakers_per_batch, 24 | utterances_per_speaker, 25 | num_workers=8, 26 | ) 27 | 28 | # Setup the device on which to run the forward pass and the loss. These can be different, 29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your 30 | # hyperparameters) faster on the CPU. 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # FIXME: currently, the gradient is None if loss_device is cuda 33 | loss_device = torch.device("cpu") 34 | 35 | # Create the model and the optimizer 36 | model = SpeakerEncoder(device, loss_device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) 38 | init_step = 1 39 | 40 | # Configure file path for the model 41 | state_fpath = models_dir.joinpath(run_id + ".pt") 42 | backup_dir = models_dir.joinpath(run_id + "_backups") 43 | 44 | # Load any existing model 45 | if not force_restart: 46 | if state_fpath.exists(): 47 | print("Found existing model \"%s\", loading it and resuming training." % run_id) 48 | checkpoint = torch.load(state_fpath) 49 | init_step = checkpoint["step"] 50 | model.load_state_dict(checkpoint["model_state"]) 51 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 52 | optimizer.param_groups[0]["lr"] = learning_rate_init 53 | else: 54 | print("No model \"%s\" found, starting training from scratch." % run_id) 55 | else: 56 | print("Starting the training from scratch.") 57 | model.train() 58 | 59 | # Initialize the visualization environment 60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) 61 | vis.log_dataset(dataset) 62 | vis.log_params() 63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") 64 | vis.log_implementation({"Device": device_name}) 65 | 66 | # Training loop 67 | profiler = Profiler(summarize_every=10, disabled=False) 68 | for step, speaker_batch in enumerate(loader, init_step): 69 | profiler.tick("Blocking, waiting for batch (threaded)") 70 | 71 | # Forward pass 72 | inputs = torch.from_numpy(speaker_batch.data).to(device) 73 | sync(device) 74 | profiler.tick("Data to %s" % device) 75 | embeds = model(inputs) 76 | sync(device) 77 | profiler.tick("Forward pass") 78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) 79 | loss, eer = model.loss(embeds_loss) 80 | sync(loss_device) 81 | profiler.tick("Loss") 82 | 83 | # Backward pass 84 | model.zero_grad() 85 | loss.backward() 86 | profiler.tick("Backward pass") 87 | model.do_gradient_ops() 88 | optimizer.step() 89 | profiler.tick("Parameter update") 90 | 91 | # Update visualizations 92 | # learning_rate = optimizer.param_groups[0]["lr"] 93 | vis.update(loss.item(), eer, step) 94 | 95 | # Draw projections and save them to the backup folder 96 | if umap_every != 0 and step % umap_every == 0: 97 | print("Drawing and saving projections (step %d)" % step) 98 | backup_dir.mkdir(exist_ok=True) 99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) 100 | embeds = embeds.detach().cpu().numpy() 101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) 102 | vis.save() 103 | 104 | 105 | -------------------------------------------------------------------------------- /encoder/preprocess.py: -------------------------------------------------------------------------------- 1 | from multiprocess.pool import ThreadPool 2 | from encoder.params_data import * 3 | from encoder.config import librispeech_datasets, anglophone_nationalites 4 | from datetime import datetime 5 | from encoder import audio 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | class DatasetLog: 12 | """ 13 | Registers metadata about the dataset in a text file. 14 | """ 15 | def __init__(self, root, name): 16 | self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") 17 | self.sample_data = dict() 18 | 19 | start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 20 | self.write_line("Creating dataset %s on %s" % (name, start_time)) 21 | self.write_line("-----") 22 | self._log_params() 23 | 24 | def _log_params(self): 25 | from encoder import params_data 26 | self.write_line("Parameter values:") 27 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 28 | value = getattr(params_data, param_name) 29 | self.write_line("\t%s: %s" % (param_name, value)) 30 | self.write_line("-----") 31 | 32 | def write_line(self, line): 33 | self.text_file.write("%s\n" % line) 34 | 35 | def add_sample(self, **kwargs): 36 | for param_name, value in kwargs.items(): 37 | if not param_name in self.sample_data: 38 | self.sample_data[param_name] = [] 39 | self.sample_data[param_name].append(value) 40 | 41 | def finalize(self): 42 | self.write_line("Statistics:") 43 | for param_name, values in self.sample_data.items(): 44 | self.write_line("\t%s:" % param_name) 45 | self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) 46 | self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) 47 | self.write_line("-----") 48 | end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 49 | self.write_line("Finished on %s" % end_time) 50 | self.text_file.close() 51 | 52 | 53 | def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog): 54 | dataset_root = datasets_root.joinpath(dataset_name) 55 | if not dataset_root.exists(): 56 | print("Couldn\'t find %s, skipping this dataset." % dataset_root) 57 | return None, None 58 | return dataset_root, DatasetLog(out_dir, dataset_name) 59 | 60 | 61 | def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, 62 | skip_existing, logger): 63 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) 64 | 65 | # Function to preprocess utterances for one speaker 66 | def preprocess_speaker(speaker_dir: Path): 67 | # Give a name to the speaker that includes its dataset 68 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) 69 | 70 | # Create an output directory with that name, as well as a txt file containing a 71 | # reference to each source file. 72 | speaker_out_dir = out_dir.joinpath(speaker_name) 73 | speaker_out_dir.mkdir(exist_ok=True) 74 | sources_fpath = speaker_out_dir.joinpath("_sources.txt") 75 | 76 | # There's a possibility that the preprocessing was interrupted earlier, check if 77 | # there already is a sources file. 78 | if sources_fpath.exists(): 79 | try: 80 | with sources_fpath.open("r") as sources_file: 81 | existing_fnames = {line.split(",")[0] for line in sources_file} 82 | except: 83 | existing_fnames = {} 84 | else: 85 | existing_fnames = {} 86 | 87 | # Gather all audio files for that speaker recursively 88 | sources_file = sources_fpath.open("a" if skip_existing else "w") 89 | for in_fpath in speaker_dir.glob("**/*.%s" % extension): 90 | # Check if the target output file already exists 91 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) 92 | out_fname = out_fname.replace(".%s" % extension, ".npy") 93 | if skip_existing and out_fname in existing_fnames: 94 | continue 95 | 96 | # Load and preprocess the waveform 97 | wav = audio.preprocess_wav(in_fpath) 98 | if len(wav) == 0: 99 | continue 100 | 101 | # Create the mel spectrogram, discard those that are too short 102 | frames = audio.wav_to_mel_spectrogram(wav) 103 | if len(frames) < partials_n_frames: 104 | continue 105 | 106 | out_fpath = speaker_out_dir.joinpath(out_fname) 107 | np.save(out_fpath, frames) 108 | logger.add_sample(duration=len(wav) / sampling_rate) 109 | sources_file.write("%s,%s\n" % (out_fname, in_fpath)) 110 | 111 | sources_file.close() 112 | 113 | # Process the utterances for each speaker 114 | with ThreadPool(8) as pool: 115 | list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), 116 | unit="speakers")) 117 | logger.finalize() 118 | print("Done preprocessing %s.\n" % dataset_name) 119 | 120 | 121 | def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): 122 | for dataset_name in librispeech_datasets["train"]["other"]: 123 | # Initialize the preprocessing 124 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 125 | if not dataset_root: 126 | return 127 | 128 | # Preprocess all speakers 129 | speaker_dirs = list(dataset_root.glob("*")) 130 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac", 131 | skip_existing, logger) 132 | 133 | 134 | def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): 135 | # Initialize the preprocessing 136 | dataset_name = "VoxCeleb1" 137 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 138 | if not dataset_root: 139 | return 140 | 141 | # Get the contents of the meta file 142 | with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: 143 | metadata = [line.split("\t") for line in metafile][1:] 144 | 145 | # Select the ID and the nationality, filter out non-anglophone speakers 146 | nationalities = {line[0]: line[3] for line in metadata} 147 | keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if 148 | nationality.lower() in anglophone_nationalites] 149 | print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % 150 | (len(keep_speaker_ids), len(nationalities))) 151 | 152 | # Get the speaker directories for anglophone speakers only 153 | speaker_dirs = dataset_root.joinpath("wav").glob("*") 154 | speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if 155 | speaker_dir.name in keep_speaker_ids] 156 | print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." % 157 | (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs))) 158 | 159 | # Preprocess all speakers 160 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", 161 | skip_existing, logger) 162 | 163 | 164 | def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): 165 | # Initialize the preprocessing 166 | dataset_name = "VoxCeleb2" 167 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 168 | if not dataset_root: 169 | return 170 | 171 | # Get the speaker directories 172 | # Preprocess all speakers 173 | speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) 174 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", 175 | skip_existing, logger) 176 | -------------------------------------------------------------------------------- /encoder/train.py: -------------------------------------------------------------------------------- 1 | from encoder.visualizations import Visualizations 2 | from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset 3 | from encoder.params_model import * 4 | from encoder.model import SpeakerEncoder 5 | from utils.profiler import Profiler 6 | from pathlib import Path 7 | import torch 8 | 9 | def sync(device: torch.device): 10 | # FIXME 11 | return 12 | # For correct profiling (cuda operations are async) 13 | if device.type == "cuda": 14 | torch.cuda.synchronize(device) 15 | 16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, 17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, 18 | no_visdom: bool): 19 | # Create a dataset and a dataloader 20 | dataset = SpeakerVerificationDataset(clean_data_root) 21 | loader = SpeakerVerificationDataLoader( 22 | dataset, 23 | speakers_per_batch, 24 | utterances_per_speaker, 25 | num_workers=8, 26 | ) 27 | 28 | # Setup the device on which to run the forward pass and the loss. These can be different, 29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your 30 | # hyperparameters) faster on the CPU. 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # FIXME: currently, the gradient is None if loss_device is cuda 33 | loss_device = torch.device("cpu") 34 | 35 | # Create the model and the optimizer 36 | model = SpeakerEncoder(device, loss_device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) 38 | init_step = 1 39 | 40 | # Configure file path for the model 41 | state_fpath = models_dir.joinpath(run_id + ".pt") 42 | backup_dir = models_dir.joinpath(run_id + "_backups") 43 | 44 | # Load any existing model 45 | if not force_restart: 46 | if state_fpath.exists(): 47 | print("Found existing model \"%s\", loading it and resuming training." % run_id) 48 | checkpoint = torch.load(state_fpath) 49 | init_step = checkpoint["step"] 50 | model.load_state_dict(checkpoint["model_state"]) 51 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 52 | optimizer.param_groups[0]["lr"] = learning_rate_init 53 | else: 54 | print("No model \"%s\" found, starting training from scratch." % run_id) 55 | else: 56 | print("Starting the training from scratch.") 57 | model.train() 58 | 59 | # Initialize the visualization environment 60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) 61 | vis.log_dataset(dataset) 62 | vis.log_params() 63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") 64 | vis.log_implementation({"Device": device_name}) 65 | 66 | # Training loop 67 | profiler = Profiler(summarize_every=10, disabled=False) 68 | for step, speaker_batch in enumerate(loader, init_step): 69 | profiler.tick("Blocking, waiting for batch (threaded)") 70 | 71 | # Forward pass 72 | inputs = torch.from_numpy(speaker_batch.data).to(device) 73 | sync(device) 74 | profiler.tick("Data to %s" % device) 75 | embeds = model(inputs) 76 | sync(device) 77 | profiler.tick("Forward pass") 78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) 79 | loss, eer = model.loss(embeds_loss) 80 | sync(loss_device) 81 | profiler.tick("Loss") 82 | 83 | # Backward pass 84 | model.zero_grad() 85 | loss.backward() 86 | profiler.tick("Backward pass") 87 | model.do_gradient_ops() 88 | optimizer.step() 89 | profiler.tick("Parameter update") 90 | 91 | # Update visualizations 92 | # learning_rate = optimizer.param_groups[0]["lr"] 93 | vis.update(loss.item(), eer, step) 94 | 95 | # Draw projections and save them to the backup folder 96 | if umap_every != 0 and step % umap_every == 0: 97 | print("Drawing and saving projections (step %d)" % step) 98 | backup_dir.mkdir(exist_ok=True) 99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) 100 | embeds = embeds.detach().cpu().numpy() 101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) 102 | vis.save() 103 | 104 | # Overwrite the latest version of the model 105 | if save_every != 0 and step % save_every == 0: 106 | print("Saving the model (step %d)" % step) 107 | torch.save({ 108 | "step": step + 1, 109 | "model_state": model.state_dict(), 110 | "optimizer_state": optimizer.state_dict(), 111 | }, state_fpath) 112 | 113 | # Make a backup 114 | if backup_every != 0 and step % backup_every == 0: 115 | print("Making a backup (step %d)" % step) 116 | backup_dir.mkdir(exist_ok=True) 117 | backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) 118 | torch.save({ 119 | "step": step + 1, 120 | "model_state": model.state_dict(), 121 | "optimizer_state": optimizer.state_dict(), 122 | }, backup_fpath) 123 | 124 | profiler.tick("Extras (visualizations, saving)") 125 | -------------------------------------------------------------------------------- /encoder/visualizations.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from datetime import datetime 3 | from time import perf_counter as timer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | # import webbrowser 7 | import visdom 8 | import umap 9 | 10 | colormap = np.array([ 11 | [76, 255, 0], 12 | [0, 127, 70], 13 | [255, 0, 0], 14 | [255, 217, 38], 15 | [0, 135, 255], 16 | [165, 0, 165], 17 | [255, 167, 255], 18 | [0, 255, 255], 19 | [255, 96, 38], 20 | [142, 76, 0], 21 | [33, 0, 127], 22 | [0, 0, 0], 23 | [183, 183, 183], 24 | ], dtype=np.float) / 255 25 | 26 | 27 | class Visualizations: 28 | def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): 29 | # Tracking data 30 | self.last_update_timestamp = timer() 31 | self.update_every = update_every 32 | self.step_times = [] 33 | self.losses = [] 34 | self.eers = [] 35 | print("Updating the visualizations every %d steps." % update_every) 36 | 37 | # If visdom is disabled TODO: use a better paradigm for that 38 | self.disabled = disabled 39 | if self.disabled: 40 | return 41 | 42 | # Set the environment name 43 | now = str(datetime.now().strftime("%d-%m %Hh%M")) 44 | if env_name is None: 45 | self.env_name = now 46 | else: 47 | self.env_name = "%s (%s)" % (env_name, now) 48 | 49 | # Connect to visdom and open the corresponding window in the browser 50 | try: 51 | self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) 52 | except ConnectionError: 53 | raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " 54 | "start it.") 55 | # webbrowser.open("http://localhost:8097/env/" + self.env_name) 56 | 57 | # Create the windows 58 | self.loss_win = None 59 | self.eer_win = None 60 | # self.lr_win = None 61 | self.implementation_win = None 62 | self.projection_win = None 63 | self.implementation_string = "" 64 | 65 | def log_params(self): 66 | if self.disabled: 67 | return 68 | from encoder import params_data 69 | from encoder import params_model 70 | param_string = "Model parameters:
" 71 | for param_name in (p for p in dir(params_model) if not p.startswith("__")): 72 | value = getattr(params_model, param_name) 73 | param_string += "\t%s: %s
" % (param_name, value) 74 | param_string += "Data parameters:
" 75 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 76 | value = getattr(params_data, param_name) 77 | param_string += "\t%s: %s
" % (param_name, value) 78 | self.vis.text(param_string, opts={"title": "Parameters"}) 79 | 80 | def log_dataset(self, dataset: SpeakerVerificationDataset): 81 | if self.disabled: 82 | return 83 | dataset_string = "" 84 | dataset_string += "Speakers: %s\n" % len(dataset.speakers) 85 | dataset_string += "\n" + dataset.get_logs() 86 | dataset_string = dataset_string.replace("\n", "
") 87 | self.vis.text(dataset_string, opts={"title": "Dataset"}) 88 | 89 | def log_implementation(self, params): 90 | if self.disabled: 91 | return 92 | implementation_string = "" 93 | for param, value in params.items(): 94 | implementation_string += "%s: %s\n" % (param, value) 95 | implementation_string = implementation_string.replace("\n", "
") 96 | self.implementation_string = implementation_string 97 | self.implementation_win = self.vis.text( 98 | implementation_string, 99 | opts={"title": "Training implementation"} 100 | ) 101 | 102 | def update(self, loss, eer, step): 103 | # Update the tracking data 104 | now = timer() 105 | self.step_times.append(1000 * (now - self.last_update_timestamp)) 106 | self.last_update_timestamp = now 107 | self.losses.append(loss) 108 | self.eers.append(eer) 109 | print(".", end="") 110 | 111 | # Update the plots every steps 112 | if step % self.update_every != 0: 113 | return 114 | time_string = "Step time: mean: %5dms std: %5dms" % \ 115 | (int(np.mean(self.step_times)), int(np.std(self.step_times))) 116 | print("\nStep %6d Loss: %.4f EER: %.4f %s" % 117 | (step, np.mean(self.losses), np.mean(self.eers), time_string)) 118 | if not self.disabled: 119 | self.loss_win = self.vis.line( 120 | [np.mean(self.losses)], 121 | [step], 122 | win=self.loss_win, 123 | update="append" if self.loss_win else None, 124 | opts=dict( 125 | legend=["Avg. loss"], 126 | xlabel="Step", 127 | ylabel="Loss", 128 | title="Loss", 129 | ) 130 | ) 131 | self.eer_win = self.vis.line( 132 | [np.mean(self.eers)], 133 | [step], 134 | win=self.eer_win, 135 | update="append" if self.eer_win else None, 136 | opts=dict( 137 | legend=["Avg. EER"], 138 | xlabel="Step", 139 | ylabel="EER", 140 | title="Equal error rate" 141 | ) 142 | ) 143 | if self.implementation_win is not None: 144 | self.vis.text( 145 | self.implementation_string + ("%s" % time_string), 146 | win=self.implementation_win, 147 | opts={"title": "Training implementation"}, 148 | ) 149 | 150 | # Reset the tracking 151 | self.losses.clear() 152 | self.eers.clear() 153 | self.step_times.clear() 154 | 155 | def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, 156 | max_speakers=10): 157 | max_speakers = min(max_speakers, len(colormap)) 158 | embeds = embeds[:max_speakers * utterances_per_speaker] 159 | 160 | n_speakers = len(embeds) // utterances_per_speaker 161 | ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) 162 | colors = [colormap[i] for i in ground_truth] 163 | 164 | reducer = umap.UMAP() 165 | projected = reducer.fit_transform(embeds) 166 | plt.scatter(projected[:, 0], projected[:, 1], c=colors) 167 | plt.gca().set_aspect("equal", "datalim") 168 | plt.title("UMAP projection (step %d)" % step) 169 | if not self.disabled: 170 | self.projection_win = self.vis.matplot(plt, win=self.projection_win) 171 | if out_fpath is not None: 172 | plt.savefig(out_fpath) 173 | plt.clf() 174 | 175 | def save(self): 176 | if not self.disabled: 177 | self.vis.save([self.env_name]) 178 | -------------------------------------------------------------------------------- /frontend/audio_preprocess.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | from scipy import signal 5 | from scipy.io import wavfile 6 | import soundfile as sf 7 | import math 8 | 9 | 10 | def load_wav(path, sr): 11 | return librosa.core.load(path, sr=sr)[0] 12 | 13 | 14 | def save_wav(wav, path, hparams): 15 | wav = wav / np.abs(wav).max() * 0.999 16 | f1 = 0.5 * 32767 / max(0.01, np.max(np.abs(wav))) 17 | f2 = np.sign(wav) * np.power(np.abs(wav), 0.95) 18 | wav = f1 * f2 19 | wav = signal.convolve(wav, signal.firwin(hparams['num_freq'], 20 | [hparams['fmin'], hparams['fmax']], 21 | pass_zero=False, 22 | fs=hparams['audio_sample_rate'])) 23 | # proposed by @dsmiller 24 | wavfile.write(path, hparams['audio_sample_rate'], wav.astype(np.int16)) 25 | 26 | 27 | def save_wavenet_wav(wav, path, sr): 28 | librosa.output.write_wav(path, wav, sr=sr) 29 | 30 | 31 | def save_melGAN_wav(file_path, sampling_rate, audio): 32 | audio = audio.reshape((-1, )) 33 | sf.write(file_path, 34 | audio, sampling_rate, "PCM_16") 35 | 36 | 37 | def preemphasis(wav, k): 38 | return signal.lfilter([1, -k], [1], wav) 39 | 40 | 41 | def inv_preemphasis(wav, k): 42 | return signal.lfilter([1], [1, -k], wav) 43 | 44 | 45 | def trim_silence(wav, hparams, only_front=True): 46 | non_silent = librosa.effects._signal_to_frame_nonsilent(wav, 47 | frame_length=hparams['trim_fft_size'], 48 | hop_length=hparams['trim_hop_size'], 49 | ref=np.max, 50 | top_db=hparams['trim_top_db']) 51 | 52 | nonzero = np.flatnonzero(non_silent) 53 | if nonzero.size > 0: 54 | # Compute the start and end positions 55 | # End position goes one frame past the last non-zero 56 | start = int(librosa.core.frames_to_samples(nonzero[0], hparams['trim_hop_size'])) 57 | end = min(wav.shape[-1], 58 | int(librosa.core.frames_to_samples(nonzero[-1] + 1, hparams['trim_hop_size']))) 59 | else: 60 | # The signal only contains zeros 61 | start, end = 0, 0 62 | if only_front: 63 | end = wav.shape[0] 64 | full_index = [slice(None)] * wav.ndim 65 | full_index[-1] = slice(start, end) 66 | 67 | return wav[tuple(full_index)] 68 | 69 | 70 | def get_hop_size(hparams): 71 | hop_size = hparams['hop_size'] 72 | if hop_size is None: 73 | assert hparams['frame_shift_ms'] is not None 74 | hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['sampling_rate']) 75 | return hop_size 76 | 77 | 78 | def inv_linear_spectrogram(linear_spectrogram, hparams): 79 | '''Converts linear spectrogram to waveform using librosa''' 80 | if hparams.signal_normalization: 81 | D = _denormalize(linear_spectrogram, hparams) 82 | else: 83 | D = linear_spectrogram 84 | 85 | S = _db_to_amp(D + hparams.ref_level_db) # Convert back to linear 86 | 87 | return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis) 88 | 89 | 90 | def _griffin_lim(S, hparams): 91 | '''librosa implementation of Griffin-Lim 92 | Based on https://github.com/librosa/librosa/issues/434 93 | ''' 94 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 95 | S_complex = np.abs(S).astype(np.complex) 96 | y = _istft(S_complex * angles, hparams) 97 | for i in range(hparams.griffin_lim_iters): 98 | angles = np.exp(1j * np.angle(_stft(y, hparams))) 99 | y = _istft(S_complex * angles, hparams) 100 | return y 101 | 102 | 103 | def _stft(y, hparams): 104 | return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams), 105 | win_length=hparams['win_length']) 106 | 107 | def _istft(y, hparams): 108 | return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_length']) 109 | 110 | # Conversions 111 | _mel_basis = None 112 | _inv_mel_basis = None 113 | 114 | 115 | def _build_mel_basis(hparams): 116 | assert hparams['fmax'] <= hparams['sampling_rate'] // 2 117 | return librosa.filters.mel(hparams['sampling_rate'], hparams['fft_size'], n_mels=hparams['num_mels'], 118 | fmin=hparams['fmin'], fmax=hparams['fmax'])#,norm=None if hparams['use_same_high'] else 1) 119 | 120 | 121 | def _linear_to_mel(spectogram, hparams): 122 | global _mel_basis 123 | if _mel_basis is None: 124 | _mel_basis = _build_mel_basis(hparams) 125 | return np.dot(_mel_basis, spectogram) 126 | 127 | 128 | def _mel_to_linear(mel_spectrogram, hparams): 129 | global _inv_mel_basis 130 | if _inv_mel_basis is None: 131 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) 132 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 133 | 134 | 135 | def _amp_to_db(x, hparams): 136 | min_level = np.exp(hparams['min_level_db'] / 20 * np.log(10)) # np.log()以e为底,np.exp()返回e的幂次方 137 | return 20 * np.log10(np.maximum(min_level, x)) # np.maximum逐位返回两个参数较大值 138 | 139 | 140 | def _db_to_amp(x): 141 | return np.power(10.0, (x) * 0.05) 142 | 143 | 144 | def _normalize(S, hparams): 145 | if hparams['allow_clipping_in_normalization']: 146 | if hparams['symmetric_mels']: 147 | return np.clip((2 * hparams['max_abs_value']) * ( 148 | (S - hparams['min_level_db']) / (-hparams['min_level_db'])) - hparams['max_abs_value'], 149 | -hparams['max_abs_value'], hparams['max_abs_value']) 150 | else: 151 | return np.clip(hparams['max_abs_value'] * ((S - hparams['min_level_db']) / (-hparams['min_level_db'])), 0, 152 | hparams['max_abs_value']) 153 | 154 | if hparams['symmetric_mels']: 155 | return (2 * hparams['max_abs_value']) * ( 156 | (S - hparams['min_level_db']) / (-hparams['min_level_db'])) - hparams['max_abs_value'] 157 | else: 158 | return hparams['max_abs_value'] * ((S - hparams['min_level_db']) / (-hparams['min_level_db'])) 159 | 160 | 161 | def _denormalize(D, hparams): 162 | if hparams['allow_clipping_in_normalization']: 163 | if hparams['symmetric_mels']: 164 | return (((np.clip(D, -hparams['max_abs_value'], 165 | hparams['max_abs_value']) + hparams['max_abs_value']) * -hparams['min_level_db'] / ( 166 | 2 * hparams['max_abs_value'])) 167 | + hparams['min_level_db']) 168 | else: 169 | return ((np.clip(D, 0, 170 | hparams['max_abs_value']) * -hparams['min_level_db'] / hparams['max_abs_value']) + hparams['min_level_db']) 171 | 172 | if hparams['symmetric_mels']: 173 | return (((D + hparams['max_abs_value']) * -hparams['min_level_db'] / ( 174 | 2 * hparams['max_abs_value'])) + hparams['min_level_db']) 175 | else: 176 | return ((D * -hparams['min_level_db'] / hparams['max_abs_value']) + hparams['min_level_db']) 177 | 178 | 179 | def linearspectrogram(wav, hparams): 180 | if hparams['preemphasis']: 181 | wav = preemphasis(wav, hparams['preemphasis_value']) 182 | D = _stft(wav, hparams) 183 | S = _amp_to_db(np.abs(D), hparams) - hparams['ref_level_db'] 184 | 185 | if hparams['signal_normalization']: 186 | return _normalize(S, hparams) 187 | return S 188 | 189 | 190 | def melspectrogram(wav, hparams): 191 | if hparams['preemphasis']: 192 | wav = preemphasis(wav, hparams['preemphasis_value']) 193 | D = _stft(wav, hparams) 194 | S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams['ref_level_db'] 195 | 196 | if hparams['signal_normalization']: 197 | return _normalize(S, hparams) 198 | return S 199 | 200 | 201 | def logmelfilterbank(audio, config, eps=1e-10): 202 | 203 | x_stft = librosa.stft(audio, n_fft=config["fft_size"], hop_length=config["hop_size"], # stft变换 204 | win_length=config["win_length"], window=config["window"], pad_mode="reflect") 205 | spc = np.abs(x_stft).T # (#frames, #bins) 206 | 207 | # get mel basis 得到mel偏移量 208 | mel_basis = librosa.filters.mel(sr=config["sampling_rate"], n_fft=config["fft_size"], 209 | n_mels=config["num_mels"], fmin=config["fmin"], fmax=config["fmax"]) 210 | # norm=None if config['use_same_high_mel'] else 1) 211 | 212 | return 20 * np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) 213 | 214 | 215 | # def chroma_stft(x, hparams): 216 | # 217 | # S = np.abs(_stft(x, hparams))**2 218 | # 219 | # tuning = estimate_tuning(S=S, sr=hparams['sampling_rate'], bins_per_octave=12) 220 | # 221 | # # Get the filter bank 222 | # chromafb = filters.chroma(hparams['sampling_rate'], hparams['fft_size'], 223 | # tuning=tuning, n_chroma=12) 224 | # 225 | # # Compute raw chroma 226 | # raw_chroma = np.dot(chromafb, S) 227 | # 228 | # return raw_chroma 229 | 230 | 231 | def inv_mel_spectrogram(mel_spectrogram, hparams): 232 | '''Converts mel spectrogram to waveform using librosa''' 233 | if hparams['signal_normalization']: 234 | D = _denormalize(mel_spectrogram, hparams) 235 | else: 236 | D = mel_spectrogram 237 | 238 | S = _mel_to_linear(_db_to_amp(D + hparams['ref_level_db']), hparams) # Convert back to linear 239 | 240 | return inv_preemphasis(_griffin_lim(S ** hparams['power'], hparams), hparams['preemphasis']) 241 | 242 | 243 | # waveRNN wav funcation 244 | def encode_mu_law(x, mu): 245 | mu = mu - 1 246 | fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) 247 | return np.floor((fx + 1) / 2 * mu + 0.5) 248 | 249 | def decode_mu_law(y, mu, from_labels=True) : 250 | # TODO : get rid of log2 - makes no sense 251 | if from_labels : y = label_2_float(y, math.log2(mu)) 252 | mu = mu - 1 253 | x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) 254 | return x 255 | 256 | def float_2_label(x, bits): 257 | assert abs(x).max() <= 1.0 258 | x = (x + 1.) * (2 ** bits - 1) / 2 259 | return x.clip(0, 2 ** bits - 1) 260 | 261 | def label_2_float(x, bits): 262 | return 2 * x / (2 ** bits - 1.) - 1. 263 | 264 | 265 | def num_frames(length, fsize, fshift): 266 | """Compute number of time frames of spectrogram 267 | """ 268 | pad = (fsize - fshift) 269 | if length % fshift == 0: 270 | M = (length + pad * 2 - fsize) // fshift + 1 271 | else: 272 | M = (length + pad * 2 - fsize) // fshift + 2 273 | return M 274 | 275 | 276 | def pad_lr(x, fsize, fshift): 277 | """Compute left and right padding 278 | """ 279 | M = num_frames(len(x), fsize, fshift) 280 | pad = (fsize - fshift) 281 | T = len(x) + 2 * pad 282 | r = (M - 1) * fshift + fsize - T 283 | return pad, pad + r 284 | 285 | 286 | import matplotlib.pyplot as plt 287 | 288 | 289 | def plot_spec(spec, path, info=None): 290 | fig = plt.figure(figsize=(14, 7)) 291 | heatmap = plt.pcolor(spec) 292 | fig.colorbar(heatmap) 293 | 294 | xlabel = 'Time' 295 | if info is not None: 296 | xlabel += '\n\n' + info 297 | plt.xlabel(xlabel) 298 | plt.ylabel('Mel filterbank') 299 | plt.tight_layout() 300 | plt.savefig(path, format='png') 301 | plt.close(fig) 302 | 303 | # Compute the mel scale spectrogram from the wav 304 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 305 | """ 306 | PARAMS 307 | ------ 308 | C: compression factor 309 | """ 310 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 311 | 312 | def dynamic_range_decompression(x, C=1): 313 | """ 314 | PARAMS 315 | ------ 316 | C: compression factor used to compress 317 | """ 318 | return np.exp(x) / C 319 | 320 | 321 | def pitchfeats(wav,hparams): # 提取pitch特征 322 | 323 | pitches,magnitudes =librosa.piptrack(wav,hparams['sampling_rate'], 324 | n_fft=hparams['fft_size'], 325 | hop_length=hparams['hop_size'], 326 | fmin=hparams['fmin'], 327 | fmax=2000, 328 | win_length=hparams['win_length']) 329 | pitches = pitches.T 330 | magnitudes = magnitudes.T 331 | assert pitches.shape==magnitudes.shape 332 | 333 | pitches = [pitches[i][find_f0(magnitudes[i])] for i,_ in enumerate(pitches) ] # 寻找pitches二维向量中最大值 334 | 335 | return np.asarray(pitches) 336 | 337 | 338 | def find_f0(mags): 339 | tmp=0 340 | mags=list(mags) 341 | for i,mag in enumerate(mags): 342 | if mag < tmp: # 若赋值<0: 343 | # return i-1 344 | if tmp-mag>2: # 若赋值<2+tmp 345 | #return i-1 346 | return mags.index(max(mags[0:i])) #返回最大值所在下下标 347 | else: 348 | return 0 349 | else: # 若赋值>0:令tmp = mag 350 | tmp = mag 351 | return 0 352 | 353 | 354 | def f0_to_coarse(f0, f0_min=35, f0_max=1400, f0_bin = 256): 355 | 356 | f0_mel = 1127 * np.log(1 + f0 / 700) 357 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 358 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 359 | # f0_mel[f0_mel == 0] = 0 360 | # 大于0的分为255个箱 361 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 362 | 363 | f0_mel[f0_mel < 0] = 1 364 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 365 | f0_coarse = np.rint(f0_mel).astype(np.int) 366 | # print('Max f0', np.max(f0_coarse), ' ||Min f0', np.min(f0_coarse)) 367 | assert (np.max(f0_coarse) <= 256 and np.min(f0_coarse) >= 0) 368 | return f0_coarse 369 | -------------------------------------------------------------------------------- /frontend/audio_world_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import copy 5 | 6 | from scipy.interpolate import interp1d 7 | from scipy.signal import firwin 8 | from scipy.signal import lfilter 9 | from multiprocessing import Pool 10 | from multiprocessing import cpu_count 11 | #from sprocket.speech import FeatureExtractor 12 | 13 | 14 | def load_from_file(path, dimension): 15 | data = np.fromfile(path, dtype=np.float32) 16 | if len(data) % dimension != 0: 17 | raise RuntimeError('%s data size is not divided by %d'%(path, dimension)) 18 | data = data.reshape([-1, dimension]) 19 | return data 20 | 21 | 22 | def save_to_file(data, path): 23 | data.astype(np.float32).tofile(path) 24 | 25 | 26 | def _lf02vuv(data): 27 | ''' 28 | generate vuv feature by interpolating lf0 29 | ''' 30 | data = np.reshape(data, (data.size, 1)) 31 | 32 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32) 33 | vuv_vector[data > 0.0] = 1.0 34 | vuv_vector[data <= 0.0] = 0.0 35 | 36 | ip_data = data 37 | 38 | frame_number = data.size 39 | last_value = 0.0 40 | for i in range(frame_number): 41 | if data[i] <= 0.0: 42 | j = i+1 43 | for j in range(i+1, frame_number): 44 | if data[j] > 0.0: 45 | break 46 | if j < frame_number-1: 47 | if last_value > 0.0: 48 | step = (data[j] - data[i-1]) / float(j - i + 1) 49 | for k in range(i, j): 50 | ip_data[k] = data[i-1] + step * (k - i + 1) 51 | else: 52 | for k in range(i, j): 53 | ip_data[k] = data[j] 54 | else: 55 | for k in range(i, frame_number): 56 | ip_data[k] = last_value 57 | else: 58 | ip_data[i] = data[i] 59 | last_value = data[i] 60 | 61 | return ip_data, vuv_vector 62 | 63 | 64 | def _conv1d(data_matrix, kernel): 65 | ''' 66 | convolve each column in data_matrix with kernel 67 | 类似CNN的那种1d卷积 68 | ''' 69 | kernel = kernel.reshape([-1, ]) 70 | kernel_width = int(len(kernel) / 2) 71 | 72 | res = [] 73 | for dim in range(data_matrix.shape[1]): 74 | vector = data_matrix[:, dim].reshape([-1, ]) 75 | vector = np.pad(vector, (kernel_width, kernel_width), 'edge') 76 | res.append(np.correlate(vector, kernel, mode='valid').reshape([-1,1])) 77 | 78 | res = np.concatenate(res, axis=-1) 79 | return res 80 | 81 | 82 | 83 | def extract_feats(world_analysis, wav_dir, feat_dir, filename, mgc_dim=60): 84 | world_analysis_cmd = "{analyze} {wav} {lf0} {mgc} {bap} {mgc_dim}".format(analyze=world_analysis, 85 | wav=os.path.join(wav_dir, filename + '.wav'), 86 | lf0=os.path.join(feat_dir, filename + '.lf0'), 87 | mgc=os.path.join(feat_dir, filename + '.mgc'), 88 | bap=os.path.join(feat_dir, filename + '.bap'), 89 | mgc_dim=mgc_dim) 90 | 91 | 92 | def _merge_feat(feat_dir, out_dir, filenames): 93 | ''' 94 | merge acoustic features 95 | 最终生成的特征为[lf0, lf0与delta的卷积, lf0与acc的卷积, mgc, mgc与delta的卷积, mgc与acc的卷积, bap, bap与delta的卷积, 96 | bap与acc的卷积, vuv] 97 | ''' 98 | for filename in filenames: 99 | lf0_path = os.path.join(feat_dir, filename + '.lf0') 100 | mgc_path = os.path.join(feat_dir, filename + '.mgc') 101 | bap_path = os.path.join(feat_dir, filename + '.bap') 102 | out_path = os.path.join(out_dir, filename + '.cmp') 103 | 104 | lf0_matrix = load_from_file(lf0_path, 1) 105 | mgc_matrix = load_from_file(mgc_path, 1) 106 | bap_matrix = load_from_file(bap_path, 1) 107 | 108 | frame_num = lf0_matrix.shape[0] 109 | mgc_matrix = mgc_matrix.reshape([frame_num, -1]) 110 | bap_matrix = bap_matrix.reshape([frame_num, -1]) 111 | 112 | lf0_matrix, vuv_matrix = _lf02vuv(lf0_matrix) 113 | 114 | delta_win = np.array([-0.5, 0.0, 0.5]) 115 | acc_win = np.array([1.0, -2.0, 1.0]) 116 | res = [] 117 | res.append(lf0_matrix) 118 | res.append(_conv1d(lf0_matrix, delta_win)) 119 | res.append(_conv1d(lf0_matrix, acc_win)) 120 | res.append(mgc_matrix) 121 | res.append(_conv1d(mgc_matrix, delta_win)) 122 | res.append(_conv1d(mgc_matrix, acc_win)) 123 | res.append(bap_matrix) 124 | res.append(_conv1d(bap_matrix, delta_win)) 125 | res.append(_conv1d(bap_matrix, acc_win)) 126 | res.append(vuv_matrix) 127 | res = np.concatenate(res, axis=-1) 128 | 129 | save_to_file(res, out_path) 130 | 131 | return lf0_matrix.shape[1] * 3, mgc_matrix.shape[1] * 3, bap_matrix.shape[1] * 3, vuv_matrix.shape[1] 132 | 133 | 134 | def wav_preprocess(data_dir, tmp_dir, world_dir): 135 | ''' 136 | 从音频中提取特征,并将他们合起来,计算新特征 137 | ''' 138 | logger = logging.getLogger('preprocess') 139 | logger.setLevel(logging.INFO) 140 | 141 | wav_dir = os.path.join(data_dir, 'wavs') 142 | feat_dir = os.path.join(tmp_dir, 'feats') 143 | cmp_dir = os.path.join(tmp_dir, 'cmp') 144 | os.makedirs(feat_dir, exist_ok=True) 145 | os.makedirs(cmp_dir, exist_ok=True) 146 | 147 | filenames = list(set(filename.split('.')[0] for filename in os.listdir(wav_dir))) 148 | split_filenames = [filenames[i::cpu_count()] for i in range(cpu_count())] 149 | world_analysis = os.path.join(world_dir, 'analysis') 150 | 151 | # 使用world提取特征 152 | logger.info('extract feat from wav') 153 | p = Pool(cpu_count()) 154 | results = [] 155 | for filename in filenames: 156 | results.append(p.apply_async(extract_feats, args=[world_analysis, wav_dir, feat_dir, filename])) 157 | p.close() 158 | p.join() 159 | results = [res.get() for res in results] 160 | 161 | # 将lf0,mgc,bap合起来,并得到新特征 162 | logger.info('merge lf0 mgc bap feat') 163 | p = Pool(cpu_count()) 164 | results = [] 165 | for filenames in split_filenames: 166 | results.append(p.apply_async(_merge_feat, args=[feat_dir, cmp_dir, filenames])) 167 | p.close() 168 | p.join() 169 | 170 | logger.info('preprocess wav finish') 171 | 172 | 173 | def world_feature_extract(wav, config): 174 | """WORLD feature extraction 175 | 176 | Args: 177 | queue (multiprocessing.Queue): the queue to store the file name of utterance 178 | wav_list (list): list of the wav files 179 | config (dict): feature extraction config 180 | 181 | """ 182 | # define feature extractor 183 | feature_extractor = FeatureExtractor( 184 | analyzer="world", 185 | fs=config['sampling_rate'], 186 | shiftms=config['hop_size'] / config['sampling_rate'] * 1000, 187 | minf0=config['minf0'], 188 | maxf0=config['maxf0'], 189 | fftl=config['fft_size']) 190 | # extraction 191 | 192 | # extract features 193 | f0, spc, ap = feature_extractor.analyze(wav) 194 | codeap = feature_extractor.codeap() 195 | mcep = feature_extractor.mcep(dim=config['mcep_dim'], alpha=config['mcep_alpha']) 196 | npow = feature_extractor.npow() 197 | uv, cont_f0 = convert_continuos_f0(f0) 198 | lpf_fs = int(config['sampling_rate'] / config['hop_size']) 199 | cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=20) 200 | next_cutoff = 70 201 | while not (cont_f0_lpf >= [0]).all(): 202 | cont_f0_lpf = low_pass_filter(cont_f0, lpf_fs, cutoff=next_cutoff) 203 | next_cutoff *= 2 204 | # concatenate 205 | cont_f0_lpf = np.expand_dims(cont_f0_lpf, axis=-1) 206 | uv = np.expand_dims(uv, axis=-1) 207 | feats = np.concatenate([uv, cont_f0_lpf, mcep, codeap], axis=1) 208 | 209 | # return (feats, f0, ap, spc, npow) 210 | return feats 211 | 212 | def convert_continuos_f0(f0): 213 | """Convert F0 to continuous F0 214 | 215 | Args: 216 | f0 (ndarray): original f0 sequence with the shape (T) 217 | Return: 218 | (ndarray): continuous f0 with the shape (T) 219 | 220 | """ 221 | # get uv information as binary 222 | uv = np.float32(f0 != 0) 223 | # get start and end of f0 224 | if (f0 == 0).all(): 225 | logging.warn("all of the f0 values are 0.") 226 | return uv, f0 227 | start_f0 = f0[f0 != 0][0] 228 | end_f0 = f0[f0 != 0][-1] 229 | # padding start and end of f0 sequence 230 | cont_f0 = copy.deepcopy(f0) 231 | start_idx = np.where(cont_f0 == start_f0)[0][0] 232 | end_idx = np.where(cont_f0 == end_f0)[0][-1] 233 | cont_f0[:start_idx] = start_f0 234 | cont_f0[end_idx:] = end_f0 235 | # get non-zero frame index 236 | nz_frames = np.where(cont_f0 != 0)[0] 237 | # perform linear interpolation 238 | f = interp1d(nz_frames, cont_f0[nz_frames]) 239 | cont_f0 = f(np.arange(0, cont_f0.shape[0])) 240 | 241 | return uv, cont_f0 242 | 243 | 244 | def low_pass_filter(x, fs, cutoff=70, padding=True): 245 | """Low pass filter 246 | 247 | Args: 248 | x (ndarray): Waveform sequence 249 | fs (int): Sampling frequency 250 | cutoff (float): Cutoff frequency of low pass filter 251 | Return: 252 | (ndarray): Low pass filtered waveform sequence 253 | 254 | """ 255 | nyquist = fs // 2 256 | norm_cutoff = cutoff / nyquist 257 | numtaps = 255 258 | fil = firwin(numtaps, norm_cutoff) 259 | x_pad = np.pad(x, (numtaps, numtaps), 'edge') 260 | lpf_x = lfilter(fil, 1, x_pad) 261 | lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2] 262 | 263 | return lpf_x 264 | -------------------------------------------------------------------------------- /frontend/world/analysis: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/Multi-Singer/a6e9f6138a1ddf52ebd4ec29e91795f34c108e42/frontend/world/analysis -------------------------------------------------------------------------------- /frontend/world/synthesis: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/Multi-Singer/a6e9f6138a1ddf52ebd4ec29e91795f34c108e42/frontend/world/synthesis -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """Decode with trained Multi-Singer.""" 5 | 6 | import argparse 7 | import logging 8 | import os 9 | import time 10 | 11 | import numpy as np 12 | import soundfile as sf 13 | import torch 14 | import yaml 15 | 16 | from tqdm import tqdm 17 | 18 | from datasets import MelDataset 19 | from utils import load_model 20 | from utils import read_hdf5 21 | import os 22 | 23 | def main(): 24 | """Run decoding process.""" 25 | parser = argparse.ArgumentParser( 26 | description="Decode dumped features with trained Parallel WaveGAN Generator " 27 | "(See detail in parallel_wavegan/bin/decode.py).") 28 | parser.add_argument("--inputdir",'-i', type=str,required=True, 29 | help="directory including feature files. " 30 | "you need to specify either feats-scp or inputdir.") 31 | parser.add_argument("--outdir",'-o',type=str, required=True, 32 | help="directory to save generated speech.") 33 | parser.add_argument("--checkpoint",'-c',type=str, required=True, 34 | help="checkpoint file to be loaded.") 35 | parser.add_argument("--config", '-g',default=None, type=str, 36 | help="yaml format configuration file. if not explicitly provided, " 37 | "it will be searched in the checkpoint directory. (default=None)") 38 | parser.add_argument("--verbose", type=int, default=1, 39 | help="logging level. higher is more logging. (default=1)") 40 | parser.add_argument("--rank", default=0, type=int, 41 | help="rank for distributed training. no need to explictly specify.") 42 | parser.add_argument("--force_cpu", type=bool, default=False) 43 | args = parser.parse_args() 44 | 45 | # set logger 46 | if args.verbose > 1: 47 | logging.basicConfig( 48 | level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") 49 | elif args.verbose > 0: 50 | logging.basicConfig( 51 | level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") 52 | else: 53 | logging.basicConfig( 54 | level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") 55 | logging.warning("Skip DEBUG/INFO messages") 56 | 57 | # check directory existence 58 | if not os.path.exists(args.outdir): 59 | os.makedirs(args.outdir) 60 | 61 | # load config 62 | if args.config is None: 63 | dirname = os.path.dirname(args.checkpoint) 64 | args.config = os.path.join(dirname, "config.yml") 65 | with open(args.config) as f: 66 | config = yaml.load(f, Loader=yaml.Loader) 67 | config.update(vars(args)) 68 | 69 | # check arguments 70 | if args.inputdir is None: 71 | raise ValueError("Please specify either --inputdir or --feats-scp.") 72 | 73 | # get dataset 74 | if config["format"] == "hdf5": 75 | mel_query = "*.h5" 76 | mel_load_fn = lambda x: read_hdf5(x, "mel") # NOQA 77 | elif config["format"] == "npy": 78 | mel_query = "*-feats.npy" 79 | mel_load_fn = np.load 80 | else: 81 | raise ValueError("Support only hdf5 or npy format.") 82 | dataset = MelDataset( 83 | args.inputdir, 84 | mel_query=mel_query, 85 | mel_load_fn=mel_load_fn, 86 | return_utt_id=True, 87 | ) 88 | logging.info(f"The number of features to be decoded = {len(dataset)}.") 89 | 90 | # setup model 91 | if torch.cuda.is_available(): 92 | device = torch.device("cuda") 93 | torch.cuda.set_device(args.rank) 94 | else: 95 | device = torch.device("cpu") 96 | model = load_model(args.checkpoint, config) 97 | logging.info(f"Loaded model parameters from {args.checkpoint}.") 98 | model.remove_weight_norm() 99 | model = model.eval().to(device) 100 | 101 | # start generation 102 | total_rtf = 0.0 103 | with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar: 104 | for idx, (utt_id, c) in enumerate(pbar, 1): # utt_id: mel id c: mel feats 105 | # generate 106 | if not (os.path.exists(os.path.join(config["outdir"], f"{utt_id}_gen.wav"))): 107 | c = torch.tensor(c, dtype=torch.float).to(device) 108 | start = time.time() 109 | y = model.inference(c).view(-1) 110 | rtf = (time.time() - start) / (len(y) / config["sampling_rate"]) 111 | pbar.set_postfix({"RTF": rtf}) 112 | total_rtf += rtf 113 | 114 | # save as PCM 16 bit wav file 115 | sf.write(os.path.join(config["outdir"], f"{utt_id}_gen.wav"), 116 | y.cpu().numpy(), config["sampling_rate"], "PCM_16") 117 | del c,y 118 | torch.cuda.empty_cache() 119 | 120 | # report average RTF 121 | logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).") 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .causal_conv import * # NOQA 2 | from .pqmf import * # NOQA 3 | from .residual_block import * # NOQA 4 | from .residual_stack import * # NOQA 5 | from .upsample import * # NOQA 6 | -------------------------------------------------------------------------------- /layers/causal_conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Causal convolusion layer modules.""" 7 | 8 | 9 | import torch 10 | 11 | 12 | class CausalConv1d(torch.nn.Module): 13 | """CausalConv1d module with customized initialization.""" 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, 16 | dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}): 17 | """Initialize CausalConv1d module.""" 18 | super(CausalConv1d, self).__init__() 19 | self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) 20 | self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, 21 | dilation=dilation, bias=bias) 22 | 23 | def forward(self, x): 24 | """Calculate forward propagation. 25 | 26 | Args: 27 | x (Tensor): Input tensor (B, in_channels, T). 28 | 29 | Returns: 30 | Tensor: Output tensor (B, out_channels, T). 31 | 32 | """ 33 | return self.conv(self.pad(x))[:, :, :x.size(2)] 34 | 35 | 36 | class CausalConvTranspose1d(torch.nn.Module): 37 | """CausalConvTranspose1d module with customized initialization.""" 38 | 39 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): 40 | """Initialize CausalConvTranspose1d module.""" 41 | super(CausalConvTranspose1d, self).__init__() 42 | self.deconv = torch.nn.ConvTranspose1d( 43 | in_channels, out_channels, kernel_size, stride, bias=bias) 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | """Calculate forward propagation. 48 | 49 | Args: 50 | x (Tensor): Input tensor (B, in_channels, T_in). 51 | 52 | Returns: 53 | Tensor: Output tensor (B, out_channels, T_out). 54 | 55 | """ 56 | return self.deconv(x)[:, :, :-self.stride] 57 | -------------------------------------------------------------------------------- /layers/pqmf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Pseudo QMF modules.""" 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from scipy.signal import kaiser 13 | 14 | 15 | def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): 16 | """Design prototype filter for PQMF. 17 | 18 | This method is based on `A Kaiser window approach for the design of prototype 19 | filters of cosine modulated filterbanks`_. 20 | 21 | Args: 22 | taps (int): The number of filter taps. 23 | cutoff_ratio (float): Cut-off frequency ratio. 24 | beta (float): Beta coefficient for kaiser window. 25 | 26 | Returns: 27 | ndarray: Impluse response of prototype filter (taps + 1,). 28 | 29 | .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: 30 | https://ieeexplore.ieee.org/abstract/document/681427 31 | 32 | """ 33 | # check the arguments are valid 34 | assert taps % 2 == 0, "The number of taps mush be even number." 35 | assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." 36 | 37 | # make initial filter 38 | omega_c = np.pi * cutoff_ratio 39 | with np.errstate(invalid='ignore'): 40 | h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \ 41 | / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) 42 | h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form 43 | 44 | # apply kaiser window 45 | w = kaiser(taps + 1, beta) 46 | h = h_i * w 47 | 48 | return h 49 | 50 | 51 | class PQMF(torch.nn.Module): 52 | """PQMF module. 53 | 54 | This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. 55 | 56 | .. _`Near-perfect-reconstruction pseudo-QMF banks`: 57 | https://ieeexplore.ieee.org/document/258122 58 | 59 | """ 60 | 61 | def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): 62 | """Initilize PQMF module. 63 | 64 | The cutoff_ratio and beta parameters are optimized for #subbands = 4. 65 | See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. 66 | 67 | Args: 68 | subbands (int): The number of subbands. 69 | taps (int): The number of filter taps. 70 | cutoff_ratio (float): Cut-off frequency ratio. 71 | beta (float): Beta coefficient for kaiser window. 72 | 73 | """ 74 | super(PQMF, self).__init__() 75 | 76 | # build analysis & synthesis filter coefficients 77 | h_proto = design_prototype_filter(taps, cutoff_ratio, beta) 78 | h_analysis = np.zeros((subbands, len(h_proto))) 79 | h_synthesis = np.zeros((subbands, len(h_proto))) 80 | for k in range(subbands): 81 | h_analysis[k] = 2 * h_proto * np.cos( 82 | (2 * k + 1) * (np.pi / (2 * subbands)) * 83 | (np.arange(taps + 1) - (taps / 2)) + 84 | (-1) ** k * np.pi / 4) 85 | h_synthesis[k] = 2 * h_proto * np.cos( 86 | (2 * k + 1) * (np.pi / (2 * subbands)) * 87 | (np.arange(taps + 1) - (taps / 2)) - 88 | (-1) ** k * np.pi / 4) 89 | 90 | # convert to tensor 91 | analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) 92 | synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) 93 | 94 | # register coefficients as beffer 95 | self.register_buffer("analysis_filter", analysis_filter) 96 | self.register_buffer("synthesis_filter", synthesis_filter) 97 | 98 | # filter for downsampling & upsampling 99 | updown_filter = torch.zeros((subbands, subbands, subbands)).float() 100 | for k in range(subbands): 101 | updown_filter[k, k, 0] = 1.0 102 | self.register_buffer("updown_filter", updown_filter) 103 | self.subbands = subbands 104 | 105 | # keep padding info 106 | self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) 107 | 108 | def analysis(self, x): 109 | """Analysis with PQMF. 110 | 111 | Args: 112 | x (Tensor): Input tensor (B, 1, T). 113 | 114 | Returns: 115 | Tensor: Output tensor (B, subbands, T // subbands). 116 | 117 | """ 118 | x = F.conv1d(self.pad_fn(x), self.analysis_filter) 119 | return F.conv1d(x, self.updown_filter, stride=self.subbands) 120 | 121 | def synthesis(self, x): 122 | """Synthesis with PQMF. 123 | 124 | Args: 125 | x (Tensor): Input tensor (B, subbands, T // subbands). 126 | 127 | Returns: 128 | Tensor: Output tensor (B, 1, T). 129 | 130 | """ 131 | # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. 132 | # Not sure this is the correct way, it is better to check again. 133 | # TODO(kan-bayashi): Understand the reconstruction procedure 134 | x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) 135 | return F.conv1d(self.pad_fn(x), self.synthesis_filter) 136 | -------------------------------------------------------------------------------- /layers/residual_block.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Residual block module in WaveNet. 4 | 5 | This code is modified from https://github.com/r9y9/wavenet_vocoder. 6 | 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | class Conv1d(torch.nn.Conv1d): 16 | """Conv1d module with customized initialization.""" 17 | 18 | def __init__(self, *args, **kwargs): 19 | """Initialize Conv1d module.""" 20 | super(Conv1d, self).__init__(*args, **kwargs) 21 | 22 | def reset_parameters(self): 23 | """Reset parameters.""" 24 | torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") 25 | if self.bias is not None: 26 | torch.nn.init.constant_(self.bias, 0.0) 27 | 28 | 29 | class Conv1d1x1(Conv1d): 30 | """1x1 Conv1d with customized initialization.""" 31 | 32 | def __init__(self, in_channels, out_channels, bias): 33 | """Initialize 1x1 Conv1d module.""" 34 | super(Conv1d1x1, self).__init__(in_channels, out_channels, 35 | kernel_size=1, padding=0, 36 | dilation=1, bias=bias) 37 | 38 | 39 | class ResidualBlock(torch.nn.Module): 40 | """Residual block module in WaveNet.""" 41 | 42 | def __init__(self, 43 | kernel_size=3, 44 | residual_channels=64, 45 | gate_channels=128, 46 | skip_channels=64, 47 | aux_channels=80, # 条件输入维度:80维Mel频谱 48 | dropout=0.0, 49 | dilation=1, 50 | bias=True, 51 | use_causal_conv=False 52 | ): 53 | """Initialize ResidualBlock module. 54 | 55 | Args: 56 | kernel_size (int): Kernel size of dilation convolution layer. 57 | residual_channels (int): Number of channels for residual connection. 58 | skip_channels (int): Number of channels for skip connection. 59 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. 60 | dropout (float): Dropout probability. 61 | dilation (int): Dilation factor. 62 | bias (bool): Whether to add bias parameter in convolution layers. 63 | use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution. 64 | 65 | """ 66 | super(ResidualBlock, self).__init__() 67 | self.dropout = dropout 68 | # no future time stamps available 69 | if use_causal_conv: 70 | padding = (kernel_size - 1) * dilation 71 | else: 72 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 73 | padding = (kernel_size - 1) // 2 * dilation 74 | self.use_causal_conv = use_causal_conv 75 | 76 | # dilation conv 77 | self.conv = Conv1d(residual_channels, gate_channels, kernel_size, 78 | padding=padding, dilation=dilation, bias=bias) 79 | 80 | # local conditioning 加入条件输入 (B, aux_channels, T) -> (B, gate_channels, T) 81 | if aux_channels > 0: 82 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) 83 | else: 84 | self.conv1x1_aux = None 85 | 86 | # conv output is split into two groups GAU门输出拆为两部分: residual与skip connection 87 | gate_out_channels = gate_channels // 2 88 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 89 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) 90 | 91 | def forward(self, x, c): 92 | """Calculate forward propagation. 93 | 94 | Args: 95 | x (Tensor): Input tensor (B, residual_channels, T). 96 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). 97 | 98 | Returns: 99 | Tensor: Output tensor for residual connection (B, residual_channels, T). 100 | Tensor: Output tensor for skip connection (B, skip_channels, T). 101 | 102 | """ 103 | residual = x 104 | x = F.dropout(x, p=self.dropout, training=self.training) 105 | x = self.conv(x) # x经过膨胀卷积 (B, residual_channels, T) -> (B, gate_channels, T) 106 | 107 | # remove future time steps if use_causal_conv conv 去除x中residual未来的时间步 108 | x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x 109 | 110 | # split into two part for gated activation (B, gate_channels, T) -> 2*(B, gate_channels/2, T) 111 | splitdim = 1 112 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) # 拆分后分别通过tanh与sigmoid 113 | 114 | # local conditioning: WaveNet中的条件输入,同样经过拆分后附加到xa,xb上 115 | if c is not None: 116 | assert self.conv1x1_aux is not None 117 | c = self.conv1x1_aux(c) # condition经过一层卷积 118 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 119 | xa, xb = xa + ca, xb + cb 120 | 121 | x = torch.tanh(xa) * torch.sigmoid(xb) 122 | 123 | # for skip connection 124 | s = self.conv1x1_skip(x) 125 | 126 | # for residual connection 卷积输出+残差快 127 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) 128 | 129 | return x, s # 返回residual 和 skip 130 | 131 | 132 | 133 | class ResidualEmbeddingBlock(torch.nn.Module): 134 | """Residual block module in WaveNet.""" 135 | 136 | def __init__(self, 137 | kernel_size=3, 138 | residual_channels=64, 139 | gate_channels=128, 140 | skip_channels=64, 141 | aux_channels=80, # 条件输入维度:80维Mel频谱 142 | embed_channels=256, 143 | dropout=0.0, 144 | dilation=1, 145 | bias=True, 146 | use_causal_conv=False 147 | ): 148 | """Initialize ResidualBlock module. 149 | 150 | Args: 151 | kernel_size (int): Kernel size of dilation convolution layer. 152 | residual_channels (int): Number of channels for residual connection. 153 | skip_channels (int): Number of channels for skip connection. 154 | aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. 155 | dropout (float): Dropout probability. 156 | dilation (int): Dilation factor. 157 | bias (bool): Whether to add bias parameter in convolution layers. 158 | use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution. 159 | 160 | """ 161 | super(ResidualEmbeddingBlock, self).__init__() 162 | self.dropout = dropout 163 | # no future time stamps available 164 | if use_causal_conv: 165 | padding = (kernel_size - 1) * dilation 166 | else: 167 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 168 | padding = (kernel_size - 1) // 2 * dilation 169 | self.use_causal_conv = use_causal_conv 170 | 171 | # dilation conv 172 | self.conv = Conv1d(residual_channels, gate_channels, kernel_size, 173 | padding=padding, dilation=dilation, bias=bias) 174 | 175 | # local conditioning 加入条件输入 (B, aux_channels, T) -> (B, gate_channels, T) 176 | if aux_channels > 0: 177 | self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) 178 | else: 179 | self.conv1x1_aux = None 180 | 181 | if aux_channels > 0: 182 | self.conv1x1_embed = Conv1d1x1(embed_channels, gate_channels, bias=False) 183 | else: 184 | self.conv1x1_embed = None 185 | 186 | # conv output is split into two groups GAU门输出拆为两部分: residual与skip connection 187 | gate_out_channels = gate_channels // 2 188 | self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) 189 | self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) 190 | 191 | def forward(self, x, c, embed): 192 | """Calculate forward propagation. 193 | 194 | Args: 195 | x (Tensor): Input tensor (B, residual_channels, T). 196 | c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). 197 | embed (Tensor): Local conditioning auxiliary tensor (B, embed_channels, T). 198 | 199 | Returns: 200 | Tensor: Output tensor for residual connection (B, residual_channels, T). 201 | Tensor: Output tensor for skip connection (B, skip_channels, T). 202 | 203 | """ 204 | residual = x 205 | x = F.dropout(x, p=self.dropout, training=self.training) 206 | x = self.conv(x) # x经过膨胀卷积 (B, residual_channels, T) -> (B, gate_channels, T) 207 | 208 | # remove future time steps if use_causal_conv conv 去除x中residual未来的时间步 209 | x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x 210 | 211 | # split into two part for gated activation (B, gate_channels, T) -> 2*(B, gate_channels/2, T) 212 | splitdim = 1 213 | xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) # 拆分后分别通过tanh与sigmoid 214 | 215 | # local conditioning: WaveNet中的条件输入,同样经过拆分后附加到xa,xb上 216 | if c is not None and embed is not None: 217 | assert self.conv1x1_aux is not None 218 | assert self.conv1x1_embed is not None 219 | c = self.conv1x1_aux(c) # condition经过一层卷积 220 | ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) 221 | 222 | embed = self.conv1x1_embed(embed) # condition经过一层卷积 223 | ea, eb = embed.split(embed.size(splitdim) // 2, dim=splitdim) 224 | xa, xb = xa + ca + ea, xb + cb + eb 225 | 226 | x = torch.tanh(xa) * torch.sigmoid(xb) 227 | 228 | # for skip connection 229 | s = self.conv1x1_skip(x) 230 | 231 | # for residual connection 卷积输出+残差快 232 | x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) 233 | 234 | return x, s # 返回residual 和 skip 235 | -------------------------------------------------------------------------------- /layers/residual_stack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Residual stack module in MelGAN.""" 7 | 8 | import torch 9 | 10 | from layers import CausalConv1d 11 | 12 | 13 | class ResidualStack(torch.nn.Module): 14 | """Residual stack module introduced in MelGAN.""" 15 | 16 | def __init__(self, 17 | kernel_size=3, 18 | channels=32, 19 | dilation=1, 20 | bias=True, 21 | nonlinear_activation="LeakyReLU", 22 | nonlinear_activation_params={"negative_slope": 0.2}, 23 | pad="ReflectionPad1d", 24 | pad_params={}, 25 | use_causal_conv=False, 26 | ): 27 | """Initialize ResidualStack module. 28 | 29 | Args: 30 | kernel_size (int): Kernel size of dilation convolution layer. 31 | channels (int): Number of channels of convolution layers. 32 | dilation (int): Dilation factor. 33 | bias (bool): Whether to add bias parameter in convolution layers. 34 | nonlinear_activation (str): Activation function module name. 35 | nonlinear_activation_params (dict): Hyperparameters for activation function. 36 | pad (str): Padding function module name before dilated convolution layer. 37 | pad_params (dict): Hyperparameters for padding function. 38 | use_causal_conv (bool): Whether to use causal convolution. 39 | 40 | """ 41 | super(ResidualStack, self).__init__() 42 | 43 | # defile residual stack part 44 | if not use_causal_conv: 45 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 46 | self.stack = torch.nn.Sequential( 47 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 48 | getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), 49 | torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias), 50 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 51 | torch.nn.Conv1d(channels, channels, 1, bias=bias), 52 | ) 53 | else: 54 | self.stack = torch.nn.Sequential( 55 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 56 | CausalConv1d(channels, channels, kernel_size, dilation=dilation, 57 | bias=bias, pad=pad, pad_params=pad_params), 58 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 59 | torch.nn.Conv1d(channels, channels, 1, bias=bias), 60 | ) 61 | 62 | # defile extra layer for skip connection 63 | self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias) 64 | 65 | def forward(self, c): 66 | """Calculate forward propagation. 67 | 68 | Args: 69 | c (Tensor): Input tensor (B, channels, T). 70 | 71 | Returns: 72 | Tensor: Output tensor (B, channels, T). 73 | 74 | """ 75 | return self.stack(c) + self.skip_layer(c) 76 | -------------------------------------------------------------------------------- /layers/tf_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 MINH ANH (@dathudeptrai) 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Tensorflow Layer modules complatible with pytorch.""" 7 | 8 | import tensorflow as tf 9 | 10 | 11 | class TFReflectionPad1d(tf.keras.layers.Layer): 12 | """Tensorflow ReflectionPad1d module.""" 13 | 14 | def __init__(self, padding_size): 15 | """Initialize TFReflectionPad1d module. 16 | 17 | Args: 18 | padding_size (int): Padding size. 19 | 20 | """ 21 | super(TFReflectionPad1d, self).__init__() 22 | self.padding_size = padding_size 23 | 24 | @tf.function 25 | def call(self, x): 26 | """Calculate forward propagation. 27 | 28 | Args: 29 | x (Tensor): Input tensor (B, T, 1, C). 30 | 31 | Returns: 32 | Tensor: Padded tensor (B, T + 2 * padding_size, 1, C). 33 | 34 | """ 35 | return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT") 36 | 37 | 38 | class TFConvTranspose1d(tf.keras.layers.Layer): 39 | """Tensorflow ConvTranspose1d module.""" 40 | 41 | def __init__(self, channels, kernel_size, stride, padding): 42 | """Initialize TFConvTranspose1d( module. 43 | 44 | Args: 45 | channels (int): Number of channels. 46 | kernel_size (int): kernel size. 47 | strides (int): Stride width. 48 | padding (str): Padding type ("same" or "valid"). 49 | 50 | """ 51 | super(TFConvTranspose1d, self).__init__() 52 | self.conv1d_transpose = tf.keras.layers.Conv2DTranspose( 53 | filters=channels, 54 | kernel_size=(kernel_size, 1), 55 | strides=(stride, 1), 56 | padding=padding, 57 | ) 58 | 59 | @tf.function 60 | def call(self, x): 61 | """Calculate forward propagation. 62 | 63 | Args: 64 | x (Tensor): Input tensor (B, T, 1, C). 65 | 66 | Returns: 67 | Tensors: Output tensor (B, T', 1, C'). 68 | 69 | """ 70 | x = self.conv1d_transpose(x) 71 | return x 72 | 73 | 74 | class TFResidualStack(tf.keras.layers.Layer): 75 | """Tensorflow ResidualStack module.""" 76 | 77 | def __init__(self, 78 | kernel_size, 79 | channels, 80 | dilation, 81 | bias, 82 | nonlinear_activation, 83 | nonlinear_activation_params, 84 | padding, 85 | ): 86 | """Initialize TFResidualStack module. 87 | 88 | Args: 89 | kernel_size (int): Kernel size. 90 | channles (int): Number of channels. 91 | dilation (int): Dilation ine. 92 | bias (bool): Whether to add bias parameter in convolution layers. 93 | nonlinear_activation (str): Activation function module name. 94 | nonlinear_activation_params (dict): Hyperparameters for activation function. 95 | padding (str): Padding type ("same" or "valid"). 96 | 97 | """ 98 | super(TFResidualStack, self).__init__() 99 | self.block = [ 100 | getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), 101 | TFReflectionPad1d(dilation), 102 | tf.keras.layers.Conv2D( 103 | filters=channels, 104 | kernel_size=(kernel_size, 1), 105 | dilation_rate=(dilation, 1), 106 | use_bias=bias, 107 | padding="valid", 108 | ), 109 | getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), 110 | tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) 111 | ] 112 | self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) 113 | 114 | @tf.function 115 | def call(self, x): 116 | """Calculate forward propagation. 117 | 118 | Args: 119 | x (Tensor): Input tensor (B, T, 1, C). 120 | 121 | Returns: 122 | Tensor: Output tensor (B, T, 1, C). 123 | 124 | """ 125 | _x = tf.identity(x) 126 | for i, layer in enumerate(self.block): 127 | _x = layer(_x) 128 | shortcut = self.shortcut(x) 129 | return shortcut + _x 130 | -------------------------------------------------------------------------------- /layers/upsample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Upsampling module. 4 | 5 | This code is modified from https://github.com/r9y9/wavenet_vocoder. 6 | 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from layers import Conv1d 14 | 15 | 16 | class Stretch2d(torch.nn.Module): # 复制上采样 17 | """Stretch2d module.""" 18 | 19 | def __init__(self, x_scale, y_scale, mode="nearest"): 20 | """Initialize Stretch2d module. 21 | 22 | Args: 23 | x_scale (int): X scaling factor (Time axis in spectrogram). 24 | y_scale (int): Y scaling factor (Frequency axis in spectrogram). 25 | mode (str): Interpolation mode. 26 | 27 | """ 28 | super(Stretch2d, self).__init__() 29 | self.x_scale = x_scale 30 | self.y_scale = y_scale 31 | self.mode = mode 32 | 33 | def forward(self, x): 34 | """Calculate forward propagation. 35 | 36 | Args: 37 | x (Tensor): Input tensor (B, C, F, T). 38 | 39 | Returns: 40 | Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), 41 | 42 | """ 43 | return F.interpolate( 44 | x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) 45 | 46 | 47 | class Conv2d(torch.nn.Conv2d): 48 | """Conv2d module with customized initialization.""" 49 | 50 | def __init__(self, *args, **kwargs): 51 | """Initialize Conv2d module.""" 52 | super(Conv2d, self).__init__(*args, **kwargs) 53 | 54 | def reset_parameters(self): 55 | """Reset parameters.""" 56 | self.weight.data.fill_(1. / np.prod(self.kernel_size)) 57 | if self.bias is not None: 58 | torch.nn.init.constant_(self.bias, 0.0) 59 | 60 | 61 | class UpsampleNetwork(torch.nn.Module): 62 | """Upsampling network module.""" 63 | 64 | def __init__(self, 65 | upsample_scales, 66 | nonlinear_activation=None, 67 | nonlinear_activation_params={}, 68 | interpolate_mode="nearest", 69 | freq_axis_kernel_size=1, 70 | use_causal_conv=False, 71 | ): 72 | """Initialize upsampling network module. 73 | 74 | Args: 75 | upsample_scales (list): List of upsampling scales. 76 | nonlinear_activation (str): Activation function name. 77 | nonlinear_activation_params (dict): Arguments for specified activation function. 78 | interpolate_mode (str): Interpolation mode. 79 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 80 | 81 | """ 82 | super(UpsampleNetwork, self).__init__() 83 | self.use_causal_conv = use_causal_conv # 是否使用因果卷积 84 | self.up_layers = torch.nn.ModuleList() 85 | for scale in upsample_scales: 86 | # interpolation layer 87 | stretch = Stretch2d(scale, 1, interpolate_mode) 88 | self.up_layers += [stretch] 89 | 90 | # conv layer 91 | assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." 92 | freq_axis_padding = (freq_axis_kernel_size - 1) // 2 93 | kernel_size = (freq_axis_kernel_size, scale * 2 + 1) 94 | if use_causal_conv: 95 | padding = (freq_axis_padding, scale * 2) 96 | else: 97 | padding = (freq_axis_padding, scale) 98 | conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 99 | self.up_layers += [conv] 100 | 101 | # nonlinear 102 | if nonlinear_activation is not None: # 使用非线性激活 103 | nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) 104 | self.up_layers += [nonlinear] 105 | 106 | def forward(self, c): 107 | """Calculate forward propagation. 108 | 109 | Args: 110 | c : Input tensor (B, C, T). 111 | 112 | Returns: 113 | Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). 114 | 115 | """ 116 | c = c.unsqueeze(1) # (B, 1, C, T) 117 | for f in self.up_layers: 118 | if self.use_causal_conv and isinstance(f, Conv2d): # 若使用因果卷积且当前为卷积层 119 | c = f(c)[..., :c.size(-1)] # 获得[B, 1, C, T''],即在时间步上实现因果 120 | else: 121 | c = f(c) 122 | return c.squeeze(1) # (B, C, T') 123 | 124 | 125 | class ConvInUpsampleNetwork(torch.nn.Module): 126 | """Convolution + upsampling network module.""" 127 | 128 | def __init__(self, 129 | upsample_scales, 130 | nonlinear_activation=None, 131 | nonlinear_activation_params={}, 132 | interpolate_mode="nearest", 133 | freq_axis_kernel_size=1, 134 | aux_channels=80, 135 | aux_context_window=0, 136 | use_causal_conv=False 137 | ): 138 | """Initialize convolution + upsampling network module. 139 | 140 | Args: 141 | upsample_scales (list): List of upsampling scales. 142 | nonlinear_activation (str): Activation function name. 143 | nonlinear_activation_params (dict): Arguments for specified activation function. 144 | mode (str): Interpolation mode. 145 | freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. 146 | aux_channels (int): Number of channels of pre-convolutional layer. 147 | aux_context_window (int): Context window size of the pre-convolutional layer. 148 | use_causal_conv (bool): Whether to use causal structure. 149 | 150 | """ 151 | super(ConvInUpsampleNetwork, self).__init__() 152 | self.aux_context_window = aux_context_window 153 | self.use_causal_conv = use_causal_conv and aux_context_window > 0 154 | # To capture wide-context information in conditional features 使用卷积层+上采样层 155 | kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 156 | # NOTE(kan-bayashi): Here do not use padding because the input is already padded 157 | self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) # 时间步T' -> (T' - aux_context_window * 2) 158 | self.upsample = UpsampleNetwork( 159 | upsample_scales=upsample_scales, 160 | nonlinear_activation=nonlinear_activation, 161 | nonlinear_activation_params=nonlinear_activation_params, 162 | interpolate_mode=interpolate_mode, 163 | freq_axis_kernel_size=freq_axis_kernel_size, 164 | use_causal_conv=use_causal_conv, 165 | ) 166 | 167 | def forward(self, c): 168 | """Calculate forward propagation. 169 | 170 | Args: 171 | c : Input tensor (B, C, T'). 172 | 173 | Returns: 174 | Tensor: Upsampled tensor (B, C, T), 175 | where T = (T' - aux_context_window * 2) * prod(upsample_scales). 176 | 177 | Note: 178 | The length of inputs considers the context window size. 179 | 180 | """ 181 | c_ = self.conv_in(c) 182 | c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ 183 | return self.upsample(c) 184 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .stft_loss import * # NOQA 2 | -------------------------------------------------------------------------------- /losses/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # MIT License (https://opensource.org/licenses/MIT) 4 | 5 | """STFT-based Loss modules.""" 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from distutils.version import LooseVersion 11 | 12 | is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7") 13 | 14 | 15 | def stft(x, fft_size, hop_size, win_length, window): 16 | """Perform STFT and convert to magnitude spectrogram. 17 | 18 | Args: 19 | x (Tensor): Input signal tensor (B, T). 20 | fft_size (int): FFT size. 21 | hop_size (int): Hop size. 22 | win_length (int): Window length. 23 | window (str): Window function type. 24 | 25 | Returns: 26 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 27 | 28 | """ 29 | if is_pytorch_17plus: 30 | x_stft = torch.stft( 31 | x, fft_size, hop_size, win_length, window, return_complex=False 32 | ) 33 | else: 34 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 35 | real = x_stft[..., 0] 36 | imag = x_stft[..., 1] 37 | 38 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 39 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 40 | 41 | 42 | class SpectralConvergenceLoss(torch.nn.Module): 43 | """Spectral convergence loss module.""" 44 | 45 | def __init__(self): 46 | """Initilize spectral convergence loss module.""" 47 | super(SpectralConvergenceLoss, self).__init__() 48 | 49 | def forward(self, x_mag, y_mag): 50 | """Calculate forward propagation. 51 | 52 | Args: 53 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 54 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 55 | 56 | Returns: 57 | Tensor: Spectral convergence loss value. 58 | 59 | """ 60 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 61 | 62 | 63 | class LogSTFTMagnitudeLoss(torch.nn.Module): 64 | """Log STFT magnitude loss module.""" 65 | 66 | def __init__(self): 67 | """Initilize los STFT magnitude loss module.""" 68 | super(LogSTFTMagnitudeLoss, self).__init__() 69 | 70 | def forward(self, x_mag, y_mag): 71 | """Calculate forward propagation. 72 | 73 | Args: 74 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 75 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 76 | 77 | Returns: 78 | Tensor: Log STFT magnitude loss value. 79 | 80 | """ 81 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 82 | 83 | 84 | class STFTLoss(torch.nn.Module): 85 | """STFT loss module.""" 86 | 87 | def __init__( 88 | self, fft_size=1024, shift_size=120, win_length=600, window="hann_window" 89 | ): 90 | """Initialize STFT loss module.""" 91 | super(STFTLoss, self).__init__() 92 | self.fft_size = fft_size 93 | self.shift_size = shift_size 94 | self.win_length = win_length 95 | self.spectral_convergence_loss = SpectralConvergenceLoss() 96 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 97 | # NOTE(kan-bayashi): Use register_buffer to fix #223 98 | self.register_buffer("window", getattr(torch, window)(win_length)) 99 | 100 | def forward(self, x, y): 101 | """Calculate forward propagation. 102 | 103 | Args: 104 | x (Tensor): Predicted signal (B, T). 105 | y (Tensor): Groundtruth signal (B, T). 106 | 107 | Returns: 108 | Tensor: Spectral convergence loss value. 109 | Tensor: Log STFT magnitude loss value. 110 | 111 | """ 112 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 113 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 114 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag) 115 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 116 | 117 | return sc_loss, mag_loss 118 | 119 | 120 | class MultiResolutionSTFTLoss(torch.nn.Module): 121 | """Multi resolution STFT loss module.""" 122 | 123 | def __init__( 124 | self, 125 | fft_sizes=[1024, 2048, 512], 126 | hop_sizes=[120, 240, 50], 127 | win_lengths=[600, 1200, 240], 128 | window="hann_window", 129 | ): 130 | """Initialize Multi resolution STFT loss module. 131 | 132 | Args: 133 | fft_sizes (list): List of FFT sizes. 134 | hop_sizes (list): List of hop sizes. 135 | win_lengths (list): List of window lengths. 136 | window (str): Window function type. 137 | 138 | """ 139 | super(MultiResolutionSTFTLoss, self).__init__() 140 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 141 | self.stft_losses = torch.nn.ModuleList() 142 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 143 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 144 | 145 | def forward(self, x, y): 146 | """Calculate forward propagation. 147 | 148 | Args: 149 | x (Tensor): Predicted signal (B, T). 150 | y (Tensor): Groundtruth signal (B, T). 151 | 152 | Returns: 153 | Tensor: Multi resolution spectral convergence loss value. 154 | Tensor: Multi resolution log STFT magnitude loss value. 155 | 156 | """ 157 | sc_loss = 0.0 158 | mag_loss = 0.0 159 | for f in self.stft_losses: 160 | sc_l, mag_l = f(x, y) 161 | sc_loss += sc_l 162 | mag_loss += mag_l 163 | sc_loss /= len(self.stft_losses) 164 | mag_loss /= len(self.stft_losses) 165 | 166 | return sc_loss, mag_loss 167 | -------------------------------------------------------------------------------- /models/Discriminator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | """Multi-Singer Modules.""" 6 | 7 | import numpy as np 8 | import torch 9 | import logging 10 | from layers import Conv1d 11 | 12 | 13 | 14 | class Unconditional_Discriminator(torch.nn.Module): 15 | """Unconditional Discriminator module.""" 16 | 17 | def __init__(self, 18 | in_channels=1, 19 | out_channels=1, 20 | kernel_size=3, 21 | layers=10, 22 | conv_channels=64, 23 | dilation_factor=1, 24 | nonlinear_activation="LeakyReLU", 25 | nonlinear_activation_params={"negative_slope": 0.2}, 26 | bias=True, 27 | use_weight_norm=True, 28 | ): 29 | """Initialize Unconditional Discriminator module. 30 | 31 | Args: 32 | in_channels (int): Number of input channels. 33 | out_channels (int): Number of output channels. 34 | kernel_size (int): Number of output channels. 35 | layers (int): Number of conv layers. 36 | conv_channels (int): Number of chnn layers. 37 | dilation_factor (int): Dilation factor. For example, if dilation_factor = 2, 38 | the dilation will be 2, 4, 8, ..., and so on. 39 | nonlinear_activation (str): Nonlinear function after each conv. 40 | nonlinear_activation_params (dict): Nonlinear function parameters 41 | bias (bool): Whether to use bias parameter in conv. 42 | use_weight_norm (bool) Whether to use weight norm. 43 | If set to true, it will be applied to all of the conv layers. 44 | 45 | """ 46 | 47 | super(Unconditional_Discriminator, self).__init__() 48 | assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." 49 | assert dilation_factor > 0, "Dilation factor must be > 0." 50 | self.conv_layers = torch.nn.ModuleList() 51 | conv_in_channels = in_channels 52 | for i in range(layers - 1): # (B, 1, T) -> (B, 64, T) 53 | if i == 0: 54 | dilation = 1 55 | else: 56 | dilation = i if dilation_factor == 1 else dilation_factor ** i 57 | conv_in_channels = conv_channels 58 | padding = (kernel_size - 1) // 2 * dilation 59 | conv_layer = [ 60 | Conv1d(conv_in_channels, conv_channels, 61 | kernel_size=kernel_size, padding=padding, 62 | dilation=dilation, bias=bias), 63 | getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params) 64 | ] 65 | self.conv_layers += conv_layer 66 | padding = (kernel_size - 1) // 2 67 | last_conv_layer = Conv1d( # (B, 64, T) -> (B, 1, T) 68 | conv_in_channels, out_channels, 69 | kernel_size=kernel_size, padding=padding, bias=bias) 70 | self.conv_layers += [last_conv_layer] 71 | 72 | # apply weight norm 73 | if use_weight_norm: 74 | self.apply_weight_norm() 75 | 76 | def forward(self, x): 77 | """Calculate forward propagation. 78 | 79 | Args: 80 | x (Tensor): Input noise signal (B, 1, T). 81 | 82 | Returns: 83 | Tensor: Output tensor (B, 1, T) 84 | 85 | """ 86 | for f in self.conv_layers: 87 | x = f(x) 88 | return x 89 | 90 | def apply_weight_norm(self): 91 | """Apply weight normalization module from all of the layers.""" 92 | def _apply_weight_norm(m): 93 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): 94 | torch.nn.utils.weight_norm(m) 95 | logging.debug(f"Weight norm is applied to {m}.") 96 | 97 | self.apply(_apply_weight_norm) 98 | 99 | def remove_weight_norm(self): 100 | """Remove weight normalization module from all of the layers.""" 101 | def _remove_weight_norm(m): 102 | try: 103 | logging.debug(f"Weight norm is removed from {m}.") 104 | torch.nn.utils.remove_weight_norm(m) 105 | except ValueError: # this module didn't have weight norm 106 | return 107 | 108 | self.apply(_remove_weight_norm) 109 | 110 | 111 | class SingerConditional_Discriminator(torch.nn.Module): 112 | """SingerConditional Discriminator module.""" 113 | 114 | def __init__(self, 115 | in_channels=1, 116 | out_channels=256, 117 | kernel_sizes=[5, 3], 118 | channels=16, 119 | max_downsample_channels=1024, 120 | bias=True, 121 | downsample_scales=[4, 4, 4, 4], 122 | nonlinear_activation="LeakyReLU", 123 | nonlinear_activation_params={"negative_slope": 0.2}, 124 | pad="ReflectionPad1d", 125 | pad_params={}, 126 | model_hidden_size=256, 127 | model_num_layers=3 128 | ): 129 | """Initialize SingerConditional Discriminator module. 130 | """ 131 | super(SingerConditional_Discriminator, self).__init__() 132 | 133 | self.lstm = torch.nn.LSTM(input_size=model_hidden_size, 134 | hidden_size=model_hidden_size, 135 | num_layers=model_num_layers) 136 | 137 | self.linear = torch.nn.Linear(model_hidden_size,1) 138 | self.relu = torch.nn.ReLU() 139 | 140 | # add first layer (B, 1, T) -> (B, channels, T) 141 | self.layers = torch.nn.ModuleList() 142 | self.layers += [ 143 | torch.nn.Sequential( 144 | getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), 145 | torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias), 146 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 147 | ) 148 | ] 149 | 150 | # add downsample layers (B, channels, T) -> (B, channels*downsample_scale[0], T/downsample_scale[0]) 151 | # -> ... -> (B, channels*downsample_scale[0]*...*downsample_scale[3], T/(downsample_scale[0]*...*downsample_scale[3])) 152 | # -> ... -> (B, channels*downsample_scale[0]*...*downsample_scale[3], T/product(downsample_scale)) 153 | in_chs = channels 154 | for downsample_scale in downsample_scales: 155 | out_chs = min(in_chs * downsample_scale, max_downsample_channels) 156 | self.layers += [ 157 | torch.nn.Sequential( 158 | torch.nn.Conv1d( 159 | in_chs, out_chs, 160 | kernel_size=downsample_scale * 10 + 1, 161 | stride=downsample_scale, 162 | padding=downsample_scale * 5, 163 | groups=in_chs // 4, 164 | bias=bias, 165 | ), 166 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 167 | ) 168 | ] 169 | in_chs = out_chs 170 | 171 | # add final layers (B, channels*downsample_scale[0]*...*downsample_scale[3], T/product(downsample_scale)) -> (B, channels*downsample_scale[0]*...*downsample_scale[3], T/product(downsample_scale)) 172 | out_chs = min(in_chs * 2, max_downsample_channels) 173 | self.layers += [ 174 | torch.nn.Sequential( 175 | torch.nn.Conv1d( 176 | in_chs, out_chs, kernel_sizes[0], 177 | padding=(kernel_sizes[0] - 1) // 2, 178 | bias=bias, 179 | ), 180 | getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), 181 | ) 182 | ] 183 | self.layers += [ # (B, channels*downsample_scale[0]*...*downsample_scale[3], T/product(downsample_scale)) -> (B, 1, T/product(downsample_scale)) 184 | torch.nn.Conv1d( 185 | out_chs, out_channels, kernel_sizes[1], 186 | padding=(kernel_sizes[1] - 1) // 2, 187 | bias=bias, 188 | ), 189 | ] 190 | 191 | 192 | def forward(self, x, embed): 193 | """Calculate forward propagation. 194 | 195 | Args: 196 | x (Tensor): Input noise signal (B, 1, T). 197 | embed (Tensor): Local conditioning auxiliary features (B, C ,1). 198 | 199 | Returns: 200 | Tensor: Output tensor (B, out_channels, T) 201 | 202 | """ 203 | 204 | for f in self.layers: # (B, 1, T) -> (B, 256, T/prob(downscale)) 205 | x = f(x) 206 | 207 | frames_batch = x.permute(2,0,1) # (B, 256, T/prob(downscale)) -> (seq_len, B, 256) 208 | output, (hn, cn) = self.lstm(frames_batch) # output: (seq_len, batch, model_embedding_size) hidden: (layers, batch, model_embedding_size) 209 | 210 | p = output[-1] + embed.squeeze(2) # (batch, model_embedding_size) + (batch, model_embedding_size) 211 | p = self.relu(self.linear(p)) # (batch, 1) 212 | return p 213 | 214 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .Generator import * # NOQA 2 | from .Discriminator import * # NOQA 3 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.optim import * # NOQA 2 | 3 | from .radam import * # NOQA 4 | -------------------------------------------------------------------------------- /optimizers/radam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """RAdam optimizer. 4 | 5 | This code is drived from https://github.com/LiyuanLucasLiu/RAdam. 6 | """ 7 | 8 | import math 9 | import torch 10 | 11 | from torch.optim.optimizer import Optimizer 12 | 13 | 14 | class RAdam(Optimizer): 15 | """Rectified Adam optimizer.""" 16 | 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 18 | """Initilize RAdam optimizer.""" 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 20 | self.buffer = [[None, None, None] for ind in range(10)] 21 | super(RAdam, self).__init__(params, defaults) 22 | 23 | def __setstate__(self, state): 24 | """Set state.""" 25 | super(RAdam, self).__setstate__(state) 26 | 27 | def step(self, closure=None): 28 | """Run one step.""" 29 | loss = None 30 | if closure is not None: 31 | loss = closure() 32 | 33 | for group in self.param_groups: 34 | 35 | for p in group['params']: 36 | if p.grad is None: 37 | continue 38 | grad = p.grad.data.float() 39 | if grad.is_sparse: 40 | raise RuntimeError('RAdam does not support sparse gradients') 41 | 42 | p_data_fp32 = p.data.float() 43 | 44 | state = self.state[p] 45 | 46 | if len(state) == 0: 47 | state['step'] = 0 48 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 49 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 50 | else: 51 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 52 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 53 | 54 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 55 | beta1, beta2 = group['betas'] 56 | 57 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 58 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 59 | 60 | state['step'] += 1 61 | buffered = self.buffer[int(state['step'] % 10)] 62 | if state['step'] == buffered[0]: 63 | N_sma, step_size = buffered[1], buffered[2] 64 | else: 65 | buffered[0] = state['step'] 66 | beta2_t = beta2 ** state['step'] 67 | N_sma_max = 2 / (1 - beta2) - 1 68 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 69 | buffered[1] = N_sma 70 | 71 | # more conservative since it's an approximated value 72 | if N_sma >= 5: 73 | step_size = math.sqrt( 74 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA 75 | else: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | buffered[2] = step_size 78 | 79 | if group['weight_decay'] != 0: 80 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 81 | 82 | # more conservative since it's an approximated value 83 | if N_sma >= 5: 84 | denom = exp_avg_sq.sqrt().add_(group['eps']) 85 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 86 | else: 87 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 88 | 89 | p.data.copy_(p_data_fp32) 90 | 91 | return loss 92 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import torch 6 | import logging 7 | import os 8 | from encoder import inference as encoder 9 | import librosa 10 | import numpy as np 11 | import soundfile as sf 12 | import yaml 13 | import random 14 | from tqdm import tqdm 15 | from multiprocessing.pool import Pool 16 | 17 | from datasets import AudioDataset 18 | from frontend.audio_preprocess import logmelfilterbank, pitchfeats, f0_to_coarse 19 | from frontend.audio_world_process import world_feature_extract, convert_continuos_f0, low_pass_filter 20 | from utils import write_hdf5 21 | from utils import simple_table 22 | 23 | 24 | 25 | def normalize(S): 26 | return np.clip((S + 100) / 100, -2, 2) 27 | 28 | 29 | def extract_feats(wav, outdir, utt_id, config): 30 | 31 | wav = wav / np.abs(wav).max() * 0.5 32 | h5_file = os.path.join(outdir, f"{utt_id}.h5") 33 | if config['feat_type'] == 'librosa': 34 | mel = logmelfilterbank(wav, config) # 35 | frames = len(mel) 36 | mel = normalize(mel) * 2 37 | # mel = melspectrogram(x, config).T 38 | write_hdf5(h5_file, "mel", mel.astype(np.float32)) 39 | if config["use_chroma"]: 40 | chromagram = librosa.feature.chroma_stft(wav, 41 | sr=config["sampling_rate"], 42 | hop_length=config["hop_size"]) 43 | write_hdf5(h5_file, "chroma", chromagram.T.astype(np.float32)) 44 | 45 | if config["use_f0"]: 46 | f0 = pitchfeats(wav, config) 47 | write_hdf5(h5_file, "f0_origin", f0.astype(np.float)) 48 | 49 | if config["use_embed"]: 50 | wav_torch = torch.from_numpy(wav) 51 | preprocessed_wav = encoder.preprocess_wav_torch(wav_torch) 52 | embed = encoder.embed_utterance_torch_preprocess(preprocessed_wav) 53 | embed = embed.detach().numpy() 54 | write_hdf5(h5_file, "embed", embed.astype(np.float32)) 55 | 56 | elif config['feat_type'] == 'world': 57 | feats = world_feature_extract(wav, config) 58 | frames = len(feats) 59 | write_hdf5(h5_file, "feats", feats.astype(np.float32)) 60 | 61 | else: 62 | raise NotImplementedError("Currently, only 'world'、'librosa' are supported.") 63 | 64 | audio = np.pad(wav, (0, config["fft_size"]), mode="edge") 65 | audio = audio[:frames * config["hop_size"]] 66 | assert frames * config["hop_size"] == len(audio) 67 | 68 | write_hdf5(h5_file, "wav", audio.astype(np.float32)) 69 | 70 | return utt_id, h5_file, frames, len(audio) 71 | 72 | 73 | def write2file(values, config, outdir): 74 | test_nums = config['test_num'] 75 | train_text = open(os.path.join(outdir, 'train.txt'), 'w', encoding='utf-8') 76 | dev_text = open(os.path.join(outdir, 'dev.txt'), 'w', encoding='utf-8') 77 | 78 | for v in values[:test_nums]: 79 | dev_text.write('|'.join([str(x) for x in v]) + '\n') 80 | for v in values[test_nums:]: 81 | train_text.write('|'.join([str(x) for x in v]) + '\n') 82 | 83 | mel_frames = sum([int(m[2]) for m in values]) 84 | timesteps = sum([int(m[3]) for m in values]) 85 | sr = config['sampling_rate'] 86 | hours = timesteps / sr / 3600 87 | logging.info('Write {} utterances, {} mel frames, {} audio timesteps, ({:.2f} hours)'.format( 88 | len(values), mel_frames, timesteps, hours)) 89 | logging.info('Max mel frames length: {}'.format(max(int(m[2]) for m in values))) 90 | logging.info('Max audio timesteps length: {}'.format(max(m[3] for m in values))) 91 | 92 | 93 | def main(): 94 | """Run preprocessing process.""" 95 | parser = argparse.ArgumentParser( 96 | description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py).") 97 | parser.add_argument("--inputdir",'-i', type=str, required=True, 98 | help="directory including wav files. you need to specify either scp or inputdir.") 99 | parser.add_argument("--dumpdir",'-o', type=str,required=True, 100 | help="directory to dump feature files.") 101 | parser.add_argument("--config",'-c', type=str, required=True, 102 | help="yaml format configuration file.") 103 | parser.add_argument("--verbose", type=int, default=1, 104 | help="logging level. higher is more logging. (default=1)") 105 | args = parser.parse_args() 106 | 107 | # set logger 108 | if args.verbose > 1: 109 | logging.basicConfig( 110 | level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") 111 | elif args.verbose > 0: 112 | logging.basicConfig( 113 | level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") 114 | else: 115 | logging.basicConfig( 116 | level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") 117 | logging.warning('Skip DEBUG/INFO messages') 118 | 119 | # load config 120 | with open(args.config) as f: 121 | config = yaml.load(f, Loader=yaml.Loader) 122 | config.update(vars(args)) 123 | # check arguments 124 | if args.inputdir is None: 125 | raise ValueError("Please specify either --rootdir or --wav-scp.") 126 | 127 | # get dataset 128 | assert args.inputdir is not None 129 | dataset = AudioDataset( 130 | args.inputdir, "*.wav", 131 | audio_load_fn=sf.read, 132 | return_utt_id=True, 133 | ) 134 | if config["use_embed"]: 135 | print("Preparing the encoder...") 136 | encoder.load_model(config["enc_model_fpath"],preprocess=True) 137 | 138 | # check directly existence 139 | if not os.path.exists(args.dumpdir): 140 | os.makedirs(args.dumpdir, exist_ok=True) 141 | 142 | # process each data 143 | futures = [] 144 | p = Pool(int(os.getenv('N_PROC', os.cpu_count()))) 145 | 146 | simple_table([ 147 | ('Data Path', args.inputdir), 148 | ('Preprocess Path', args.dumpdir), 149 | ('Config File', args.config), 150 | ('CPU Usage', os.cpu_count()) 151 | ]) 152 | 153 | 154 | 155 | for utt_id, (audio, fs) in tqdm(dataset): 156 | # check 157 | assert len(audio.shape) == 1, \ 158 | f"{utt_id} seems to be multi-channel signal." 159 | assert np.abs(audio).max() <= 1.0, \ 160 | f"{utt_id} seems to be different from 16 bit PCM." 161 | assert fs == config["sampling_rate"], \ 162 | f"{utt_id} seems to have a different sampling rate." 163 | 164 | # trim silence 165 | if config["trim_silence"]: 166 | audio, _ = librosa.effects.trim(audio, 167 | top_db=config["trim_threshold_in_db"], 168 | frame_length=config["trim_frame_size"], 169 | hop_length=config["trim_hop_size"]) 170 | 171 | if "sampling_rate_for_feats" not in config: 172 | x = audio 173 | sampling_rate = config["sampling_rate"] 174 | hop_size = config["hop_size"] 175 | else: 176 | 177 | x = librosa.resample(audio, fs, config["sampling_rate_for_feats"]) 178 | sampling_rate = config["sampling_rate_for_feats"] 179 | assert config["hop_size"] * config["sampling_rate_for_feats"] % fs == 0, \ 180 | "hop_size must be int value. please check sampling_rate_for_feats is correct." 181 | hop_size = config["hop_size"] * config["sampling_rate_for_feats"] // fs 182 | 183 | config["sampling_rate"] = sampling_rate 184 | config["hop_size"] = hop_size 185 | 186 | feats_dir = os.path.join(args.dumpdir, 'feats') 187 | os.makedirs(feats_dir, exist_ok=True) 188 | 189 | futures.append(p.apply_async(extract_feats, args=(x, feats_dir, utt_id, config))) 190 | 191 | p.close() 192 | values = [] 193 | for future in tqdm(futures): 194 | values.append(future.get()) 195 | 196 | random.seed(2020) 197 | random.shuffle(values) 198 | 199 | write2file(values, config, args.dumpdir) 200 | 201 | if __name__ == "__main__": 202 | main() 203 | -------------------------------------------------------------------------------- /pretrained1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/Multi-Singer/a6e9f6138a1ddf52ebd4ec29e91795f34c108e42/pretrained1.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | SoundFile==0.10.3.post1 2 | ipdb==0.13.4 3 | pypesq==1.2.4 4 | scipy==1.5.4 5 | h5py==2.10.0 6 | torchaudio==0.7.0 7 | apex==0.9.10.dev0 8 | torch==1.7.0 9 | webrtcvad==2.0.10 10 | librosa==0.8.0 11 | gdown==3.12.2 12 | tensorflow 13 | matplotlib==3.3.3 14 | tqdm==4.54.0 15 | multiprocess==0.70.12.2 16 | numpy==1.21.1 17 | parallel_wavegan==0.4.8 18 | PyYAML==5.4.1 19 | scikit_learn==0.24.2 20 | tensorboardX 21 | umap==0.1.1 22 | visdom==0.1.8.9 23 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * # NOQA 2 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/Multi-Singer/a6e9f6138a1ddf52ebd4ec29e91795f34c108e42/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rongjiehuang/Multi-Singer/a6e9f6138a1ddf52ebd4ec29e91795f34c108e42/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/display.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('agg') # Use non-interactive backend by default 3 | import matplotlib.pyplot as plt 4 | import time 5 | import numpy as np 6 | import sys 7 | 8 | 9 | def progbar(i, n, size=16): 10 | done = (i * size) // n 11 | bar = '' 12 | for i in range(size): 13 | bar += '█' if i <= done else '░' 14 | return bar 15 | 16 | 17 | def stream(message): 18 | sys.stdout.write(f"\r{message}") 19 | 20 | 21 | def simple_table(item_tuples): 22 | 23 | border_pattern = '+---------------------------------------' 24 | whitespace = ' ' 25 | 26 | headings, cells, = [], [] 27 | 28 | for item in item_tuples: 29 | 30 | heading, cell = str(item[0]), str(item[1]) 31 | 32 | pad_head = True if len(heading) < len(cell) else False 33 | 34 | pad = abs(len(heading) - len(cell)) 35 | pad = whitespace[:pad] 36 | 37 | pad_left = pad[:len(pad)//2] 38 | pad_right = pad[len(pad)//2:] 39 | 40 | if pad_head: 41 | heading = pad_left + heading + pad_right 42 | else: 43 | cell = pad_left + cell + pad_right 44 | 45 | headings += [heading] 46 | cells += [cell] 47 | 48 | border, head, body = '', '', '' 49 | 50 | for i in range(len(item_tuples)): 51 | 52 | temp_head = f'| {headings[i]} ' 53 | temp_body = f'| {cells[i]} ' 54 | 55 | border += border_pattern[:len(temp_head)] 56 | head += temp_head 57 | body += temp_body 58 | 59 | if i == len(item_tuples) - 1: 60 | head += '|' 61 | body += '|' 62 | border += '+' 63 | 64 | print(border) 65 | print(head) 66 | print(border) 67 | print(body) 68 | print(border) 69 | print(' ') 70 | 71 | 72 | def time_since(started): 73 | elapsed = time.time() - started 74 | m = int(elapsed // 60) 75 | s = int(elapsed % 60) 76 | if m >= 60: 77 | h = int(m // 60) 78 | m = m % 60 79 | return f'{h}h {m}m {s}s' 80 | else: 81 | return f'{m}m {s}s' 82 | 83 | 84 | def save_attention(attn, path): 85 | fig = plt.figure(figsize=(12, 6)) 86 | plt.imshow(attn.T, interpolation='nearest', aspect='auto') 87 | fig.savefig(path.parent/f'{path.stem}.png', bbox_inches='tight') 88 | plt.close(fig) 89 | 90 | 91 | def save_spectrogram(M, path, length=None): 92 | M = np.flip(M, axis=0) 93 | if length: M = M[:, :length] 94 | fig = plt.figure(figsize=(12, 6)) 95 | plt.imshow(M, interpolation='nearest', aspect='auto') 96 | fig.savefig(f'{path}.png', bbox_inches='tight') 97 | plt.close(fig) 98 | 99 | 100 | def plot(array): 101 | mpl.interactive(True) 102 | fig = plt.figure(figsize=(30, 5)) 103 | ax = fig.add_subplot(111) 104 | ax.xaxis.label.set_color('grey') 105 | ax.yaxis.label.set_color('grey') 106 | ax.xaxis.label.set_fontsize(23) 107 | ax.yaxis.label.set_fontsize(23) 108 | ax.tick_params(axis='x', colors='grey', labelsize=23) 109 | ax.tick_params(axis='y', colors='grey', labelsize=23) 110 | plt.plot(array) 111 | mpl.interactive(False) 112 | 113 | 114 | def plot_spec(M): 115 | mpl.interactive(True) 116 | M = np.flip(M, axis=0) 117 | plt.figure(figsize=(18,4)) 118 | plt.imshow(M, interpolation='nearest', aspect='auto') 119 | plt.show() 120 | mpl.interactive(False) 121 | 122 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Utility functions.""" 7 | 8 | import fnmatch 9 | import logging 10 | import os 11 | import sys 12 | import tarfile 13 | 14 | from distutils.version import LooseVersion 15 | 16 | import h5py 17 | import numpy as np 18 | import torch 19 | import yaml 20 | 21 | PRETRAINED_MODEL_LIST = { 22 | "ljspeech_parallel_wavegan.v1": "1PdZv37JhAQH6AwNh31QlqruqrvjTBq7U", 23 | "ljspeech_parallel_wavegan.v1.long": "1A9TsrD9fHxFviJVFjCk5W6lkzWXwhftv", 24 | "ljspeech_parallel_wavegan.v1.no_limit": "1CdWKSiKoFNPZyF1lo7Dsj6cPKmfLJe72", 25 | "ljspeech_parallel_wavegan.v3": "1-oZpwpWZMMolDYsCqeL12dFkXSBD9VBq", 26 | "ljspeech_full_band_melgan.v2": "1Kb7q5zBeQ30Wsnma0X23G08zvgDG5oen", 27 | "ljspeech_multi_band_melgan.v2": "1b70pJefKI8DhGYz4SxbEHpxm92tj1_qC", 28 | "jsut_parallel_wavegan.v1": "1qok91A6wuubuz4be-P9R2zKhNmQXG0VQ", 29 | "jsut_multi_band_melgan.v2": "1chTt-76q2p69WPpZ1t1tt8szcM96IKad", 30 | "csmsc_parallel_wavegan.v1": "1QTOAokhD5dtRnqlMPTXTW91-CG7jf74e", 31 | "csmsc_multi_band_melgan.v2": "1G6trTmt0Szq-jWv2QDhqglMdWqQxiXQT", 32 | "arctic_slt_parallel_wavegan.v1": "1_MXePg40-7DTjD0CDVzyduwQuW_O9aA1", 33 | "jnas_parallel_wavegan.v1": "1D2TgvO206ixdLI90IqG787V6ySoXLsV_", 34 | "vctk_parallel_wavegan.v1": "1bqEFLgAroDcgUy5ZFP4g2O2MwcwWLEca", 35 | "vctk_parallel_wavegan.v1.long": "1tO4-mFrZ3aVYotgg7M519oobYkD4O_0-", 36 | "vctk_multi_band_melgan.v2": "10PRQpHMFPE7RjF-MHYqvupK9S0xwBlJ_", 37 | "libritts_parallel_wavegan.v1": "1zHQl8kUYEuZ_i1qEFU6g2MEu99k3sHmR", 38 | "libritts_parallel_wavegan.v1.long": "1b9zyBYGCCaJu0TIus5GXoMF8M3YEbqOw", 39 | "libritts_multi_band_melgan.v2": "1kIDSBjrQvAsRewHPiFwBZ3FDelTWMp64", 40 | } 41 | 42 | 43 | def find_files(root_dir, query="*.wav", include_root_dir=True): 44 | """Find files recursively. 45 | 46 | Args: 47 | root_dir (str): Root root_dir to find. 48 | query (str): Query to find. 49 | include_root_dir (bool): If False, root_dir name is not included. 50 | 51 | Returns: 52 | list: List of found filenames. 53 | 54 | """ 55 | files = [] 56 | for root, dirnames, filenames in os.walk(root_dir, followlinks=True): 57 | for filename in fnmatch.filter(filenames, query): 58 | files.append(os.path.join(root, filename)) 59 | if not include_root_dir: 60 | files = [file_.replace(root_dir + "/", "") for file_ in files] 61 | 62 | return files 63 | 64 | 65 | def read_hdf5(hdf5_name, hdf5_path): 66 | """Read hdf5 dataset. 67 | 68 | Args: 69 | hdf5_name (str): Filename of hdf5 file. 70 | hdf5_path (str): Dataset name in hdf5 file. 71 | 72 | Return: 73 | any: Dataset values. 74 | 75 | """ 76 | if not os.path.exists(hdf5_name): 77 | logging.error(f"There is no such a hdf5 file ({hdf5_name}).") 78 | sys.exit(1) 79 | 80 | hdf5_file = h5py.File(hdf5_name, "r") 81 | 82 | if hdf5_path not in hdf5_file: 83 | logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") 84 | sys.exit(1) 85 | 86 | hdf5_data = hdf5_file[hdf5_path][()] 87 | hdf5_file.close() 88 | 89 | return hdf5_data 90 | 91 | 92 | def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): 93 | """Write dataset to hdf5. 94 | 95 | Args: 96 | hdf5_name (str): Hdf5 dataset filename. 97 | hdf5_path (str): Dataset path in hdf5. 98 | write_data (ndarray): Data to write. 99 | is_overwrite (bool): Whether to overwrite dataset. 100 | 101 | """ 102 | # convert to numpy array 103 | write_data = np.array(write_data) 104 | 105 | # check folder existence 106 | folder_name, _ = os.path.split(hdf5_name) 107 | if not os.path.exists(folder_name) and len(folder_name) != 0: 108 | os.makedirs(folder_name) 109 | 110 | # check hdf5 existence 111 | if os.path.exists(hdf5_name): 112 | # if already exists, open with r+ mode 113 | hdf5_file = h5py.File(hdf5_name, "r+") 114 | # check dataset existence 115 | if hdf5_path in hdf5_file: 116 | if is_overwrite: 117 | logging.warning("Dataset in hdf5 file already exists. " 118 | "recreate dataset in hdf5.") 119 | hdf5_file.__delitem__(hdf5_path) 120 | else: 121 | logging.error("Dataset in hdf5 file already exists. " 122 | "if you want to overwrite, please set is_overwrite = True.") 123 | hdf5_file.close() 124 | sys.exit(1) 125 | else: 126 | # if not exists, open with w mode 127 | hdf5_file = h5py.File(hdf5_name, "w") 128 | 129 | # write data to hdf5 130 | hdf5_file.create_dataset(hdf5_path, data=write_data) 131 | hdf5_file.flush() 132 | hdf5_file.close() 133 | 134 | 135 | class HDF5ScpLoader(object): 136 | """Loader class for a fests.scp file of hdf5 file. 137 | 138 | Examples: 139 | key1 /some/path/a.h5:feats 140 | key2 /some/path/b.h5:feats 141 | key3 /some/path/c.h5:feats 142 | key4 /some/path/d.h5:feats 143 | ... 144 | >>> loader = HDF5ScpLoader("hdf5.scp") 145 | >>> array = loader["key1"] 146 | 147 | key1 /some/path/a.h5 148 | key2 /some/path/b.h5 149 | key3 /some/path/c.h5 150 | key4 /some/path/d.h5 151 | ... 152 | >>> loader = HDF5ScpLoader("hdf5.scp", "feats") 153 | >>> array = loader["key1"] 154 | 155 | key1 /some/path/a.h5:feats_1,feats_2 156 | key2 /some/path/b.h5:feats_1,feats_2 157 | key3 /some/path/c.h5:feats_1,feats_2 158 | key4 /some/path/d.h5:feats_1,feats_2 159 | ... 160 | >>> loader = HDF5ScpLoader("hdf5.scp") 161 | # feats_1 and feats_2 will be concatenated 162 | >>> array = loader["key1"] 163 | 164 | """ 165 | 166 | def __init__(self, feats_scp, default_hdf5_path="feats"): 167 | """Initialize HDF5 scp loader. 168 | 169 | Args: 170 | feats_scp (str): Kaldi-style feats.scp file with hdf5 format. 171 | default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. 172 | 173 | """ 174 | self.default_hdf5_path = default_hdf5_path 175 | with open(feats_scp) as f: 176 | lines = [line.replace("\n", "") for line in f.readlines()] 177 | self.data = {} 178 | for line in lines: 179 | key, value = line.split() 180 | self.data[key] = value 181 | 182 | def get_path(self, key): 183 | """Get hdf5 file path for a given key.""" 184 | return self.data[key] 185 | 186 | def __getitem__(self, key): 187 | """Get ndarray for a given key.""" 188 | p = self.data[key] 189 | if ":" in p: 190 | if len(p.split(",")) == 1: 191 | return read_hdf5(*p.split(":")) 192 | else: 193 | p1, p2 = p.split(":") 194 | feats = [read_hdf5(p1, p) for p in p2.split(",")] 195 | return np.concatenate([f if len(f.shape) != 1 else f.reshape(-1, 1) for f in feats], 1) 196 | else: 197 | return read_hdf5(p, self.default_hdf5_path) 198 | 199 | def __len__(self): 200 | """Return the length of the scp file.""" 201 | return len(self.data) 202 | 203 | def __iter__(self): 204 | """Return the iterator of the scp file.""" 205 | return iter(self.data) 206 | 207 | def keys(self): 208 | """Return the keys of the scp file.""" 209 | return self.data.keys() 210 | 211 | def values(self): 212 | """Return the values of the scp file.""" 213 | for key in self.keys(): 214 | yield self[key] 215 | 216 | 217 | class NpyScpLoader(object): 218 | """Loader class for a fests.scp file of npy file. 219 | 220 | Examples: 221 | key1 /some/path/a.npy 222 | key2 /some/path/b.npy 223 | key3 /some/path/c.npy 224 | key4 /some/path/d.npy 225 | ... 226 | >>> loader = NpyScpLoader("feats.scp") 227 | >>> array = loader["key1"] 228 | 229 | """ 230 | 231 | def __init__(self, feats_scp): 232 | """Initialize npy scp loader. 233 | 234 | Args: 235 | feats_scp (str): Kaldi-style feats.scp file with npy format. 236 | 237 | """ 238 | with open(feats_scp) as f: 239 | lines = [line.replace("\n", "") for line in f.readlines()] 240 | self.data = {} 241 | for line in lines: 242 | key, value = line.split() 243 | self.data[key] = value 244 | 245 | def get_path(self, key): 246 | """Get npy file path for a given key.""" 247 | return self.data[key] 248 | 249 | def __getitem__(self, key): 250 | """Get ndarray for a given key.""" 251 | return np.load(self.data[key]) 252 | 253 | def __len__(self): 254 | """Return the length of the scp file.""" 255 | return len(self.data) 256 | 257 | def __iter__(self): 258 | """Return the iterator of the scp file.""" 259 | return iter(self.data) 260 | 261 | def keys(self): 262 | """Return the keys of the scp file.""" 263 | return self.data.keys() 264 | 265 | def values(self): 266 | """Return the values of the scp file.""" 267 | for key in self.keys(): 268 | yield self[key] 269 | 270 | 271 | def load_model(checkpoint, config=None): 272 | """Load trained model. 273 | 274 | Args: 275 | checkpoint (str): Checkpoint path. 276 | config (dict): Configuration dict. 277 | 278 | Return: 279 | torch.nn.Module: Model instance. 280 | 281 | """ 282 | # load config if not provided 283 | if config is None: 284 | dirname = os.path.dirname(checkpoint) 285 | config = os.path.join(dirname, "config.yml") 286 | with open(config) as f: 287 | config = yaml.load(f, Loader=yaml.Loader) 288 | 289 | # lazy load for circular error 290 | import models 291 | 292 | # get model and load parameters 293 | model_class = getattr( 294 | models, 295 | config.get("generator_type", "ParallelWaveGANGenerator") 296 | ) 297 | model = model_class(**config["generator_params"]) 298 | model.load_state_dict( 299 | torch.load(checkpoint, map_location="cpu")["model"]["generator"] 300 | ) 301 | 302 | # add pqmf if needed 303 | if config["generator_params"]["out_channels"] > 1: 304 | # lazy load for circular error 305 | from layers import PQMF 306 | 307 | pqmf_params = {} 308 | if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion("0.4.2"): 309 | # For compatibility, here we set default values in version <= 0.4.2 310 | pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0) 311 | model.pqmf = PQMF( 312 | subbands=config["generator_params"]["out_channels"], 313 | **config.get("pqmf_params", pqmf_params), 314 | ) 315 | 316 | return model 317 | 318 | 319 | def download_pretrained_model(tag, download_dir=None): 320 | """Download pretrained model form google drive. 321 | 322 | Args: 323 | tag (str): Pretrained model tag. 324 | download_dir (str): Directory to save downloaded files. 325 | 326 | Returns: 327 | str: Path of downloaded model checkpoint. 328 | 329 | """ 330 | assert tag in PRETRAINED_MODEL_LIST, f"{tag} does not exists." 331 | id_ = PRETRAINED_MODEL_LIST[tag] 332 | if download_dir is None: 333 | download_dir = os.path.expanduser("~/.cache/parallel_wavegan") 334 | output_path = f"{download_dir}/{tag}.tar.gz" 335 | os.makedirs(f"{download_dir}", exist_ok=True) 336 | if not os.path.exists(output_path): 337 | # lazy load for compatibility 338 | import gdown 339 | 340 | gdown.download(f"https://drive.google.com/uc?id={id_}", output_path, quiet=False) 341 | with tarfile.open(output_path, 'r:*') as tar: 342 | for member in tar.getmembers(): 343 | if member.isreg(): 344 | member.name = os.path.basename(member.name) 345 | tar.extract(member, f"{download_dir}/{tag}") 346 | checkpoint_path = find_files(f"{download_dir}/{tag}", "checkpoint*.pkl") 347 | 348 | return checkpoint_path[0] 349 | 350 | 351 | def simple_table(item_tuples): 352 | 353 | border_pattern = '+---------------------------------------' 354 | whitespace = ' ' 355 | 356 | headings, cells, = [], [] 357 | 358 | for item in item_tuples: 359 | 360 | heading, cell = str(item[0]), str(item[1]) 361 | 362 | pad_head = True if len(heading) < len(cell) else False 363 | 364 | pad = abs(len(heading) - len(cell)) 365 | pad = whitespace[:pad] 366 | 367 | pad_left = pad[:len(pad)//2] 368 | pad_right = pad[len(pad)//2:] 369 | 370 | if pad_head: 371 | heading = pad_left + heading + pad_right 372 | else: 373 | cell = pad_left + cell + pad_right 374 | 375 | headings += [heading] 376 | cells += [cell] 377 | 378 | border, head, body = '', '', '' 379 | 380 | for i in range(len(item_tuples)): 381 | 382 | temp_head = f'| {headings[i]} ' 383 | temp_body = f'| {cells[i]} ' 384 | 385 | border += border_pattern[:len(temp_head)] 386 | head += temp_head 387 | body += temp_body 388 | 389 | if i == len(item_tuples) - 1: 390 | head += '|' 391 | body += '|' 392 | border += '+' 393 | 394 | print(border) 395 | print(head) 396 | print(border) 397 | print(body) 398 | print(border) 399 | print(' ') 400 | --------------------------------------------------------------------------------