├── .gitignore ├── 1_compute_ctc_att_bnf.py ├── 2_compute_f0.py ├── 3_compute_spk_dvecs.py ├── LICENSE ├── README.md ├── bin ├── solver.py ├── train_linglf02mel_seq2seq_encAddlf0.py ├── train_linglf02mel_seq2seq_lsa.py ├── train_linglf02mel_seq2seq_oneshotvc.py └── train_ppg2mel_oneshotvc.py ├── conf ├── bilstm_ppg2mel_vctk_libri_oneshotvc.yaml └── seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml ├── conformer_ppg_model ├── build_ppg_model.py ├── e2e_asr_common.py ├── en_conformer_ctc_att │ ├── 24epoch.pth │ └── config.yaml ├── encoder │ ├── __init__.py │ ├── attention.py │ ├── conformer_encoder.py │ ├── convolution.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── layer_norm.py │ ├── multi_layer_conv.py │ ├── positionwise_feed_forward.py │ ├── repeat.py │ ├── subsampling.py │ ├── swish.py │ └── vgg.py ├── encoders.py ├── frontend.py ├── log_mel.py ├── nets_utils.py ├── stft.py └── utterance_mvn.py ├── convert_from_wav.py ├── figs ├── README.md ├── seq2seq_bnf2mel.pdf └── seq2seq_bnf2mel.png ├── main.py ├── path.sh ├── requirements.txt ├── run.sh ├── speaker_encoder ├── __init__.py ├── audio.py ├── ckpt │ └── pretrained_bak_5805000.pt ├── compute_embed.py ├── config.py ├── data_objects │ ├── __init__.py │ ├── random_cycler.py │ ├── speaker.py │ ├── speaker_batch.py │ ├── speaker_verification_dataset.py │ └── utterance.py ├── hparams.py ├── inference.py ├── model.py ├── params_data.py ├── params_model.py ├── preprocess.py ├── train.py ├── visualizations.py └── voice_encoder.py ├── src ├── __init__.py ├── abs_model.py ├── audio_utils.py ├── basic_layers.py ├── cnn_postnet.py ├── data_load.py ├── f0_utils.py ├── loss.py ├── loss_fn.py ├── lsa_attention.py ├── mel_decoder_lsa.py ├── mel_decoder_mol_encAddlf0.py ├── mel_decoder_mol_v2.py ├── module.py ├── mol_attention.py ├── nets_utils.py ├── optim.py ├── option.py ├── rnn_decoder_lsa.py ├── rnn_decoder_mol.py ├── rnn_decoder_mol_add_pitch.py ├── rnn_ppg2mel.py ├── solver.py ├── util.py └── vc_utils.py ├── test.sh ├── tools └── Makefile ├── utils ├── f0_utils.py ├── file_related.py ├── load_yaml.py └── tensor_ops.py └── vocoders ├── __init__.py ├── env.py ├── hifigan_model.py ├── utils.py └── vctk_24k10ms ├── config.json └── g_02830000 /.gitignore: -------------------------------------------------------------------------------- 1 | # general 2 | *~ 3 | *.pyc 4 | \#*\# 5 | .\#* 6 | *DS_Store 7 | out.txt 8 | doc/_build 9 | slurm-*.out 10 | tmp* 11 | .eggs/ 12 | .hypothesis/ 13 | .idea 14 | .backup/ 15 | .pytest_cache/ 16 | __pycache__/ 17 | .coverage* 18 | coverage.xml* 19 | .vscode* 20 | .nfs* 21 | .ipynb_checkpoints 22 | 23 | data/f0s.scp 24 | log 25 | conf/tuning 26 | ckpt 27 | vc_gen_wavs 28 | tools/venv 29 | -------------------------------------------------------------------------------- /1_compute_ctc_att_bnf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF). 3 | """ 4 | import sys 5 | import os 6 | import argparse 7 | import torch 8 | import glob2 9 | import soundfile 10 | import librosa 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | from conformer_ppg_model.build_ppg_model import load_ppg_model 15 | 16 | 17 | SAMPLE_RATE=16000 18 | 19 | 20 | def compute_bnf( 21 | output_dir: str, 22 | wav_dir: str, 23 | train_config: str, 24 | model_file: str, 25 | ): 26 | device = "cuda" 27 | 28 | # 1. Build PPG model 29 | ppg_model_local = load_ppg_model(train_config, model_file, device) 30 | 31 | # 2. Glob wav files 32 | wav_file_list = glob2.glob(f"{wav_dir}/**/*.wav") 33 | print(f"Globbing {len(wav_file_list)} wav files.") 34 | 35 | # 3. start to compute ppgs 36 | os.makedirs(output_dir, exist_ok=True) 37 | for wav_file in tqdm(wav_file_list): 38 | audio, sr = soundfile.read(wav_file, always_2d=False) 39 | if sr != SAMPLE_RATE: 40 | audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) 41 | wav_tensor = torch.from_numpy(audio).float().to(device).unsqueeze(0) 42 | wav_length = torch.LongTensor([audio.shape[0]]).to(device) 43 | with torch.no_grad(): 44 | bnf = ppg_model_local(wav_tensor, wav_length) 45 | # bnf = torch.nn.functional.softmax(asr_model.ctc.ctc_lo(bnf), dim=2) 46 | bnf_npy = bnf.squeeze(0).cpu().numpy() 47 | fid = os.path.basename(wav_file).split(".")[0] 48 | bnf_fname = f"{output_dir}/{fid}.ling_feat.npy" 49 | np.save(bnf_fname, bnf_npy, allow_pickle=False) 50 | 51 | 52 | def get_parser(): 53 | parser = argparse.ArgumentParser(description="compute ppg or ctc-bnf or ctc-att-bnf") 54 | 55 | parser.add_argument( 56 | "--output_dir", 57 | type=str, 58 | required=True, 59 | default=None, 60 | ) 61 | parser.add_argument( 62 | "--wav_dir", 63 | type=str, 64 | required=True, 65 | default=None, 66 | ) 67 | parser.add_argument( 68 | "--train_config", 69 | type=str, 70 | default="./conformer_ppg_model/en_conformer_ctc_att/config.yaml", 71 | ) 72 | parser.add_argument( 73 | "--model_file", 74 | type=str, 75 | default="./conformer_ppg_model/en_conformer_ctc_att/24epoch.pth", 76 | ) 77 | 78 | return parser 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = get_parser() 83 | args = parser.parse_args() 84 | kwargs = vars(args) 85 | compute_bnf(**kwargs) 86 | -------------------------------------------------------------------------------- /2_compute_f0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob2 3 | import numpy as np 4 | import io 5 | from tqdm import tqdm 6 | import soundfile 7 | import resampy 8 | import pyworld 9 | 10 | import torch 11 | from multiprocessing import cpu_count 12 | from concurrent.futures import ProcessPoolExecutor 13 | from functools import partial 14 | 15 | 16 | def compute_f0( 17 | wav, 18 | sr, 19 | f0_floor=20.0, 20 | f0_ceil=600.0, 21 | frame_period=10.0 22 | ): 23 | wav = wav.astype(np.float64) 24 | f0, timeaxis = pyworld.harvest( 25 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0) 26 | return f0.astype(np.float32) 27 | 28 | 29 | def compute_f0_from_wav( 30 | wavfile_path, 31 | sampling_rate, 32 | f0_floor, 33 | f0_ceil, 34 | frame_period_ms, 35 | ): 36 | # try: 37 | wav, sr = soundfile.read(wavfile_path) 38 | if len(wav) < sr: 39 | return None, sr, len(wav) 40 | if sr != sampling_rate: 41 | wav = resampy.resample(wav, sr, sampling_rate) 42 | sr = sampling_rate 43 | f0 = compute_f0(wav, sr, f0_floor, f0_ceil, frame_period_ms) 44 | return f0, sr, len(wav) 45 | 46 | 47 | def process_one( 48 | wav_file_path, 49 | args, 50 | output_dir, 51 | ): 52 | fid = os.path.basename(wav_file_path)[:-4] 53 | save_fname = f"{output_dir}/{fid}.f0.npy" 54 | if os.path.isfile(save_fname): 55 | return 56 | 57 | f0, sr, wav_len = compute_f0_from_wav( 58 | wav_file_path, args.sampling_rate, 59 | args.f0_floor, args.f0_ceil, args.frame_period_ms) 60 | if f0 is None: 61 | return 62 | np.save(save_fname, f0, allow_pickle=False) 63 | 64 | 65 | def run(args): 66 | """Compute merged f0 values.""" 67 | output_dir = args.output_dir 68 | os.makedirs(output_dir, exist_ok=True) 69 | 70 | wav_dir = args.wav_dir 71 | # Get file id list 72 | wav_file_list = glob2.glob(f"{wav_dir}/**/*.wav") 73 | print(f"Globbed {len(wav_file_list)} wave files.") 74 | 75 | # Multi-process worker 76 | if args.num_workers < 2 : 77 | for wav_file_path in tqdm(wav_file_list): 78 | process_one(wav_file_path, args, output_dir) 79 | else: 80 | with ProcessPoolExecutor(max_workers=args.num_workers) as executor: 81 | futures = [] 82 | for wav_file_path in wav_file_list: 83 | futures.append(executor.submit( 84 | partial( 85 | process_one, wav_file_path, args, output_dir, 86 | ) 87 | )) 88 | results = [future.result() for future in tqdm(futures)] 89 | 90 | 91 | def get_parser(): 92 | import argparse 93 | parser = argparse.ArgumentParser(description="Compute merged f0 values") 94 | parser.add_argument( 95 | "--wav_dir", 96 | required=True, 97 | type=str, 98 | ) 99 | parser.add_argument( 100 | "--output_dir", 101 | required=True, 102 | type=str, 103 | ) 104 | parser.add_argument( 105 | "--frame_period_ms", 106 | default=10, 107 | type=float, 108 | ) 109 | parser.add_argument( 110 | "--sampling_rate", 111 | default=24000, 112 | type=int, 113 | ) 114 | parser.add_argument( 115 | "--f0_floor", 116 | default=80, 117 | type=int, 118 | ) 119 | parser.add_argument( 120 | "--f0_ceil", 121 | default=600, 122 | type=int 123 | ) 124 | parser.add_argument( 125 | "--num_workers", 126 | default=10, 127 | type=int 128 | ) 129 | return parser 130 | 131 | 132 | def main(): 133 | parser = get_parser() 134 | args = parser.parse_args() 135 | print(args) 136 | run(args) 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /3_compute_spk_dvecs.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from speaker_encoder.voice_encoder import SpeakerEncoder 3 | from speaker_encoder.audio import preprocess_wav 4 | from pathlib import Path 5 | import numpy as np 6 | from os.path import join, basename, split 7 | from tqdm import tqdm 8 | from multiprocessing import cpu_count 9 | from concurrent.futures import ProcessPoolExecutor 10 | from functools import partial 11 | import glob 12 | import argparse 13 | 14 | 15 | def build_from_path(in_dir, out_dir, weights_fpath, num_workers=1): 16 | executor = ProcessPoolExecutor(max_workers=num_workers) 17 | futures = [] 18 | wavfile_paths = glob.glob(os.path.join(in_dir, '*/*/*.wav')) 19 | wavfile_paths= sorted(wavfile_paths) 20 | for wav_path in wavfile_paths: 21 | futures.append(executor.submit( 22 | partial(_compute_spkEmbed, out_dir, wav_path, weights_fpath))) 23 | return [future.result() for future in tqdm(futures)] 24 | 25 | def _compute_spkEmbed(out_dir, wav_path, weights_fpath): 26 | utt_id = os.path.basename(wav_path).rstrip(".wav") 27 | fpath = Path(wav_path) 28 | wav = preprocess_wav(fpath) 29 | 30 | encoder = SpeakerEncoder(weights_fpath) 31 | embed = encoder.embed_utterance(wav) 32 | fname_save = os.path.join(out_dir, f"{utt_id}.npy") 33 | np.save(fname_save, embed, allow_pickle=False) 34 | return os.path.basename(fname_save) 35 | 36 | def preprocess(in_dir, out_dir, weights_fpath, num_workers): 37 | os.makedirs(out_dir, exist_ok=True) 38 | metadata = build_from_path(in_dir, out_dir, weights_fpath, num_workers) 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--in_dir', type=str, 43 | default='/home/shaunxliu/data/datasets/LibriTTS/LibriTTS') 44 | parser.add_argument('--num_workers', type=int, default=20) 45 | parser.add_argument('--out_dir_root', type=str, 46 | default='/home/shaunxliu/data/datasets/LibriTTS') 47 | parser.add_argument('--spk_encoder_ckpt', type=str, \ 48 | default='speaker_encoder/ckpt/pretrained_bak_5805000.pt') 49 | 50 | args = parser.parse_args() 51 | 52 | split_list = ['train-clean-100', 'train-clean-360'] 53 | 54 | # sub_folder_list = os.listdir(args.in_dir) 55 | # sub_folder_list.sort() 56 | 57 | args.num_workers = args.num_workers if args.num_workers is not None else cpu_count() 58 | print("Number of workers: ", args.num_workers) 59 | ckpt_step = os.path.basename(args.spk_encoder_ckpt).split('.')[0].split('_')[-1] 60 | spk_embed_out_dir = os.path.join(args.out_dir_root, f"GE2E_spkEmbed_step_{ckpt_step}") 61 | print("[INFO] spk_embed_out_dir: ", spk_embed_out_dir) 62 | os.makedirs(spk_embed_out_dir, exist_ok=True) 63 | 64 | # for data_split in split_list: 65 | # sub_folder_list = os.listdir(args.in_dir, data_split) 66 | # for spk in sub_folder_list: 67 | # print("Preprocessing {} ...".format(spk)) 68 | # in_dir = os.path.join(args.in_dir, dataset, spk) 69 | # if not os.path.isdir(in_dir): 70 | # continue 71 | # # out_dir = os.path.join(args.out_dir, spk) 72 | # preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers) 73 | for data_split in split_list: 74 | in_dir = os.path.join(args.in_dir, data_split) 75 | preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers) 76 | 77 | print("DONE!") 78 | sys.exit(0) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # One-shot Phonetic PosteriorGram (PPG)-Based Voice Conversion (PPG-VC): Any-to-Many Voice Conversion with Location-Relative Sequence-to-Sequence Modeling (TASLP 2021) 2 | 3 | ### [Paper](https://arxiv.org/abs/2009.02725v3) | [Pre-trained models](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing) | [Paper Demo](https://liusongxiang.github.io/BNE-Seq2SeqMoL-VC/) 4 | 5 | This paper proposes an any-to-many location-relative, sequence-to-sequence (seq2seq) based, non-parallel voice conversion approach. In this approach, we combine a bottle-neck feature extractor (BNE) with a seq2seq based synthesis module. During the training stage, an encoder-decoder based hybrid connectionist-temporal-classification-attention (CTC-attention) phoneme recognizer is trained, whose encoder has a bottle-neck layer. A BNE is obtained from the phoneme recognizer and is utilized to extract speaker-independent, dense and rich linguistic representations from spectral features. Then a multi-speaker location-relative attention based seq2seq synthesis model is trained to reconstruct spectral features from the bottle-neck features, conditioning on speaker representations for speaker identity control in the generated speech. To mitigate the difficulties of using seq2seq based models to align long sequences, we down-sample the input spectral feature along the temporal dimension and equip the synthesis model with a discretized mixture of logistic (MoL) attention mechanism. Since the phoneme recognizer is trained with large speech recognition data corpus, the proposed approach can conduct any-to-many voice conversion. Objective and subjective evaluations shows that the proposed any-to-many approach has superior voice conversion performance in terms of both naturalness and speaker similarity. Ablation studies are conducted to confirm the effectiveness of feature selection and model design strategies in the proposed approach. The proposed VC approach can readily be extended to support any-to-any VC (also known as one/few-shot VC), and achieve high performance according to objective and subjective evaluations. 6 | 7 |

8 | 9 |

10 |

11 | Diagram of the BNE-Seq2seqMoL system. 12 |

13 | 14 | 15 | This repo implements an updated version of PPG-based VC models. 16 | 17 | Notes: 18 | 19 | - The PPG model provided in `conformer_ppg_model` is based on Hybrid CTC-Attention phoneme recognizer, trained with LibriSpeech (960hrs). PPGs have frame-shift of 10 ms, with dimensionality of 144. This modelis very much similar to the one used in [this paper](https://arxiv.org/pdf/2011.05731v2.pdf). 20 | 21 | - This repo uses [Hifi-GAN V1](https://github.com/jik876/hifi-gan) as the vocoder model, sampling rate of synthesized audio is 24kHz. 22 | 23 | ## Updates! 24 | - We provide an audio sample uttered by Barack Obama ([link](https://drive.google.com/file/d/10Cgtw14UtVf2jTqKtR-C1y5bZqq6Ue7U/view?usp=sharing)), you can convert any voice into Obama's voice using this sample as reference. Please have a try! 25 | - BNE-Seq2seqMoL One-shot VC model are uploaded ([link](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing)) 26 | - BiLSTM-based One-shot VC model are uploaded ([link](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing)) 27 | 28 | 29 | ## How to use 30 | ### Setup with virtualenv 31 | ``` 32 | $ cd tools 33 | $ make 34 | ``` 35 | 36 | Note: If you want to specify Python version, CUDA version or PyTorch version, please run for example: 37 | 38 | ``` 39 | $ make PYTHON=3.7 CUDA_VERSION=10.1 PYTORCH_VERSION=1.6 40 | ``` 41 | 42 | ### Conversion with a pretrained model 43 | 1. Download a model from [here](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing), we recommend to first try the model `bneSeq2seqMoL-vctk-libritts460-oneshot`. Put the config file and the checkpoint file in a folder ``. 44 | 2. Prepare a source wav directory ``, where the wavs inside are what you want to convert. 45 | 3. Prepare a reference audio sample (i.e., the target voice you want convert to) ``. 46 | 4. Run `test.sh` as: 47 | ``` 48 | sh test.sh /seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml /best_loss_step_304000.pth \ 49 | 50 | ``` 51 | The converted wavs are saved in the folder `vc_gen_wavs`. 52 | 53 | ### Data preprocessing 54 | Activate the virtual env py `source tools/venv/bin/activate`, then: 55 | - Please run `1_compute_ctc_att_bnf.py` to compute PPG features. 56 | - Please run `2_compute_f0.py` to compute fundamental frequency. 57 | - Please run `3_compute_spk_dvecs.py` to compute speaker d-vectors. 58 | 59 | ### Training 60 | - Please refer to `run.sh` 61 | 62 | ## Citations 63 | If you use this repo for your research, please consider of citing the following related papers. 64 | ``` 65 | @ARTICLE{liu2021any, 66 | author={Liu, Songxiang and Cao, Yuewen and Wang, Disong and Wu, Xixin and Liu, Xunying and Meng, Helen}, 67 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 68 | title={Any-to-Many Voice Conversion With Location-Relative Sequence-to-Sequence Modeling}, 69 | year={2021}, 70 | volume={29}, 71 | number={}, 72 | pages={1717-1728}, 73 | doi={10.1109/TASLP.2021.3076867} 74 | } 75 | 76 | @inproceedings{Liu2018, 77 | author={Songxiang Liu and Jinghua Zhong and Lifa Sun and Xixin Wu and Xunying Liu and Helen Meng}, 78 | title={Voice Conversion Across Arbitrary Speakers Based on a Single Target-Speaker Utterance}, 79 | year=2018, 80 | booktitle={Proc. Interspeech 2018}, 81 | pages={496--500}, 82 | doi={10.21437/Interspeech.2018-1504}, 83 | url={http://dx.doi.org/10.21437/Interspeech.2018-1504} 84 | } 85 | -------------------------------------------------------------------------------- /bin/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import abc 4 | import math 5 | import yaml 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from src.option import default_hparas 10 | from src.util import human_format, Timer 11 | from utils.load_yaml import HpsYaml 12 | 13 | 14 | class BaseSolver(): 15 | ''' 16 | Prototype Solver for all kinds of tasks 17 | Arguments 18 | config - yaml-styled config 19 | paras - argparse outcome 20 | mode - "train"/"test" 21 | ''' 22 | 23 | def __init__(self, config, paras, mode="train"): 24 | # General Settings 25 | self.config = config # load from yaml file 26 | self.paras = paras # command line args 27 | self.mode = mode # 'train' or 'test' 28 | for k, v in default_hparas.items(): 29 | setattr(self, k, v) 30 | self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \ 31 | else torch.device('cpu') 32 | 33 | # Name experiment 34 | self.exp_name = paras.name 35 | if self.exp_name is None: 36 | if 'exp_name' in self.config: 37 | self.exp_name = self.config.exp_name 38 | else: 39 | # By default, exp is named after config file 40 | self.exp_name = paras.config.split('/')[-1].replace('.yaml', '') 41 | if mode == 'train': 42 | self.exp_name += '_seed{}'.format(paras.seed) 43 | 44 | 45 | if mode == 'train': 46 | # Filepath setup 47 | os.makedirs(paras.ckpdir, exist_ok=True) 48 | self.ckpdir = os.path.join(paras.ckpdir, self.exp_name) 49 | os.makedirs(self.ckpdir, exist_ok=True) 50 | 51 | # Logger settings 52 | self.logdir = os.path.join(paras.logdir, self.exp_name) 53 | self.log = SummaryWriter( 54 | self.logdir, flush_secs=self.TB_FLUSH_FREQ) 55 | self.timer = Timer() 56 | 57 | # Hyper-parameters 58 | self.step = 0 59 | self.valid_step = config.hparas.valid_step 60 | self.max_step = config.hparas.max_step 61 | 62 | self.verbose('Exp. name : {}'.format(self.exp_name)) 63 | self.verbose('Loading data... large corpus may took a while.') 64 | 65 | # elif mode == 'test': 66 | # # Output path 67 | # os.makedirs(paras.outdir, exist_ok=True) 68 | # self.ckpdir = os.path.join(paras.outdir, self.exp_name) 69 | 70 | # Load training config to get acoustic feat and build model 71 | # self.src_config = HpsYaml(config.src.config) 72 | # self.paras.load = config.src.ckpt 73 | 74 | # self.verbose('Evaluating result of tr. config @ {}'.format( 75 | # config.src.config)) 76 | 77 | def backward(self, loss): 78 | ''' 79 | Standard backward step with self.timer and debugger 80 | Arguments 81 | loss - the loss to perform loss.backward() 82 | ''' 83 | self.timer.set() 84 | loss.backward() 85 | grad_norm = torch.nn.utils.clip_grad_norm_( 86 | self.model.parameters(), self.GRAD_CLIP) 87 | if math.isnan(grad_norm): 88 | self.verbose('Error : grad norm is NaN @ step '+str(self.step)) 89 | else: 90 | self.optimizer.step() 91 | self.timer.cnt('bw') 92 | return grad_norm 93 | 94 | def load_ckpt(self): 95 | ''' Load ckpt if --load option is specified ''' 96 | if self.paras.load is not None: 97 | if self.paras.warm_start: 98 | self.verbose(f"Warm starting model from checkpoint {self.paras.load}.") 99 | ckpt = torch.load( 100 | self.paras.load, map_location=self.device if self.mode == 'train' 101 | else 'cpu') 102 | model_dict = ckpt['model'] 103 | if len(self.config.model.ignore_layers) > 0: 104 | model_dict = {k:v for k, v in model_dict.items() 105 | if k not in self.config.model.ignore_layers} 106 | dummy_dict = self.model.state_dict() 107 | dummy_dict.update(model_dict) 108 | model_dict = dummy_dict 109 | self.model.load_state_dict(model_dict) 110 | else: 111 | # Load weights 112 | ckpt = torch.load( 113 | self.paras.load, map_location=self.device if self.mode == 'train' 114 | else 'cpu') 115 | self.model.load_state_dict(ckpt['model']) 116 | 117 | # Load task-dependent items 118 | if self.mode == 'train': 119 | self.step = ckpt['global_step'] 120 | self.optimizer.load_opt_state_dict(ckpt['optimizer']) 121 | self.verbose('Load ckpt from {}, restarting at step {}'.format( 122 | self.paras.load, self.step)) 123 | else: 124 | for k, v in ckpt.items(): 125 | if type(v) is float: 126 | metric, score = k, v 127 | self.model.eval() 128 | self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format( 129 | self.paras.load, metric, score)) 130 | 131 | def verbose(self, msg): 132 | ''' Verbose function for print information to stdout''' 133 | if self.paras.verbose: 134 | if type(msg) == list: 135 | for m in msg: 136 | print('[INFO]', m.ljust(100)) 137 | else: 138 | print('[INFO]', msg.ljust(100)) 139 | 140 | def progress(self, msg): 141 | ''' Verbose function for updating progress on stdout (do not include newline) ''' 142 | if self.paras.verbose: 143 | sys.stdout.write("\033[K") # Clear line 144 | print('[{}] {}'.format(human_format(self.step), msg), end='\r') 145 | 146 | def write_log(self, log_name, log_dict): 147 | ''' 148 | Write log to TensorBoard 149 | log_name - Name of tensorboard variable 150 | log_value - / Value of variable (e.g. dict of losses), passed if value = None 151 | ''' 152 | if type(log_dict) is dict: 153 | log_dict = {key: val for key, val in log_dict.items() if ( 154 | val is not None and not math.isnan(val))} 155 | if log_dict is None: 156 | pass 157 | elif len(log_dict) > 0: 158 | if 'align' in log_name or 'spec' in log_name: 159 | img, form = log_dict 160 | self.log.add_image( 161 | log_name, img, global_step=self.step, dataformats=form) 162 | elif 'text' in log_name or 'hyp' in log_name: 163 | self.log.add_text(log_name, log_dict, self.step) 164 | else: 165 | self.log.add_scalars(log_name, log_dict, self.step) 166 | 167 | def save_checkpoint(self, f_name, metric, score, show_msg=True): 168 | '''' 169 | Ckpt saver 170 | f_name - the name of ckpt file (w/o prefix) to store, overwrite if existed 171 | score - The value of metric used to evaluate model 172 | ''' 173 | ckpt_path = os.path.join(self.ckpdir, f_name) 174 | full_dict = { 175 | "model": self.model.state_dict(), 176 | "optimizer": self.optimizer.get_opt_state_dict(), 177 | "global_step": self.step, 178 | metric: score 179 | } 180 | 181 | torch.save(full_dict, ckpt_path) 182 | if show_msg: 183 | self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}". 184 | format(human_format(self.step), metric, score, ckpt_path)) 185 | 186 | 187 | # ----------------------------------- Abtract Methods ------------------------------------------ # 188 | @abc.abstractmethod 189 | def load_data(self): 190 | ''' 191 | Called by main to load all data 192 | After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set) 193 | No return value 194 | ''' 195 | raise NotImplementedError 196 | 197 | @abc.abstractmethod 198 | def set_model(self): 199 | ''' 200 | Called by main to set models 201 | After this call, model related attributes should be setup (e.g. self.l2_loss) 202 | The followings MUST be setup 203 | - self.model (torch.nn.Module) 204 | - self.optimizer (src.Optimizer), 205 | init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas']) 206 | Loading pre-trained model should also be performed here 207 | No return value 208 | ''' 209 | raise NotImplementedError 210 | 211 | @abc.abstractmethod 212 | def exec(self): 213 | ''' 214 | Called by main to execute training/inference 215 | ''' 216 | raise NotImplementedError 217 | -------------------------------------------------------------------------------- /bin/train_ppg2mel_oneshotvc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import numpy as np 4 | from src.solver import BaseSolver 5 | # from src.data_load import VcDataset, VcCollate 6 | from src.data_load import OneshotVcDataset, MultiSpkVcCollate 7 | from src.rnn_ppg2mel import BiRnnPpg2MelModel 8 | from src.optim import Optimizer 9 | from src.util import human_format, feat_to_fig 10 | from src.loss_fn import MaskedMSELoss 11 | 12 | 13 | class Solver(BaseSolver): 14 | """Customized Solver.""" 15 | def __init__(self, config, paras, mode): 16 | super().__init__(config, paras, mode) 17 | self.best_loss = np.inf 18 | 19 | def fetch_data(self, data): 20 | """Move data to device""" 21 | data = [i.to(self.device) for i in data] 22 | return data 23 | 24 | def load_data(self): 25 | """ Load data for training/validation/plotting.""" 26 | train_dataset = OneshotVcDataset( 27 | meta_file=self.config.data.train_fid_list, 28 | vctk_ppg_dir=self.config.data.vctk_ppg_dir, 29 | libri_ppg_dir=self.config.data.libri_ppg_dir, 30 | vctk_f0_dir=self.config.data.vctk_f0_dir, 31 | libri_f0_dir=self.config.data.libri_f0_dir, 32 | vctk_wav_dir=self.config.data.vctk_wav_dir, 33 | libri_wav_dir=self.config.data.libri_wav_dir, 34 | vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir, 35 | libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir, 36 | ppg_file_ext=self.config.data.ppg_file_ext, 37 | min_max_norm_mel=self.config.data.min_max_norm_mel, 38 | mel_min=self.config.data.mel_min, 39 | mel_max=self.config.data.mel_max, 40 | ) 41 | dev_dataset = OneshotVcDataset( 42 | meta_file=self.config.data.dev_fid_list, 43 | vctk_ppg_dir=self.config.data.vctk_ppg_dir, 44 | libri_ppg_dir=self.config.data.libri_ppg_dir, 45 | vctk_f0_dir=self.config.data.vctk_f0_dir, 46 | libri_f0_dir=self.config.data.libri_f0_dir, 47 | vctk_wav_dir=self.config.data.vctk_wav_dir, 48 | libri_wav_dir=self.config.data.libri_wav_dir, 49 | vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir, 50 | libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir, 51 | ppg_file_ext=self.config.data.ppg_file_ext, 52 | min_max_norm_mel=self.config.data.min_max_norm_mel, 53 | mel_min=self.config.data.mel_min, 54 | mel_max=self.config.data.mel_max, 55 | ) 56 | self.train_dataloader = DataLoader( 57 | train_dataset, 58 | num_workers=self.paras.njobs, 59 | shuffle=True, 60 | batch_size=self.config.hparas.batch_size, 61 | pin_memory=False, 62 | drop_last=True, 63 | collate_fn=MultiSpkVcCollate(n_frames_per_step=1, 64 | f02ppg_length_ratio=1, 65 | use_spk_dvec=True), 66 | ) 67 | self.dev_dataloader = DataLoader( 68 | dev_dataset, 69 | num_workers=self.paras.njobs, 70 | shuffle=False, 71 | batch_size=self.config.hparas.batch_size, 72 | pin_memory=False, 73 | drop_last=False, 74 | collate_fn=MultiSpkVcCollate(n_frames_per_step=1, 75 | f02ppg_length_ratio=1, 76 | use_spk_dvec=True), 77 | ) 78 | msg = "Have prepared training set and dev set." 79 | self.verbose(msg) 80 | 81 | def load_pretrained_params(self): 82 | prefix = "ppg2mel_model" 83 | ignore_layers = ["ppg2mel_model.spk_embedding.weight"] 84 | pretrain_model_file = self.config.data.pretrain_model_file 85 | pretrain_ckpt = torch.load( 86 | pretrain_model_file, map_location=self.device 87 | ) 88 | model_dict = self.model.state_dict() 89 | 90 | # 1. filter out unnecessrary keys 91 | pretrain_dict = {k.split(".", maxsplit=1)[1]: v 92 | for k, v in pretrain_ckpt.items() if "spk_embedding" not in k 93 | and "wav2ppg_model" not in k and "reduce_proj" not in k} 94 | # assert len(pretrain_dict.keys()) == len(model_dict.keys()) 95 | 96 | # 2. overwrite entries in the existing state dict 97 | model_dict.update(pretrain_dict) 98 | 99 | # 3. load the new state dict 100 | self.model.load_state_dict(model_dict) 101 | 102 | def set_model(self): 103 | """Setup model and optimizer""" 104 | # Model 105 | self.model = BiRnnPpg2MelModel(**self.config["model"]).to(self.device) 106 | if "pretrain_model_file" in self.config.data: 107 | self.load_pretrained_params() 108 | 109 | # model_params = [{'params': self.model.spk_embedding.weight}] 110 | model_params = [{'params': self.model.parameters()}] 111 | 112 | # Loss criterion 113 | self.loss_criterion = MaskedMSELoss() 114 | 115 | # Optimizer 116 | self.optimizer = Optimizer(model_params, **self.config["hparas"]) 117 | self.verbose(self.optimizer.create_msg()) 118 | 119 | # Automatically load pre-trained model if self.paras.load is given 120 | self.load_ckpt() 121 | 122 | def exec(self): 123 | self.verbose("Total training steps {}.".format( 124 | human_format(self.max_step))) 125 | 126 | mel_loss = None 127 | n_epochs = 0 128 | # Set as current time 129 | self.timer.set() 130 | 131 | while self.step < self.max_step: 132 | for data in self.train_dataloader: 133 | # Pre-step: updata lr_rate and do zero_grad 134 | lr_rate = self.optimizer.pre_step(self.step) 135 | total_loss = 0 136 | # data to device 137 | ppgs, lf0_uvs, mels, in_lengths, \ 138 | out_lengths, spk_ids, _ = self.fetch_data(data) 139 | self.timer.cnt("rd") 140 | mel_pred = self.model( 141 | ppg=ppgs, 142 | ppg_lengths=out_lengths, 143 | logf0_uv=lf0_uvs, 144 | spembs=spk_ids, 145 | ) 146 | loss = self.loss_criterion(mel_pred, mels, out_lengths) 147 | 148 | self.timer.cnt("fw") 149 | 150 | # Back-prop 151 | grad_norm = self.backward(loss) 152 | self.step += 1 153 | 154 | # Logger 155 | if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): 156 | self.progress("Tr stat | Loss - {:.4f} | Grad. Norm - {:.2f} | {}" 157 | .format(loss.cpu().item(), grad_norm, self.timer.show())) 158 | self.write_log('loss', {'tr': loss}) 159 | 160 | # Validation 161 | if (self.step == 1) or (self.step % self.valid_step == 0): 162 | self.validate() 163 | 164 | # End of step 165 | # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 166 | torch.cuda.empty_cache() 167 | self.timer.set() 168 | if self.step > self.max_step: 169 | break 170 | n_epochs += 1 171 | self.log.close() 172 | 173 | def validate(self): 174 | self.model.eval() 175 | dev_loss = 0.0 176 | 177 | for i, data in enumerate(self.dev_dataloader): 178 | self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader))) 179 | # Fetch data 180 | # ppgs, lf0_uvs, mels, lengths = self.fetch_data(data) 181 | ppgs, lf0_uvs, mels, in_lengths, \ 182 | out_lengths, spk_ids, _ = self.fetch_data(data) 183 | 184 | with torch.no_grad(): 185 | mel_pred = self.model( 186 | ppg=ppgs, 187 | ppg_lengths=out_lengths, 188 | logf0_uv=lf0_uvs, 189 | spembs=spk_ids, 190 | ) 191 | loss = self.loss_criterion(mel_pred, mels, out_lengths) 192 | dev_loss += loss.cpu().item() 193 | 194 | dev_loss = dev_loss / (i + 1) 195 | self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False) 196 | if dev_loss < self.best_loss: 197 | self.best_loss = dev_loss 198 | self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss) 199 | self.write_log('loss', {'dv_loss': dev_loss}) 200 | 201 | # Resume training 202 | self.model.train() 203 | 204 | -------------------------------------------------------------------------------- /conf/bilstm_ppg2mel_vctk_libri_oneshotvc.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_fid_list: "/home/shaunxliu/data/vctk/fidlists/train_fidlist.new" 3 | dev_fid_list: "/home/shaunxliu/data/vctk/fidlists/dev_fidlist.new" 4 | eval_fid_list: "/home/shaunxliu/data/vctk/fidlists/eval_fidlist.txt" 5 | vctk_ppg_dir: "/home/shaunxliu/data/vctk/conformer_bnf10ms" 6 | libri_ppg_dir: "/home/shaunxliu/data/LibriTTS/conformer_bnf10ms" 7 | vctk_f0_dir: "/home/shaunxliu/data/vctk/merged_f0s" 8 | libri_f0_dir: "/home/shaunxliu/data/LibriTTS/f0s" 9 | vctk_wav_dir: "/home/shaunxliu/data/vctk/wav_mono_24k_16b_norm-6db" 10 | libri_wav_dir: "/home/shaunxliu/data/LibriTTS/LibriTTS/train-wavs-clean460/" 11 | vctk_spk_dvec_dir: "/home/shaunxliu/data/vctk/GE2E_spkEmbed_step_5805000_perSpk" 12 | libri_spk_dvec_dir: "/home/shaunxliu/data/LibriTTS/GE2E_spkEmbed_step_5805000_perSpk" 13 | ppg_file_ext: "ling_feat.npy" 14 | f0_file_ext: "f0.npy" 15 | wav_file_ext: "wav" 16 | min_max_norm_mel: true 17 | mel_min: -12.0 18 | mel_max: 2.5 19 | 20 | hparas: 21 | batch_size: 32 22 | valid_step: 1000 23 | max_step: 1000000 24 | optimizer: 'Adam' 25 | lr: 0.001 26 | eps: 1.0e-8 27 | weight_decay: 1.0e-6 28 | lr_scheduler: 'warmup' # "fixed", "warmup" 29 | 30 | model_name: "bilstm" 31 | model: 32 | input_size: 146 # 144 ppg-dim and 2 pitch 33 | multi_spk: True 34 | use_spk_dvec: True # for one-shot VC 35 | 36 | 37 | -------------------------------------------------------------------------------- /conf/seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | #train_fid_list: "../wavernn-vc/dump/vc_train_set/f0s.scp" 3 | train_fid_list: "/home/shaunxliu/data/vctk/fidlists/train_fidlist.new" 4 | dev_fid_list: "/home/shaunxliu/data/vctk/fidlists/dev_fidlist.new" 5 | eval_fid_list: "/home/shaunxliu/data/vctk/fidlists/eval_fidlist.txt" 6 | vctk_ppg_dir: "/home/shaunxliu/data/vctk/conformer_bnf10ms" 7 | libri_ppg_dir: "/home/shaunxliu/data/LibriTTS/conformer_bnf10ms" 8 | vctk_f0_dir: "/home/shaunxliu/data/vctk/merged_f0s" 9 | libri_f0_dir: "/home/shaunxliu/data/LibriTTS/f0s" 10 | vctk_wav_dir: "/home/shaunxliu/data/vctk/wav_mono_24k_16b_norm-6db" 11 | libri_wav_dir: "/home/shaunxliu/data/LibriTTS/LibriTTS/train-wavs-clean460/" 12 | vctk_spk_dvec_dir: "/home/shaunxliu/data/vctk/GE2E_spkEmbed_step_5805000_perSpk" 13 | libri_spk_dvec_dir: "/home/shaunxliu/data/LibriTTS/GE2E_spkEmbed_step_5805000_perSpk" 14 | ppg_file_ext: "ling_feat.npy" 15 | f0_file_ext: "f0.npy" 16 | wav_file_ext: "wav" 17 | min_max_norm_mel: true 18 | mel_min: -12.0 19 | mel_max: 2.5 20 | 21 | hparas: 22 | batch_size: 32 23 | valid_step: 2000 24 | max_step: 1000000 25 | optimizer: 'Adam' 26 | lr: 0.001 27 | eps: 1.0e-8 28 | weight_decay: 1.0e-6 29 | lr_scheduler: 'warmup' # "fixed", "warmup" 30 | 31 | model_name: "seq2seqmolv2" 32 | model: 33 | num_speakers: 1250 34 | spk_embed_dim: 256 35 | bottle_neck_feature_dim: 144 36 | encoder_downsample_rates: [2, 2] 37 | attention_rnn_dim: 512 38 | attention_dim: 512 39 | decoder_rnn_dim: 512 40 | num_decoder_rnn_layer: 1 41 | concat_context_to_last: True 42 | prenet_dims: [256, 128] 43 | prenet_dropout: 0.5 44 | num_mixtures: 5 45 | frames_per_step: 4 46 | postnet_num_layers: 5 47 | postnet_hidden_dim: 512 48 | mask_padding: True 49 | use_spk_dvec: True 50 | -------------------------------------------------------------------------------- /conformer_ppg_model/build_ppg_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from pathlib import Path 4 | import yaml 5 | 6 | 7 | from .frontend import DefaultFrontend 8 | from .utterance_mvn import UtteranceMVN 9 | from .encoder.conformer_encoder import ConformerEncoder 10 | 11 | 12 | class PPGModel(torch.nn.Module): 13 | def __init__( 14 | self, 15 | frontend, 16 | normalizer, 17 | encoder, 18 | ): 19 | super().__init__() 20 | self.frontend = frontend 21 | self.normalize = normalizer 22 | self.encoder = encoder 23 | 24 | def forward(self, speech, speech_lengths): 25 | """ 26 | 27 | Args: 28 | speech (tensor): (B, L) 29 | speech_lengths (tensor): (B, ) 30 | 31 | Returns: 32 | bottle_neck_feats (tensor): (B, L//hop_size, 144) 33 | 34 | """ 35 | feats, feats_lengths = self._extract_feats(speech, speech_lengths) 36 | feats, feats_lengths = self.normalize(feats, feats_lengths) 37 | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) 38 | return encoder_out 39 | 40 | def _extract_feats( 41 | self, speech: torch.Tensor, speech_lengths: torch.Tensor 42 | ): 43 | assert speech_lengths.dim() == 1, speech_lengths.shape 44 | 45 | # for data-parallel 46 | speech = speech[:, : speech_lengths.max()] 47 | 48 | if self.frontend is not None: 49 | # Frontend 50 | # e.g. STFT and Feature extract 51 | # data_loader may send time-domain signal in this case 52 | # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) 53 | feats, feats_lengths = self.frontend(speech, speech_lengths) 54 | else: 55 | # No frontend and no feature extract 56 | feats, feats_lengths = speech, speech_lengths 57 | return feats, feats_lengths 58 | 59 | 60 | def build_model(args): 61 | normalizer = UtteranceMVN(**args.normalize_conf) 62 | frontend = DefaultFrontend(**args.frontend_conf) 63 | encoder = ConformerEncoder(input_size=80, **args.encoder_conf) 64 | model = PPGModel(frontend, normalizer, encoder) 65 | 66 | return model 67 | 68 | 69 | def load_ppg_model(train_config, model_file, device): 70 | config_file = Path(train_config) 71 | with config_file.open("r", encoding="utf-8") as f: 72 | args = yaml.safe_load(f) 73 | 74 | args = argparse.Namespace(**args) 75 | 76 | model = build_model(args) 77 | model_state_dict = model.state_dict() 78 | 79 | ckpt_state_dict = torch.load(model_file, map_location='cpu') 80 | ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k} 81 | 82 | model_state_dict.update(ckpt_state_dict) 83 | model.load_state_dict(model_state_dict) 84 | 85 | return model.eval().to(device) 86 | -------------------------------------------------------------------------------- /conformer_ppg_model/en_conformer_ctc_att/24epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/conformer_ppg_model/en_conformer_ctc_att/24epoch.pth -------------------------------------------------------------------------------- /conformer_ppg_model/en_conformer_ctc_att/config.yaml: -------------------------------------------------------------------------------- 1 | config: ./conf/train_asr_conformer_nodownsample.v1.yaml 2 | print_config: false 3 | log_level: INFO 4 | dry_run: false 5 | iterator_type: sequence 6 | output_dir: exp/asr_train_asr_conformer_nodownsample.v1_raw_sp 7 | ngpu: 1 8 | seed: 0 9 | num_workers: 1 10 | num_att_plot: 5 11 | dist_backend: nccl 12 | dist_init_method: env:// 13 | dist_world_size: 4 14 | dist_rank: 1 15 | local_rank: 1 16 | dist_master_addr: localhost 17 | dist_master_port: 50826 18 | dist_launcher: null 19 | multiprocessing_distributed: true 20 | cudnn_enabled: true 21 | cudnn_benchmark: false 22 | cudnn_deterministic: true 23 | collect_stats: false 24 | write_collected_feats: false 25 | max_epoch: 2000 26 | patience: 200 27 | val_scheduler_criterion: 28 | - valid 29 | - acc 30 | early_stopping_criterion: 31 | - valid 32 | - loss 33 | - min 34 | best_model_criterion: 35 | - - valid 36 | - acc 37 | - max 38 | keep_nbest_models: 10 39 | grad_clip: 5.0 40 | grad_noise: false 41 | accum_grad: 8 42 | no_forward_run: false 43 | resume: true 44 | train_dtype: float32 45 | log_interval: null 46 | pretrain_path: [] 47 | pretrain_key: [] 48 | num_iters_per_epoch: null 49 | batch_size: 32 50 | valid_batch_size: 32 51 | batch_bins: 1000000 52 | valid_batch_bins: 1000000 53 | train_shape_file: 54 | - exp/asr_stats_raw_sp/train/speech_shape 55 | - exp/asr_stats_raw_sp/train/text_shape.phn 56 | valid_shape_file: 57 | - exp/asr_stats_raw_sp/valid/speech_shape 58 | - exp/asr_stats_raw_sp/valid/text_shape.phn 59 | batch_type: folded 60 | valid_batch_type: folded 61 | fold_length: 62 | - 128000 63 | - 150 64 | sort_in_batch: descending 65 | sort_batch: descending 66 | chunk_length: 500 67 | chunk_shift_ratio: 0.5 68 | num_cache_chunks: 1024 69 | train_data_path_and_name_and_type: 70 | - - dump/raw/train_960_sp/wav.scp 71 | - speech 72 | - sound 73 | - - dump/raw/train_960_sp/text 74 | - text 75 | - text 76 | valid_data_path_and_name_and_type: 77 | - - dump/raw/dev_set/wav.scp 78 | - speech 79 | - sound 80 | - - dump/raw/dev_set/text 81 | - text 82 | - text 83 | allow_variable_data_keys: false 84 | max_cache_size: 0.0 85 | valid_max_cache_size: 0.0 86 | optim: adam 87 | optim_conf: 88 | lr: 0.0015 89 | scheduler: warmuplr 90 | scheduler_conf: 91 | warmup_steps: 25000 92 | token_list: 93 | - 94 | - 95 | - AA0 96 | - AA1 97 | - AA2 98 | - AE0 99 | - AE1 100 | - AE2 101 | - AH0 102 | - AH1 103 | - AH2 104 | - AO0 105 | - AO1 106 | - AO2 107 | - AW0 108 | - AW1 109 | - AW2 110 | - AY0 111 | - AY1 112 | - AY2 113 | - B 114 | - CH 115 | - D 116 | - DH 117 | - EH0 118 | - EH1 119 | - EH2 120 | - ER0 121 | - ER1 122 | - ER2 123 | - EY0 124 | - EY1 125 | - EY2 126 | - F 127 | - G 128 | - HH 129 | - IH0 130 | - IH1 131 | - IH2 132 | - IY0 133 | - IY1 134 | - IY2 135 | - JH 136 | - K 137 | - L 138 | - M 139 | - N 140 | - NG 141 | - OW0 142 | - OW1 143 | - OW2 144 | - OY0 145 | - OY1 146 | - OY2 147 | - P 148 | - R 149 | - S 150 | - SH 151 | - T 152 | - TH 153 | - UH0 154 | - UH1 155 | - UH2 156 | - UW0 157 | - UW1 158 | - UW2 159 | - V 160 | - W 161 | - Y 162 | - Z 163 | - ZH 164 | - sil 165 | - sp 166 | - spn 167 | - 168 | init: null 169 | input_size: null 170 | ctc_conf: 171 | dropout_rate: 0.0 172 | ctc_type: builtin 173 | reduce: true 174 | model_conf: 175 | ctc_weight: 0.5 176 | lsm_weight: 0.1 177 | length_normalized_loss: false 178 | use_preprocessor: true 179 | token_type: phn 180 | bpemodel: null 181 | non_linguistic_symbols: null 182 | frontend: default 183 | frontend_conf: 184 | fs: 16000 185 | specaug: specaug 186 | specaug_conf: 187 | apply_time_warp: true 188 | time_warp_window: 5 189 | time_warp_mode: bicubic 190 | apply_freq_mask: true 191 | freq_mask_width_range: 192 | - 0 193 | - 30 194 | num_freq_mask: 2 195 | apply_time_mask: true 196 | time_mask_width_range: 197 | - 0 198 | - 40 199 | num_time_mask: 2 200 | normalize: utterance_mvn 201 | normalize_conf: 202 | norm_means: true 203 | norm_vars: true 204 | encoder: conformer 205 | encoder_conf: 206 | attention_dim: 144 207 | attention_heads: 4 208 | linear_units: 576 209 | num_blocks: 16 210 | dropout_rate: 0.1 211 | positional_dropout_rate: 0.1 212 | attention_dropout_rate: 0.0 213 | input_layer: conv2d 214 | normalize_before: true 215 | concat_after: false 216 | positionwise_layer_type: linear 217 | positionwise_conv_kernel_size: 1 218 | macaron_style: true 219 | pos_enc_layer_type: rel_pos 220 | selfattention_layer_type: rel_selfattn 221 | activation_type: swish 222 | use_cnn_module: true 223 | cnn_module_kernel: 15 224 | no_subsample: true 225 | subsample_by_2: false 226 | decoder: rnn 227 | decoder_conf: 228 | rnn_type: lstm 229 | num_layers: 1 230 | hidden_size: 320 231 | sampling_probability: 0.0 232 | dropout: 0.0 233 | att_conf: 234 | adim: 320 235 | required: 236 | - output_dir 237 | - token_list 238 | distributed: true 239 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/conformer_ppg_model/encoder/__init__.py -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Multi-Head Attention layer definition.""" 8 | 9 | import math 10 | 11 | import numpy 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class MultiHeadedAttention(nn.Module): 17 | """Multi-Head Attention layer. 18 | 19 | :param int n_head: the number of head s 20 | :param int n_feat: the number of features 21 | :param float dropout_rate: dropout rate 22 | 23 | """ 24 | 25 | def __init__(self, n_head, n_feat, dropout_rate): 26 | """Construct an MultiHeadedAttention object.""" 27 | super(MultiHeadedAttention, self).__init__() 28 | assert n_feat % n_head == 0 29 | # We assume d_v always equals d_k 30 | self.d_k = n_feat // n_head 31 | self.h = n_head 32 | self.linear_q = nn.Linear(n_feat, n_feat) 33 | self.linear_k = nn.Linear(n_feat, n_feat) 34 | self.linear_v = nn.Linear(n_feat, n_feat) 35 | self.linear_out = nn.Linear(n_feat, n_feat) 36 | self.attn = None 37 | self.dropout = nn.Dropout(p=dropout_rate) 38 | 39 | def forward_qkv(self, query, key, value): 40 | """Transform query, key and value. 41 | 42 | :param torch.Tensor query: (batch, time1, size) 43 | :param torch.Tensor key: (batch, time2, size) 44 | :param torch.Tensor value: (batch, time2, size) 45 | :return torch.Tensor transformed query, key and value 46 | 47 | """ 48 | n_batch = query.size(0) 49 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 50 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 51 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 52 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 53 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 54 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 55 | 56 | return q, k, v 57 | 58 | def forward_attention(self, value, scores, mask): 59 | """Compute attention context vector. 60 | 61 | :param torch.Tensor value: (batch, head, time2, size) 62 | :param torch.Tensor scores: (batch, head, time1, time2) 63 | :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2) 64 | :return torch.Tensor transformed `value` (batch, time1, d_model) 65 | weighted by the attention score (batch, time1, time2) 66 | 67 | """ 68 | n_batch = value.size(0) 69 | if mask is not None: 70 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 71 | min_value = float( 72 | numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min 73 | ) 74 | scores = scores.masked_fill(mask, min_value) 75 | self.attn = torch.softmax(scores, dim=-1).masked_fill( 76 | mask, 0.0 77 | ) # (batch, head, time1, time2) 78 | else: 79 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 80 | 81 | p_attn = self.dropout(self.attn) 82 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 83 | x = ( 84 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 85 | ) # (batch, time1, d_model) 86 | 87 | return self.linear_out(x) # (batch, time1, d_model) 88 | 89 | def forward(self, query, key, value, mask): 90 | """Compute 'Scaled Dot Product Attention'. 91 | 92 | :param torch.Tensor query: (batch, time1, size) 93 | :param torch.Tensor key: (batch, time2, size) 94 | :param torch.Tensor value: (batch, time2, size) 95 | :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2) 96 | :param torch.nn.Dropout dropout: 97 | :return torch.Tensor: attention output (batch, time1, d_model) 98 | """ 99 | q, k, v = self.forward_qkv(query, key, value) 100 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 101 | return self.forward_attention(v, scores, mask) 102 | 103 | 104 | class RelPositionMultiHeadedAttention(MultiHeadedAttention): 105 | """Multi-Head Attention layer with relative position encoding. 106 | 107 | Paper: https://arxiv.org/abs/1901.02860 108 | 109 | :param int n_head: the number of head s 110 | :param int n_feat: the number of features 111 | :param float dropout_rate: dropout rate 112 | 113 | """ 114 | 115 | def __init__(self, n_head, n_feat, dropout_rate): 116 | """Construct an RelPositionMultiHeadedAttention object.""" 117 | super().__init__(n_head, n_feat, dropout_rate) 118 | # linear transformation for positional ecoding 119 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 120 | # these two learnable bias are used in matrix c and matrix d 121 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 122 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 123 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 124 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 125 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 126 | 127 | def rel_shift(self, x, zero_triu=False): 128 | """Compute relative positinal encoding. 129 | 130 | :param torch.Tensor x: (batch, time, size) 131 | :param bool zero_triu: return the lower triangular part of the matrix 132 | """ 133 | zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) 134 | x_padded = torch.cat([zero_pad, x], dim=-1) 135 | 136 | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) 137 | x = x_padded[:, :, 1:].view_as(x) 138 | 139 | if zero_triu: 140 | ones = torch.ones((x.size(2), x.size(3))) 141 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] 142 | 143 | return x 144 | 145 | def forward(self, query, key, value, pos_emb, mask): 146 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 147 | 148 | :param torch.Tensor query: (batch, time1, size) 149 | :param torch.Tensor key: (batch, time2, size) 150 | :param torch.Tensor value: (batch, time2, size) 151 | :param torch.Tensor pos_emb: (batch, time1, size) 152 | :param torch.Tensor mask: (batch, time1, time2) 153 | :param torch.nn.Dropout dropout: 154 | :return torch.Tensor: attention output (batch, time1, d_model) 155 | """ 156 | q, k, v = self.forward_qkv(query, key, value) 157 | q = q.transpose(1, 2) # (batch, time1, head, d_k) 158 | 159 | n_batch_pos = pos_emb.size(0) 160 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 161 | p = p.transpose(1, 2) # (batch, head, time1, d_k) 162 | 163 | # (batch, head, time1, d_k) 164 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 165 | # (batch, head, time1, d_k) 166 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 167 | 168 | # compute attention score 169 | # first compute matrix a and matrix c 170 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 171 | # (batch, head, time1, time2) 172 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 173 | 174 | # compute matrix b and matrix d 175 | # (batch, head, time1, time2) 176 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 177 | matrix_bd = self.rel_shift(matrix_bd) 178 | 179 | scores = (matrix_ac + matrix_bd) / math.sqrt( 180 | self.d_k 181 | ) # (batch, head, time1, time2) 182 | 183 | return self.forward_attention(v, scores, mask) 184 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/convolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | 8 | """ConvolutionModule definition.""" 9 | 10 | from torch import nn 11 | 12 | 13 | class ConvolutionModule(nn.Module): 14 | """ConvolutionModule in Conformer model. 15 | 16 | :param int channels: channels of cnn 17 | :param int kernel_size: kernerl size of cnn 18 | 19 | """ 20 | 21 | def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): 22 | """Construct an ConvolutionModule object.""" 23 | super(ConvolutionModule, self).__init__() 24 | # kernerl_size should be a odd number for 'SAME' padding 25 | assert (kernel_size - 1) % 2 == 0 26 | 27 | self.pointwise_conv1 = nn.Conv1d( 28 | channels, 29 | 2 * channels, 30 | kernel_size=1, 31 | stride=1, 32 | padding=0, 33 | bias=bias, 34 | ) 35 | self.depthwise_conv = nn.Conv1d( 36 | channels, 37 | channels, 38 | kernel_size, 39 | stride=1, 40 | padding=(kernel_size - 1) // 2, 41 | groups=channels, 42 | bias=bias, 43 | ) 44 | self.norm = nn.BatchNorm1d(channels) 45 | self.pointwise_conv2 = nn.Conv1d( 46 | channels, 47 | channels, 48 | kernel_size=1, 49 | stride=1, 50 | padding=0, 51 | bias=bias, 52 | ) 53 | self.activation = activation 54 | 55 | def forward(self, x): 56 | """Compute convolution module. 57 | 58 | :param torch.Tensor x: (batch, time, size) 59 | :return torch.Tensor: convoluted `value` (batch, time, d_model) 60 | """ 61 | # exchange the temporal dimension and the feature dimension 62 | x = x.transpose(1, 2) 63 | 64 | # GLU mechanism 65 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 66 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 67 | 68 | # 1D Depthwise Conv 69 | x = self.depthwise_conv(x) 70 | x = self.activation(self.norm(x)) 71 | 72 | x = self.pointwise_conv2(x) 73 | 74 | return x.transpose(1, 2) 75 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positonal Encoding Module.""" 8 | 9 | import math 10 | 11 | import torch 12 | 13 | 14 | def _pre_hook( 15 | state_dict, 16 | prefix, 17 | local_metadata, 18 | strict, 19 | missing_keys, 20 | unexpected_keys, 21 | error_msgs, 22 | ): 23 | """Perform pre-hook in load_state_dict for backward compatibility. 24 | 25 | Note: 26 | We saved self.pe until v.0.5.2 but we have omitted it later. 27 | Therefore, we remove the item "pe" from `state_dict` for backward compatibility. 28 | 29 | """ 30 | k = prefix + "pe" 31 | if k in state_dict: 32 | state_dict.pop(k) 33 | 34 | 35 | class PositionalEncoding(torch.nn.Module): 36 | """Positional encoding. 37 | 38 | :param int d_model: embedding dim 39 | :param float dropout_rate: dropout rate 40 | :param int max_len: maximum input length 41 | :param reverse: whether to reverse the input position 42 | 43 | """ 44 | 45 | def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): 46 | """Construct an PositionalEncoding object.""" 47 | super(PositionalEncoding, self).__init__() 48 | self.d_model = d_model 49 | self.reverse = reverse 50 | self.xscale = math.sqrt(self.d_model) 51 | self.dropout = torch.nn.Dropout(p=dropout_rate) 52 | self.pe = None 53 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 54 | self._register_load_state_dict_pre_hook(_pre_hook) 55 | 56 | def extend_pe(self, x): 57 | """Reset the positional encodings.""" 58 | if self.pe is not None: 59 | if self.pe.size(1) >= x.size(1): 60 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 61 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 62 | return 63 | pe = torch.zeros(x.size(1), self.d_model) 64 | if self.reverse: 65 | position = torch.arange( 66 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 67 | ).unsqueeze(1) 68 | else: 69 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 70 | div_term = torch.exp( 71 | torch.arange(0, self.d_model, 2, dtype=torch.float32) 72 | * -(math.log(10000.0) / self.d_model) 73 | ) 74 | pe[:, 0::2] = torch.sin(position * div_term) 75 | pe[:, 1::2] = torch.cos(position * div_term) 76 | pe = pe.unsqueeze(0) 77 | self.pe = pe.to(device=x.device, dtype=x.dtype) 78 | 79 | def forward(self, x: torch.Tensor): 80 | """Add positional encoding. 81 | 82 | Args: 83 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 84 | 85 | Returns: 86 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 87 | 88 | """ 89 | self.extend_pe(x) 90 | x = x * self.xscale + self.pe[:, : x.size(1)] 91 | return self.dropout(x) 92 | 93 | 94 | class ScaledPositionalEncoding(PositionalEncoding): 95 | """Scaled positional encoding module. 96 | 97 | See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf 98 | 99 | """ 100 | 101 | def __init__(self, d_model, dropout_rate, max_len=5000): 102 | """Initialize class. 103 | 104 | :param int d_model: embedding dim 105 | :param float dropout_rate: dropout rate 106 | :param int max_len: maximum input length 107 | 108 | """ 109 | super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) 110 | self.alpha = torch.nn.Parameter(torch.tensor(1.0)) 111 | 112 | def reset_parameters(self): 113 | """Reset parameters.""" 114 | self.alpha.data = torch.tensor(1.0) 115 | 116 | def forward(self, x): 117 | """Add positional encoding. 118 | 119 | Args: 120 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 121 | 122 | Returns: 123 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 124 | 125 | """ 126 | self.extend_pe(x) 127 | x = x + self.alpha * self.pe[:, : x.size(1)] 128 | return self.dropout(x) 129 | 130 | 131 | class RelPositionalEncoding(PositionalEncoding): 132 | """Relitive positional encoding module. 133 | 134 | See : Appendix B in https://arxiv.org/abs/1901.02860 135 | 136 | :param int d_model: embedding dim 137 | :param float dropout_rate: dropout rate 138 | :param int max_len: maximum input length 139 | 140 | """ 141 | 142 | def __init__(self, d_model, dropout_rate, max_len=5000): 143 | """Initialize class. 144 | 145 | :param int d_model: embedding dim 146 | :param float dropout_rate: dropout rate 147 | :param int max_len: maximum input length 148 | 149 | """ 150 | super().__init__(d_model, dropout_rate, max_len, reverse=True) 151 | 152 | def forward(self, x): 153 | """Compute positional encoding. 154 | 155 | Args: 156 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 157 | 158 | Returns: 159 | torch.Tensor: x. Its shape is (batch, time, ...) 160 | torch.Tensor: pos_emb. Its shape is (1, time, ...) 161 | 162 | """ 163 | self.extend_pe(x) 164 | x = x * self.xscale 165 | pos_emb = self.pe[:, : x.size(1)] 166 | return self.dropout(x), self.dropout(pos_emb) 167 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/encoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Encoder self-attention layer definition.""" 8 | 9 | import torch 10 | 11 | from torch import nn 12 | 13 | from .layer_norm import LayerNorm 14 | 15 | 16 | class EncoderLayer(nn.Module): 17 | """Encoder layer module. 18 | 19 | :param int size: input dim 20 | :param espnet.nets.pytorch_backend.transformer.attention. 21 | MultiHeadedAttention self_attn: self attention module 22 | RelPositionMultiHeadedAttention self_attn: self attention module 23 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. 24 | PositionwiseFeedForward feed_forward: 25 | feed forward module 26 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward 27 | for macaron style 28 | PositionwiseFeedForward feed_forward: 29 | feed forward module 30 | :param espnet.nets.pytorch_backend.conformer.convolution. 31 | ConvolutionModule feed_foreard: 32 | feed forward module 33 | :param float dropout_rate: dropout rate 34 | :param bool normalize_before: whether to use layer_norm before the first block 35 | :param bool concat_after: whether to concat attention layer's input and output 36 | if True, additional linear will be applied. 37 | i.e. x -> x + linear(concat(x, att(x))) 38 | if False, no additional linear will be applied. i.e. x -> x + att(x) 39 | 40 | """ 41 | 42 | def __init__( 43 | self, 44 | size, 45 | self_attn, 46 | feed_forward, 47 | feed_forward_macaron, 48 | conv_module, 49 | dropout_rate, 50 | normalize_before=True, 51 | concat_after=False, 52 | ): 53 | """Construct an EncoderLayer object.""" 54 | super(EncoderLayer, self).__init__() 55 | self.self_attn = self_attn 56 | self.feed_forward = feed_forward 57 | self.feed_forward_macaron = feed_forward_macaron 58 | self.conv_module = conv_module 59 | self.norm_ff = LayerNorm(size) # for the FNN module 60 | self.norm_mha = LayerNorm(size) # for the MHA module 61 | if feed_forward_macaron is not None: 62 | self.norm_ff_macaron = LayerNorm(size) 63 | self.ff_scale = 0.5 64 | else: 65 | self.ff_scale = 1.0 66 | if self.conv_module is not None: 67 | self.norm_conv = LayerNorm(size) # for the CNN module 68 | self.norm_final = LayerNorm(size) # for the final output of the block 69 | self.dropout = nn.Dropout(dropout_rate) 70 | self.size = size 71 | self.normalize_before = normalize_before 72 | self.concat_after = concat_after 73 | if self.concat_after: 74 | self.concat_linear = nn.Linear(size + size, size) 75 | 76 | def forward(self, x_input, mask, cache=None): 77 | """Compute encoded features. 78 | 79 | :param torch.Tensor x_input: encoded source features, w/o pos_emb 80 | tuple((batch, max_time_in, size), (1, max_time_in, size)) 81 | or (batch, max_time_in, size) 82 | :param torch.Tensor mask: mask for x (batch, max_time_in) 83 | :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) 84 | :rtype: Tuple[torch.Tensor, torch.Tensor] 85 | """ 86 | if isinstance(x_input, tuple): 87 | x, pos_emb = x_input[0], x_input[1] 88 | else: 89 | x, pos_emb = x_input, None 90 | 91 | # whether to use macaron style 92 | if self.feed_forward_macaron is not None: 93 | residual = x 94 | if self.normalize_before: 95 | x = self.norm_ff_macaron(x) 96 | x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 97 | if not self.normalize_before: 98 | x = self.norm_ff_macaron(x) 99 | 100 | # multi-headed self-attention module 101 | residual = x 102 | if self.normalize_before: 103 | x = self.norm_mha(x) 104 | 105 | if cache is None: 106 | x_q = x 107 | else: 108 | assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) 109 | x_q = x[:, -1:, :] 110 | residual = residual[:, -1:, :] 111 | mask = None if mask is None else mask[:, -1:, :] 112 | 113 | if pos_emb is not None: 114 | x_att = self.self_attn(x_q, x, x, pos_emb, mask) 115 | else: 116 | x_att = self.self_attn(x_q, x, x, mask) 117 | 118 | if self.concat_after: 119 | x_concat = torch.cat((x, x_att), dim=-1) 120 | x = residual + self.concat_linear(x_concat) 121 | else: 122 | x = residual + self.dropout(x_att) 123 | if not self.normalize_before: 124 | x = self.norm_mha(x) 125 | 126 | # convolution module 127 | if self.conv_module is not None: 128 | residual = x 129 | if self.normalize_before: 130 | x = self.norm_conv(x) 131 | x = residual + self.dropout(self.conv_module(x)) 132 | if not self.normalize_before: 133 | x = self.norm_conv(x) 134 | 135 | # feed forward module 136 | residual = x 137 | if self.normalize_before: 138 | x = self.norm_ff(x) 139 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 140 | if not self.normalize_before: 141 | x = self.norm_ff(x) 142 | 143 | if self.conv_module is not None: 144 | x = self.norm_final(x) 145 | 146 | if cache is not None: 147 | x = torch.cat([cache, x], dim=1) 148 | 149 | if pos_emb is not None: 150 | return (x, pos_emb), mask 151 | 152 | return x, mask 153 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/layer_norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Layer normalization module.""" 8 | 9 | import torch 10 | 11 | 12 | class LayerNorm(torch.nn.LayerNorm): 13 | """Layer normalization module. 14 | 15 | :param int nout: output dim size 16 | :param int dim: dimension to be normalized 17 | """ 18 | 19 | def __init__(self, nout, dim=-1): 20 | """Construct an LayerNorm object.""" 21 | super(LayerNorm, self).__init__(nout, eps=1e-12) 22 | self.dim = dim 23 | 24 | def forward(self, x): 25 | """Apply layer normalization. 26 | 27 | :param torch.Tensor x: input tensor 28 | :return: layer normalized tensor 29 | :rtype torch.Tensor 30 | """ 31 | if self.dim == -1: 32 | return super(LayerNorm, self).forward(x) 33 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 34 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/multi_layer_conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Tomoki Hayashi 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" 8 | 9 | import torch 10 | 11 | 12 | class MultiLayeredConv1d(torch.nn.Module): 13 | """Multi-layered conv1d for Transformer block. 14 | 15 | This is a module of multi-leyered conv1d designed 16 | to replace positionwise feed-forward network 17 | in Transforner block, which is introduced in 18 | `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 19 | 20 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 21 | https://arxiv.org/pdf/1905.09263.pdf 22 | 23 | """ 24 | 25 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): 26 | """Initialize MultiLayeredConv1d module. 27 | 28 | Args: 29 | in_chans (int): Number of input channels. 30 | hidden_chans (int): Number of hidden channels. 31 | kernel_size (int): Kernel size of conv1d. 32 | dropout_rate (float): Dropout rate. 33 | 34 | """ 35 | super(MultiLayeredConv1d, self).__init__() 36 | self.w_1 = torch.nn.Conv1d( 37 | in_chans, 38 | hidden_chans, 39 | kernel_size, 40 | stride=1, 41 | padding=(kernel_size - 1) // 2, 42 | ) 43 | self.w_2 = torch.nn.Conv1d( 44 | hidden_chans, 45 | in_chans, 46 | kernel_size, 47 | stride=1, 48 | padding=(kernel_size - 1) // 2, 49 | ) 50 | self.dropout = torch.nn.Dropout(dropout_rate) 51 | 52 | def forward(self, x): 53 | """Calculate forward propagation. 54 | 55 | Args: 56 | x (Tensor): Batch of input tensors (B, ..., in_chans). 57 | 58 | Returns: 59 | Tensor: Batch of output tensors (B, ..., hidden_chans). 60 | 61 | """ 62 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 63 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) 64 | 65 | 66 | class Conv1dLinear(torch.nn.Module): 67 | """Conv1D + Linear for Transformer block. 68 | 69 | A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. 70 | 71 | """ 72 | 73 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): 74 | """Initialize Conv1dLinear module. 75 | 76 | Args: 77 | in_chans (int): Number of input channels. 78 | hidden_chans (int): Number of hidden channels. 79 | kernel_size (int): Kernel size of conv1d. 80 | dropout_rate (float): Dropout rate. 81 | 82 | """ 83 | super(Conv1dLinear, self).__init__() 84 | self.w_1 = torch.nn.Conv1d( 85 | in_chans, 86 | hidden_chans, 87 | kernel_size, 88 | stride=1, 89 | padding=(kernel_size - 1) // 2, 90 | ) 91 | self.w_2 = torch.nn.Linear(hidden_chans, in_chans) 92 | self.dropout = torch.nn.Dropout(dropout_rate) 93 | 94 | def forward(self, x): 95 | """Calculate forward propagation. 96 | 97 | Args: 98 | x (Tensor): Batch of input tensors (B, ..., in_chans). 99 | 100 | Returns: 101 | Tensor: Batch of output tensors (B, ..., hidden_chans). 102 | 103 | """ 104 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 105 | return self.w_2(self.dropout(x)) 106 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positionwise feed forward layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class PositionwiseFeedForward(torch.nn.Module): 13 | """Positionwise feed forward layer. 14 | 15 | :param int idim: input dimenstion 16 | :param int hidden_units: number of hidden units 17 | :param float dropout_rate: dropout rate 18 | 19 | """ 20 | 21 | def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): 22 | """Construct an PositionwiseFeedForward object.""" 23 | super(PositionwiseFeedForward, self).__init__() 24 | self.w_1 = torch.nn.Linear(idim, hidden_units) 25 | self.w_2 = torch.nn.Linear(hidden_units, idim) 26 | self.dropout = torch.nn.Dropout(dropout_rate) 27 | self.activation = activation 28 | 29 | def forward(self, x): 30 | """Forward funciton.""" 31 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 32 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/repeat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Repeat the same layer definition.""" 8 | 9 | import torch 10 | 11 | 12 | class MultiSequential(torch.nn.Sequential): 13 | """Multi-input multi-output torch.nn.Sequential.""" 14 | 15 | def forward(self, *args): 16 | """Repeat.""" 17 | for m in self: 18 | args = m(*args) 19 | return args 20 | 21 | 22 | def repeat(N, fn): 23 | """Repeat module N times. 24 | 25 | :param int N: repeat time 26 | :param function fn: function to generate module 27 | :return: repeated modules 28 | :rtype: MultiSequential 29 | """ 30 | return MultiSequential(*[fn(n) for n in range(N)]) 31 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/subsampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Subsampling layer definition.""" 8 | import logging 9 | import torch 10 | 11 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding 12 | 13 | 14 | class Conv2dSubsampling(torch.nn.Module): 15 | """Convolutional 2D subsampling (to 1/4 length or 1/2 length). 16 | 17 | :param int idim: input dim 18 | :param int odim: output dim 19 | :param flaot dropout_rate: dropout rate 20 | :param torch.nn.Module pos_enc: custom position encoding layer 21 | 22 | """ 23 | 24 | def __init__(self, idim, odim, dropout_rate, pos_enc=None, 25 | subsample_by_2=False, 26 | ): 27 | """Construct an Conv2dSubsampling object.""" 28 | super(Conv2dSubsampling, self).__init__() 29 | self.subsample_by_2 = subsample_by_2 30 | if subsample_by_2: 31 | self.conv = torch.nn.Sequential( 32 | torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2), 33 | torch.nn.ReLU(), 34 | torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1), 35 | torch.nn.ReLU(), 36 | ) 37 | self.out = torch.nn.Sequential( 38 | torch.nn.Linear(odim * (idim // 2), odim), 39 | pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), 40 | ) 41 | else: 42 | self.conv = torch.nn.Sequential( 43 | torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1), 44 | torch.nn.ReLU(), 45 | torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1), 46 | torch.nn.ReLU(), 47 | ) 48 | self.out = torch.nn.Sequential( 49 | torch.nn.Linear(odim * (idim // 4), odim), 50 | pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), 51 | ) 52 | 53 | def forward(self, x, x_mask): 54 | """Subsample x. 55 | 56 | :param torch.Tensor x: input tensor 57 | :param torch.Tensor x_mask: input mask 58 | :return: subsampled x and mask 59 | :rtype Tuple[torch.Tensor, torch.Tensor] 60 | 61 | """ 62 | x = x.unsqueeze(1) # (b, c, t, f) 63 | x = self.conv(x) 64 | b, c, t, f = x.size() 65 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 66 | if x_mask is None: 67 | return x, None 68 | if self.subsample_by_2: 69 | return x, x_mask[:, :, ::2] 70 | else: 71 | return x, x_mask[:, :, ::2][:, :, ::2] 72 | 73 | def __getitem__(self, key): 74 | """Subsample x. 75 | 76 | When reset_parameters() is called, if use_scaled_pos_enc is used, 77 | return the positioning encoding. 78 | 79 | """ 80 | if key != -1: 81 | raise NotImplementedError("Support only `-1` (for `reset_parameters`).") 82 | return self.out[key] 83 | 84 | 85 | class Conv2dNoSubsampling(torch.nn.Module): 86 | """Convolutional 2D without subsampling. 87 | 88 | :param int idim: input dim 89 | :param int odim: output dim 90 | :param flaot dropout_rate: dropout rate 91 | :param torch.nn.Module pos_enc: custom position encoding layer 92 | 93 | """ 94 | 95 | def __init__(self, idim, odim, dropout_rate, pos_enc=None): 96 | """Construct an Conv2dSubsampling object.""" 97 | super().__init__() 98 | logging.info("Encoder does not do down-sample on mel-spectrogram.") 99 | self.conv = torch.nn.Sequential( 100 | torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2), 101 | torch.nn.ReLU(), 102 | torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2), 103 | torch.nn.ReLU(), 104 | ) 105 | self.out = torch.nn.Sequential( 106 | torch.nn.Linear(odim * idim, odim), 107 | pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), 108 | ) 109 | 110 | def forward(self, x, x_mask): 111 | """Subsample x. 112 | 113 | :param torch.Tensor x: input tensor 114 | :param torch.Tensor x_mask: input mask 115 | :return: subsampled x and mask 116 | :rtype Tuple[torch.Tensor, torch.Tensor] 117 | 118 | """ 119 | x = x.unsqueeze(1) # (b, c, t, f) 120 | x = self.conv(x) 121 | b, c, t, f = x.size() 122 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 123 | if x_mask is None: 124 | return x, None 125 | return x, x_mask 126 | 127 | def __getitem__(self, key): 128 | """Subsample x. 129 | 130 | When reset_parameters() is called, if use_scaled_pos_enc is used, 131 | return the positioning encoding. 132 | 133 | """ 134 | if key != -1: 135 | raise NotImplementedError("Support only `-1` (for `reset_parameters`).") 136 | return self.out[key] 137 | 138 | 139 | class Conv2dSubsampling6(torch.nn.Module): 140 | """Convolutional 2D subsampling (to 1/6 length). 141 | 142 | :param int idim: input dim 143 | :param int odim: output dim 144 | :param flaot dropout_rate: dropout rate 145 | 146 | """ 147 | 148 | def __init__(self, idim, odim, dropout_rate): 149 | """Construct an Conv2dSubsampling object.""" 150 | super(Conv2dSubsampling6, self).__init__() 151 | self.conv = torch.nn.Sequential( 152 | torch.nn.Conv2d(1, odim, 3, 2), 153 | torch.nn.ReLU(), 154 | torch.nn.Conv2d(odim, odim, 5, 3), 155 | torch.nn.ReLU(), 156 | ) 157 | self.out = torch.nn.Sequential( 158 | torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), 159 | PositionalEncoding(odim, dropout_rate), 160 | ) 161 | 162 | def forward(self, x, x_mask): 163 | """Subsample x. 164 | 165 | :param torch.Tensor x: input tensor 166 | :param torch.Tensor x_mask: input mask 167 | :return: subsampled x and mask 168 | :rtype Tuple[torch.Tensor, torch.Tensor] 169 | """ 170 | x = x.unsqueeze(1) # (b, c, t, f) 171 | x = self.conv(x) 172 | b, c, t, f = x.size() 173 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 174 | if x_mask is None: 175 | return x, None 176 | return x, x_mask[:, :, :-2:2][:, :, :-4:3] 177 | 178 | 179 | class Conv2dSubsampling8(torch.nn.Module): 180 | """Convolutional 2D subsampling (to 1/8 length). 181 | 182 | :param int idim: input dim 183 | :param int odim: output dim 184 | :param flaot dropout_rate: dropout rate 185 | 186 | """ 187 | 188 | def __init__(self, idim, odim, dropout_rate): 189 | """Construct an Conv2dSubsampling object.""" 190 | super(Conv2dSubsampling8, self).__init__() 191 | self.conv = torch.nn.Sequential( 192 | torch.nn.Conv2d(1, odim, 3, 2), 193 | torch.nn.ReLU(), 194 | torch.nn.Conv2d(odim, odim, 3, 2), 195 | torch.nn.ReLU(), 196 | torch.nn.Conv2d(odim, odim, 3, 2), 197 | torch.nn.ReLU(), 198 | ) 199 | self.out = torch.nn.Sequential( 200 | torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim), 201 | PositionalEncoding(odim, dropout_rate), 202 | ) 203 | 204 | def forward(self, x, x_mask): 205 | """Subsample x. 206 | 207 | :param torch.Tensor x: input tensor 208 | :param torch.Tensor x_mask: input mask 209 | :return: subsampled x and mask 210 | :rtype Tuple[torch.Tensor, torch.Tensor] 211 | """ 212 | x = x.unsqueeze(1) # (b, c, t, f) 213 | x = self.conv(x) 214 | b, c, t, f = x.size() 215 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 216 | if x_mask is None: 217 | return x, None 218 | return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] 219 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/swish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | 8 | """Swish() activation function for Conformer.""" 9 | 10 | import torch 11 | 12 | 13 | class Swish(torch.nn.Module): 14 | """Construct an Swish object.""" 15 | 16 | def forward(self, x): 17 | """Return Swich activation function.""" 18 | return x * torch.sigmoid(x) 19 | -------------------------------------------------------------------------------- /conformer_ppg_model/encoder/vgg.py: -------------------------------------------------------------------------------- 1 | """VGG2L definition for transformer-transducer.""" 2 | 3 | import torch 4 | 5 | 6 | class VGG2L(torch.nn.Module): 7 | """VGG2L module for transformer-transducer encoder.""" 8 | 9 | def __init__(self, idim, odim): 10 | """Construct a VGG2L object. 11 | 12 | Args: 13 | idim (int): dimension of inputs 14 | odim (int): dimension of outputs 15 | 16 | """ 17 | super(VGG2L, self).__init__() 18 | 19 | self.vgg2l = torch.nn.Sequential( 20 | torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), 21 | torch.nn.ReLU(), 22 | torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), 23 | torch.nn.ReLU(), 24 | torch.nn.MaxPool2d((3, 2)), 25 | torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), 26 | torch.nn.ReLU(), 27 | torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), 28 | torch.nn.ReLU(), 29 | torch.nn.MaxPool2d((2, 2)), 30 | ) 31 | 32 | self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) 33 | 34 | def forward(self, x, x_mask): 35 | """VGG2L forward for x. 36 | 37 | Args: 38 | x (torch.Tensor): input torch (B, T, idim) 39 | x_mask (torch.Tensor): (B, 1, T) 40 | 41 | Returns: 42 | x (torch.Tensor): input torch (B, sub(T), attention_dim) 43 | x_mask (torch.Tensor): (B, 1, sub(T)) 44 | 45 | """ 46 | x = x.unsqueeze(1) 47 | x = self.vgg2l(x) 48 | 49 | b, c, t, f = x.size() 50 | 51 | x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f)) 52 | 53 | if x_mask is None: 54 | return x, None 55 | else: 56 | x_mask = self.create_new_mask(x_mask, x) 57 | 58 | return x, x_mask 59 | 60 | def create_new_mask(self, x_mask, x): 61 | """Create a subsampled version of x_mask. 62 | 63 | Args: 64 | x_mask (torch.Tensor): (B, 1, T) 65 | x (torch.Tensor): (B, sub(T), attention_dim) 66 | 67 | Returns: 68 | x_mask (torch.Tensor): (B, 1, sub(T)) 69 | 70 | """ 71 | x_t1 = x_mask.size(2) - (x_mask.size(2) % 3) 72 | x_mask = x_mask[:, :, :x_t1][:, :, ::3] 73 | 74 | x_t2 = x_mask.size(2) - (x_mask.size(2) % 2) 75 | x_mask = x_mask[:, :, :x_t2][:, :, ::2] 76 | 77 | return x_mask 78 | -------------------------------------------------------------------------------- /conformer_ppg_model/frontend.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional 3 | from typing import Tuple 4 | from typing import Union 5 | 6 | import humanfriendly 7 | import numpy as np 8 | import torch 9 | from torch_complex.tensor import ComplexTensor 10 | 11 | from .log_mel import LogMel 12 | from .stft import Stft 13 | 14 | 15 | class DefaultFrontend(torch.nn.Module): 16 | """Conventional frontend structure for ASR 17 | 18 | Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN 19 | """ 20 | 21 | def __init__( 22 | self, 23 | fs: Union[int, str] = 16000, 24 | n_fft: int = 1024, 25 | win_length: int = 800, 26 | hop_length: int = 160, 27 | center: bool = True, 28 | pad_mode: str = "reflect", 29 | normalized: bool = False, 30 | onesided: bool = True, 31 | n_mels: int = 80, 32 | fmin: int = None, 33 | fmax: int = None, 34 | htk: bool = False, 35 | norm=1, 36 | frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend), 37 | kaldi_padding_mode=False, 38 | downsample_rate: int = 1, 39 | ): 40 | super().__init__() 41 | if isinstance(fs, str): 42 | fs = humanfriendly.parse_size(fs) 43 | self.downsample_rate = downsample_rate 44 | 45 | # Deepcopy (In general, dict shouldn't be used as default arg) 46 | frontend_conf = copy.deepcopy(frontend_conf) 47 | 48 | self.stft = Stft( 49 | n_fft=n_fft, 50 | win_length=win_length, 51 | hop_length=hop_length, 52 | center=center, 53 | pad_mode=pad_mode, 54 | normalized=normalized, 55 | onesided=onesided, 56 | kaldi_padding_mode=kaldi_padding_mode, 57 | ) 58 | if frontend_conf is not None: 59 | self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) 60 | else: 61 | self.frontend = None 62 | 63 | self.logmel = LogMel( 64 | fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm, 65 | ) 66 | self.n_mels = n_mels 67 | 68 | def output_size(self) -> int: 69 | return self.n_mels 70 | 71 | def forward( 72 | self, input: torch.Tensor, input_lengths: torch.Tensor 73 | ) -> Tuple[torch.Tensor, torch.Tensor]: 74 | # 1. Domain-conversion: e.g. Stft: time -> time-freq 75 | input_stft, feats_lens = self.stft(input, input_lengths) 76 | 77 | assert input_stft.dim() >= 4, input_stft.shape 78 | # "2" refers to the real/imag parts of Complex 79 | assert input_stft.shape[-1] == 2, input_stft.shape 80 | 81 | # Change torch.Tensor to ComplexTensor 82 | # input_stft: (..., F, 2) -> (..., F) 83 | input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) 84 | 85 | # 2. [Option] Speech enhancement 86 | if self.frontend is not None: 87 | assert isinstance(input_stft, ComplexTensor), type(input_stft) 88 | # input_stft: (Batch, Length, [Channel], Freq) 89 | input_stft, _, mask = self.frontend(input_stft, feats_lens) 90 | 91 | # 3. [Multi channel case]: Select a channel 92 | if input_stft.dim() == 4: 93 | # h: (B, T, C, F) -> h: (B, T, F) 94 | if self.training: 95 | # Select 1ch randomly 96 | ch = np.random.randint(input_stft.size(2)) 97 | input_stft = input_stft[:, :, ch, :] 98 | else: 99 | # Use the first channel 100 | input_stft = input_stft[:, :, 0, :] 101 | 102 | # 4. STFT -> Power spectrum 103 | # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) 104 | input_power = input_stft.real ** 2 + input_stft.imag ** 2 105 | 106 | # 5. Feature transform e.g. Stft -> Log-Mel-Fbank 107 | # input_power: (Batch, [Channel,] Length, Freq) 108 | # -> input_feats: (Batch, Length, Dim) 109 | input_feats, _ = self.logmel(input_power, feats_lens) 110 | 111 | # NOTE(sx): pad 112 | max_len = input_feats.size(1) 113 | if self.downsample_rate > 1 and max_len % self.downsample_rate != 0: 114 | padding = self.downsample_rate - max_len % self.downsample_rate 115 | # print("Logmel: ", input_feats.size()) 116 | input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding), 117 | "constant", 0) 118 | # print("Logmel(after padding): ",input_feats.size()) 119 | feats_lens[torch.argmax(feats_lens)] = max_len + padding 120 | 121 | return input_feats, feats_lens 122 | -------------------------------------------------------------------------------- /conformer_ppg_model/log_mel.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | from typing import Tuple 5 | 6 | from .nets_utils import make_pad_mask 7 | 8 | 9 | class LogMel(torch.nn.Module): 10 | """Convert STFT to fbank feats 11 | 12 | The arguments is same as librosa.filters.mel 13 | 14 | Args: 15 | fs: number > 0 [scalar] sampling rate of the incoming signal 16 | n_fft: int > 0 [scalar] number of FFT components 17 | n_mels: int > 0 [scalar] number of Mel bands to generate 18 | fmin: float >= 0 [scalar] lowest frequency (in Hz) 19 | fmax: float >= 0 [scalar] highest frequency (in Hz). 20 | If `None`, use `fmax = fs / 2.0` 21 | htk: use HTK formula instead of Slaney 22 | norm: {None, 1, np.inf} [scalar] 23 | if 1, divide the triangular mel weights by the width of the mel band 24 | (area normalization). Otherwise, leave all the triangles aiming for 25 | a peak value of 1.0 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | fs: int = 16000, 32 | n_fft: int = 512, 33 | n_mels: int = 80, 34 | fmin: float = None, 35 | fmax: float = None, 36 | htk: bool = False, 37 | norm=1, 38 | ): 39 | super().__init__() 40 | 41 | fmin = 0 if fmin is None else fmin 42 | fmax = fs / 2 if fmax is None else fmax 43 | _mel_options = dict( 44 | sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm 45 | ) 46 | self.mel_options = _mel_options 47 | 48 | # Note(kamo): The mel matrix of librosa is different from kaldi. 49 | melmat = librosa.filters.mel(**_mel_options) 50 | # melmat: (D2, D1) -> (D1, D2) 51 | self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) 52 | inv_mel = np.linalg.pinv(melmat) 53 | self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float()) 54 | 55 | def extra_repr(self): 56 | return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) 57 | 58 | def forward( 59 | self, feat: torch.Tensor, ilens: torch.Tensor = None, 60 | ) -> Tuple[torch.Tensor, torch.Tensor]: 61 | # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) 62 | mel_feat = torch.matmul(feat, self.melmat) 63 | 64 | logmel_feat = (mel_feat + 1e-20).log() 65 | # Zero padding 66 | if ilens is not None: 67 | logmel_feat = logmel_feat.masked_fill( 68 | make_pad_mask(ilens, logmel_feat, 1), 0.0 69 | ) 70 | else: 71 | ilens = feat.new_full( 72 | [feat.size(0)], fill_value=feat.size(1), dtype=torch.long 73 | ) 74 | return logmel_feat, ilens 75 | -------------------------------------------------------------------------------- /conformer_ppg_model/stft.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing import Tuple 3 | from typing import Union 4 | 5 | import torch 6 | 7 | from .nets_utils import make_pad_mask 8 | 9 | 10 | class Stft(torch.nn.Module): 11 | def __init__( 12 | self, 13 | n_fft: int = 512, 14 | win_length: Union[int, None] = 512, 15 | hop_length: int = 128, 16 | center: bool = True, 17 | pad_mode: str = "reflect", 18 | normalized: bool = False, 19 | onesided: bool = True, 20 | kaldi_padding_mode=False, 21 | ): 22 | super().__init__() 23 | self.n_fft = n_fft 24 | if win_length is None: 25 | self.win_length = n_fft 26 | else: 27 | self.win_length = win_length 28 | self.hop_length = hop_length 29 | self.center = center 30 | self.pad_mode = pad_mode 31 | self.normalized = normalized 32 | self.onesided = onesided 33 | self.kaldi_padding_mode = kaldi_padding_mode 34 | if self.kaldi_padding_mode: 35 | self.win_length = 400 36 | 37 | def extra_repr(self): 38 | return ( 39 | f"n_fft={self.n_fft}, " 40 | f"win_length={self.win_length}, " 41 | f"hop_length={self.hop_length}, " 42 | f"center={self.center}, " 43 | f"pad_mode={self.pad_mode}, " 44 | f"normalized={self.normalized}, " 45 | f"onesided={self.onesided}" 46 | ) 47 | 48 | def forward( 49 | self, input: torch.Tensor, ilens: torch.Tensor = None 50 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 51 | """STFT forward function. 52 | 53 | Args: 54 | input: (Batch, Nsamples) or (Batch, Nsample, Channels) 55 | ilens: (Batch) 56 | Returns: 57 | output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) 58 | 59 | """ 60 | bs = input.size(0) 61 | if input.dim() == 3: 62 | multi_channel = True 63 | # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) 64 | input = input.transpose(1, 2).reshape(-1, input.size(1)) 65 | else: 66 | multi_channel = False 67 | 68 | # output: (Batch, Freq, Frames, 2=real_imag) 69 | # or (Batch, Channel, Freq, Frames, 2=real_imag) 70 | if not self.kaldi_padding_mode: 71 | output = torch.stft( 72 | input, 73 | n_fft=self.n_fft, 74 | win_length=self.win_length, 75 | hop_length=self.hop_length, 76 | center=self.center, 77 | pad_mode=self.pad_mode, 78 | normalized=self.normalized, 79 | onesided=self.onesided, 80 | ) 81 | else: 82 | # NOTE(sx): Use Kaldi-fasion padding, maybe wrong 83 | num_pads = self.n_fft - self.win_length 84 | input = torch.nn.functional.pad(input, (num_pads, 0)) 85 | output = torch.stft( 86 | input, 87 | n_fft=self.n_fft, 88 | win_length=self.win_length, 89 | hop_length=self.hop_length, 90 | center=False, 91 | pad_mode=self.pad_mode, 92 | normalized=self.normalized, 93 | onesided=self.onesided, 94 | ) 95 | 96 | # output: (Batch, Freq, Frames, 2=real_imag) 97 | # -> (Batch, Frames, Freq, 2=real_imag) 98 | output = output.transpose(1, 2) 99 | if multi_channel: 100 | # output: (Batch * Channel, Frames, Freq, 2=real_imag) 101 | # -> (Batch, Frame, Channel, Freq, 2=real_imag) 102 | output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( 103 | 1, 2 104 | ) 105 | 106 | if ilens is not None: 107 | if self.center: 108 | pad = self.win_length // 2 109 | ilens = ilens + 2 * pad 110 | 111 | olens = (ilens - self.win_length) // self.hop_length + 1 112 | output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) 113 | else: 114 | olens = None 115 | 116 | return output, olens 117 | -------------------------------------------------------------------------------- /conformer_ppg_model/utterance_mvn.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | from .nets_utils import make_pad_mask 6 | 7 | 8 | class UtteranceMVN(torch.nn.Module): 9 | def __init__( 10 | self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, 11 | ): 12 | super().__init__() 13 | self.norm_means = norm_means 14 | self.norm_vars = norm_vars 15 | self.eps = eps 16 | 17 | def extra_repr(self): 18 | return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" 19 | 20 | def forward( 21 | self, x: torch.Tensor, ilens: torch.Tensor = None 22 | ) -> Tuple[torch.Tensor, torch.Tensor]: 23 | """Forward function 24 | 25 | Args: 26 | x: (B, L, ...) 27 | ilens: (B,) 28 | 29 | """ 30 | return utterance_mvn( 31 | x, 32 | ilens, 33 | norm_means=self.norm_means, 34 | norm_vars=self.norm_vars, 35 | eps=self.eps, 36 | ) 37 | 38 | 39 | def utterance_mvn( 40 | x: torch.Tensor, 41 | ilens: torch.Tensor = None, 42 | norm_means: bool = True, 43 | norm_vars: bool = False, 44 | eps: float = 1.0e-20, 45 | ) -> Tuple[torch.Tensor, torch.Tensor]: 46 | """Apply utterance mean and variance normalization 47 | 48 | Args: 49 | x: (B, T, D), assumed zero padded 50 | ilens: (B,) 51 | norm_means: 52 | norm_vars: 53 | eps: 54 | 55 | """ 56 | if ilens is None: 57 | ilens = x.new_full([x.size(0)], x.size(1)) 58 | ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) 59 | # Zero padding 60 | if x.requires_grad: 61 | x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) 62 | else: 63 | x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) 64 | # mean: (B, 1, D) 65 | mean = x.sum(dim=1, keepdim=True) / ilens_ 66 | 67 | if norm_means: 68 | x -= mean 69 | 70 | if norm_vars: 71 | var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ 72 | std = torch.clamp(var.sqrt(), min=eps) 73 | x = x / std.sqrt() 74 | return x, ilens 75 | else: 76 | if norm_vars: 77 | y = x - mean 78 | y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) 79 | var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ 80 | std = torch.clamp(var.sqrt(), min=eps) 81 | x /= std 82 | return x, ilens 83 | -------------------------------------------------------------------------------- /convert_from_wav.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import os 4 | import argparse 5 | import torch 6 | import numpy as np 7 | import glob 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from conformer_ppg_model.build_ppg_model import load_ppg_model 11 | from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL 12 | from src.mel_decoder_lsa import MelDecoderLSA 13 | from src.rnn_ppg2mel import BiRnnPpg2MelModel 14 | import pyworld 15 | import librosa 16 | import resampy 17 | import soundfile as sf 18 | from utils.f0_utils import get_cont_lf0 19 | from utils.load_yaml import HpsYaml 20 | 21 | from vocoders.hifigan_model import load_hifigan_generator 22 | 23 | from speaker_encoder.voice_encoder import SpeakerEncoder 24 | from speaker_encoder.audio import preprocess_wav 25 | from src import build_model 26 | 27 | 28 | def compute_spk_dvec( 29 | wav_path, weights_fpath="speaker_encoder/ckpt/pretrained_bak_5805000.pt", 30 | ): 31 | fpath = Path(wav_path) 32 | wav = preprocess_wav(fpath) 33 | encoder = SpeakerEncoder(weights_fpath) 34 | spk_dvec = encoder.embed_utterance(wav) 35 | return spk_dvec 36 | 37 | 38 | def compute_f0(wav, sr=16000, frame_period=10.0): 39 | wav = wav.astype(np.float64) 40 | f0, timeaxis = pyworld.harvest( 41 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0) 42 | return f0 43 | 44 | 45 | def compute_mean_std(lf0): 46 | nonzero_indices = np.nonzero(lf0) 47 | mean = np.mean(lf0[nonzero_indices]) 48 | std = np.std(lf0[nonzero_indices]) 49 | return mean, std 50 | 51 | 52 | def f02lf0(f0): 53 | lf0 = f0.copy() 54 | nonzero_indices = np.nonzero(f0) 55 | lf0[nonzero_indices] = np.log(f0[nonzero_indices]) 56 | return lf0 57 | 58 | 59 | def get_converted_lf0uv( 60 | wav, 61 | lf0_mean_trg, 62 | lf0_std_trg, 63 | convert=True, 64 | ): 65 | f0_src = compute_f0(wav) 66 | if not convert: 67 | uv, cont_lf0 = get_cont_lf0(f0_src) 68 | lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1) 69 | return lf0_uv 70 | 71 | lf0_src = f02lf0(f0_src) 72 | lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src) 73 | 74 | lf0_vc = lf0_src.copy() 75 | lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg 76 | f0_vc = lf0_vc.copy() 77 | f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0]) 78 | 79 | uv, cont_lf0_vc = get_cont_lf0(f0_vc) 80 | lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1) 81 | return lf0_uv 82 | 83 | 84 | def build_ppg2mel_model(model_config, model_file, device): 85 | model_class = build_model(model_config["model_name"]) 86 | ppg2mel_model = model_class( 87 | **model_config["model"] 88 | ).to(device) 89 | ckpt = torch.load(model_file, map_location=device) 90 | ppg2mel_model.load_state_dict(ckpt["model"]) 91 | ppg2mel_model.eval() 92 | return ppg2mel_model 93 | 94 | 95 | @torch.no_grad() 96 | def convert(args): 97 | device = 'cuda' 98 | ppg2mel_config = HpsYaml(args.ppg2mel_model_train_config) 99 | output_dir = args.output_dir 100 | os.makedirs(output_dir, exist_ok=True) 101 | 102 | step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1] 103 | 104 | # Build models 105 | print("Load PPG-model, PPG2Mel-model, Vocoder-model...") 106 | ppg_model = load_ppg_model( 107 | './conformer_ppg_model/en_conformer_ctc_att/config.yaml', 108 | './conformer_ppg_model/en_conformer_ctc_att/24epoch.pth', 109 | device, 110 | ) 111 | ppg2mel_model = build_ppg2mel_model(ppg2mel_config, args.ppg2mel_model_file, device) 112 | hifigan_model = load_hifigan_generator(device) 113 | 114 | # Data related 115 | ref_wav_path = args.ref_wav_path 116 | ref_fid = os.path.basename(ref_wav_path)[:-4] 117 | ref_spk_dvec = compute_spk_dvec(ref_wav_path) 118 | ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device) 119 | ref_wav, _ = librosa.load(ref_wav_path, sr=16000) 120 | ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) 121 | 122 | source_file_list = sorted(glob.glob(f"{args.src_wav_dir}/*.wav")) 123 | print(f"Number of source utterances: {len(source_file_list)}.") 124 | 125 | total_rtf = 0.0 126 | cnt = 0 127 | for src_wav_path in tqdm(source_file_list): 128 | # Load the audio to a numpy array: 129 | src_wav, _ = librosa.load(src_wav_path, sr=16000) 130 | src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device) 131 | src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device) 132 | ppg = ppg_model(src_wav_tensor, src_wav_lengths) 133 | 134 | lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) 135 | min_len = min(ppg.shape[1], len(lf0_uv)) 136 | 137 | ppg = ppg[:, :min_len] 138 | lf0_uv = lf0_uv[:min_len] 139 | 140 | start = time.time() 141 | if isinstance(ppg2mel_model, BiRnnPpg2MelModel): 142 | ppg_length = torch.LongTensor([ppg.shape[1]]).to(device) 143 | logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device) 144 | mel_pred = ppg2mel_model(ppg, ppg_length, logf0_uv, ref_spk_dvec) 145 | else: 146 | _, mel_pred, att_ws = ppg2mel_model.inference( 147 | ppg, 148 | logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), 149 | spembs=ref_spk_dvec, 150 | use_stop_tokens=True, 151 | ) 152 | # if ppg2mel_config.data.min_max_norm_mel: 153 | # mel_min = ppg2mel_config.data.mel_min 154 | # mel_max = ppg2mel_config.data.mel_max 155 | # mel_pred = (mel_pred + 4.0) / 8.0 * (mel_max - mel_min) + mel_min 156 | src_fid = os.path.basename(src_wav_path)[:-4] 157 | wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav" 158 | mel_len = mel_pred.shape[0] 159 | rtf = (time.time() - start) / (0.01 * mel_len) 160 | total_rtf += rtf 161 | cnt += 1 162 | # continue 163 | y = hifigan_model(mel_pred.view(1, -1, 80).transpose(1, 2)) 164 | sf.write(wav_fname, y.squeeze().cpu().numpy(), 24000, "PCM_16") 165 | 166 | print("RTF:") 167 | print(total_rtf / cnt) 168 | 169 | 170 | def get_parser(): 171 | parser = argparse.ArgumentParser(description="Conversion from wave input") 172 | parser.add_argument( 173 | "--src_wav_dir", 174 | type=str, 175 | default=None, 176 | required=True, 177 | help="Source wave directory.", 178 | ) 179 | parser.add_argument( 180 | "--ref_wav_path", 181 | type=str, 182 | required=True, 183 | help="Reference wave file path.", 184 | ) 185 | parser.add_argument( 186 | "--ppg2mel_model_train_config", "-c", 187 | type=str, 188 | default=None, 189 | required=True, 190 | help="Training config file (yaml file)", 191 | ) 192 | parser.add_argument( 193 | "--ppg2mel_model_file", "-m", 194 | type=str, 195 | default=None, 196 | required=True, 197 | help="ppg2mel model checkpoint file path" 198 | ) 199 | parser.add_argument( 200 | "--output_dir", "-o", 201 | type=str, 202 | default="vc_gens_vctk_oneshot", 203 | help="Output folder to save the converted wave." 204 | ) 205 | 206 | return parser 207 | 208 | 209 | def main(): 210 | parser = get_parser() 211 | args = parser.parse_args() 212 | convert(args) 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /figs/README.md: -------------------------------------------------------------------------------- 1 | images 2 | -------------------------------------------------------------------------------- /figs/seq2seq_bnf2mel.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/figs/seq2seq_bnf2mel.pdf -------------------------------------------------------------------------------- /figs/seq2seq_bnf2mel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/figs/seq2seq_bnf2mel.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import sys 4 | import torch 5 | import argparse 6 | import numpy as np 7 | from utils.load_yaml import HpsYaml 8 | 9 | # For reproducibility, comment these may speed up training 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | # Arguments 14 | parser = argparse.ArgumentParser(description= 15 | 'Training PPG2Mel VC model.') 16 | parser.add_argument('--config', type=str, 17 | help='Path to experiment config, e.g., config/vc.yaml') 18 | parser.add_argument('--name', default=None, type=str, help='Name for logging.') 19 | parser.add_argument('--logdir', default='log/', type=str, 20 | help='Logging path.', required=False) 21 | parser.add_argument('--ckpdir', default='ckpt/', type=str, 22 | help='Checkpoint path.', required=False) 23 | parser.add_argument('--outdir', default='result/', type=str, 24 | help='Decode output path.', required=False) 25 | parser.add_argument('--load', default=None, type=str, 26 | help='Load pre-trained model (for training only)', required=False) 27 | parser.add_argument('--warm_start', action='store_true', 28 | help='Load model weights only, ignore specified layers.') 29 | parser.add_argument('--seed', default=0, type=int, 30 | help='Random seed for reproducable results.', required=False) 31 | parser.add_argument('--njobs', default=8, type=int, 32 | help='Number of threads for dataloader/decoding.', required=False) 33 | parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') 34 | parser.add_argument('--no-pin', action='store_true', 35 | help='Disable pin-memory for dataloader') 36 | parser.add_argument('--test', action='store_true', help='Test the model.') 37 | parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') 38 | parser.add_argument('--finetune', action='store_true', help='Finetune model') 39 | parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') 40 | parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') 41 | parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') 42 | 43 | ### 44 | 45 | paras = parser.parse_args() 46 | setattr(paras, 'gpu', not paras.cpu) 47 | setattr(paras, 'pin_memory', not paras.no_pin) 48 | setattr(paras, 'verbose', not paras.no_msg) 49 | # Make the config dict dot visitable 50 | config = HpsYaml(paras.config) # yaml.load(open(paras.config, 'r'), Loader=yaml.FullLoader) 51 | 52 | np.random.seed(paras.seed) 53 | torch.manual_seed(paras.seed) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed_all(paras.seed) 56 | # For debug use 57 | # torch.autograd.set_detect_anomaly(True) 58 | 59 | # Hack to preserve GPU ram just in case OOM later on server 60 | # if paras.gpu and paras.reserve_gpu > 0: 61 | # buff = torch.randn(int(paras.reserve_gpu*1e9//4)).cuda() 62 | # del buff 63 | 64 | if paras.oneshotvc: 65 | print(">>> OneShot VC training ...") 66 | if paras.bilstm: 67 | from bin.train_ppg2mel_oneshotvc import Solver 68 | else: 69 | from bin.train_linglf02mel_seq2seq_oneshotvc import Solver 70 | mode = "train" 71 | solver = Solver(config, paras, mode) 72 | solver.load_data() 73 | solver.set_model() 74 | solver.exec() 75 | print(">>> Oneshot VC train finished!") 76 | sys.exit(0) 77 | elif paras.lsa: 78 | print(">>> Use location sensitive attention based seq2seq model") 79 | from bin.train_linglf02mel_seq2seq_lsa import Solver 80 | mode = "train" 81 | solver = Solver(config, paras, mode) 82 | solver.load_data() 83 | solver.set_model() 84 | solver.exec() 85 | print(">>> VC train finished!") 86 | sys.exit(0) 87 | else: 88 | if paras.finetune: 89 | from bin.finetune_linglf02mel_seq2seq_encAddlf0 import Solver 90 | else: 91 | from bin.train_linglf02mel_seq2seq_encAddlf0 import Solver 92 | mode = 'train' 93 | solver = Solver(config, paras, mode) 94 | solver.load_data() 95 | solver.set_model() 96 | solver.exec() 97 | print(">>> VC train finished!") 98 | sys.exit(0) 99 | -------------------------------------------------------------------------------- /path.sh: -------------------------------------------------------------------------------- 1 | # cuda related 2 | export CUDA_HOME=/usr/local/cuda 3 | export LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" 4 | 5 | # path related 6 | export PRJ_ROOT="./" 7 | if [ -e "${PRJ_ROOT}/tool/venv/bin/activate" ]; then 8 | # shellcheck disable=SC1090 9 | . "${PRJ_ROOT}/tool/venv/bin/activate" 10 | fi 11 | 12 | # python related 13 | export OMP_NUM_THREADS=1 14 | export PYTHONIOENCODING=UTF-8 15 | export MPL_BACKEND=Agg 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchaudio==0.6.0 2 | tqdm==4.49.0 3 | librosa==0.8.0 4 | soundfile==0.10.3.post1 5 | resampy==0.2.2 6 | pyworld==0.2.11.post0 7 | PyYAML>=5.4 8 | typeguard==2.12.1 9 | matplotlib==3.3.2 10 | numpy==1.19.2 11 | scipy==1.3.3 12 | webrtcvad==2.0.10 13 | glob2 14 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | . ./path.sh || exit 1; 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | ########## Train BiLSTM oneshot VC model ########## 7 | python main.py --config ./conf/bilstm_ppg2mel_vctk_libri_oneshotvc.yaml \ 8 | --oneshotvc \ 9 | --bilstm 10 | ################################################### 11 | 12 | ########## Train Seq2seq oneshot VC model ########### 13 | #python main.py --config ./conf/seq2seq_mol_ppg2mel_vctk_libri_oneshotvc.yaml \ 14 | #--oneshotvc \ 15 | ################################################### 16 | -------------------------------------------------------------------------------- /speaker_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/speaker_encoder/__init__.py -------------------------------------------------------------------------------- /speaker_encoder/audio.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import binary_dilation 2 | from speaker_encoder.params_data import * 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | import numpy as np 6 | import webrtcvad 7 | import librosa 8 | import struct 9 | 10 | int16_max = (2 ** 15) - 1 11 | 12 | 13 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], 14 | source_sr: Optional[int] = None): 15 | """ 16 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform 17 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters. 18 | 19 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 20 | just .wav), either the waveform as a numpy array of floats. 21 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 22 | preprocessing. After preprocessing, the waveform's sampling rate will match the data 23 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 24 | this argument will be ignored. 25 | """ 26 | # Load the wav from disk if needed 27 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 28 | wav, source_sr = librosa.load(fpath_or_wav, sr=None) 29 | else: 30 | wav = fpath_or_wav 31 | 32 | # Resample the wav if needed 33 | if source_sr is not None and source_sr != sampling_rate: 34 | wav = librosa.resample(wav, source_sr, sampling_rate) 35 | 36 | # Apply the preprocessing: normalize volume and shorten long silences 37 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) 38 | wav = trim_long_silences(wav) 39 | 40 | return wav 41 | 42 | 43 | def wav_to_mel_spectrogram(wav): 44 | """ 45 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 46 | Note: this not a log-mel spectrogram. 47 | """ 48 | frames = librosa.feature.melspectrogram( 49 | wav, 50 | sampling_rate, 51 | n_fft=int(sampling_rate * mel_window_length / 1000), 52 | hop_length=int(sampling_rate * mel_window_step / 1000), 53 | n_mels=mel_n_channels 54 | ) 55 | return frames.astype(np.float32).T 56 | 57 | 58 | def trim_long_silences(wav): 59 | """ 60 | Ensures that segments without voice in the waveform remain no longer than a 61 | threshold determined by the VAD parameters in params.py. 62 | 63 | :param wav: the raw waveform as a numpy array of floats 64 | :return: the same waveform with silences trimmed away (length <= original wav length) 65 | """ 66 | # Compute the voice detection window size 67 | samples_per_window = (vad_window_length * sampling_rate) // 1000 68 | 69 | # Trim the end of the audio to have a multiple of the window size 70 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 71 | 72 | # Convert the float waveform to 16-bit mono PCM 73 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 74 | 75 | # Perform voice activation detection 76 | voice_flags = [] 77 | vad = webrtcvad.Vad(mode=3) 78 | for window_start in range(0, len(wav), samples_per_window): 79 | window_end = window_start + samples_per_window 80 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 81 | sample_rate=sampling_rate)) 82 | voice_flags = np.array(voice_flags) 83 | 84 | # Smooth the voice detection with a moving average 85 | def moving_average(array, width): 86 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 87 | ret = np.cumsum(array_padded, dtype=float) 88 | ret[width:] = ret[width:] - ret[:-width] 89 | return ret[width - 1:] / width 90 | 91 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 92 | audio_mask = np.round(audio_mask).astype(np.bool) 93 | 94 | # Dilate the voiced regions 95 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 96 | audio_mask = np.repeat(audio_mask, samples_per_window) 97 | 98 | return wav[audio_mask == True] 99 | 100 | 101 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): 102 | if increase_only and decrease_only: 103 | raise ValueError("Both increase only and decrease only are set") 104 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) 105 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): 106 | return wav 107 | return wav * (10 ** (dBFS_change / 20)) 108 | -------------------------------------------------------------------------------- /speaker_encoder/ckpt/pretrained_bak_5805000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/speaker_encoder/ckpt/pretrained_bak_5805000.pt -------------------------------------------------------------------------------- /speaker_encoder/compute_embed.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder import inference as encoder 2 | from multiprocessing.pool import Pool 3 | from functools import partial 4 | from pathlib import Path 5 | # from utils import logmmse 6 | # from tqdm import tqdm 7 | # import numpy as np 8 | # import librosa 9 | 10 | 11 | def embed_utterance(fpaths, encoder_model_fpath): 12 | if not encoder.is_loaded(): 13 | encoder.load_model(encoder_model_fpath) 14 | 15 | # Compute the speaker embedding of the utterance 16 | wav_fpath, embed_fpath = fpaths 17 | wav = np.load(wav_fpath) 18 | wav = encoder.preprocess_wav(wav) 19 | embed = encoder.embed_utterance(wav) 20 | np.save(embed_fpath, embed, allow_pickle=False) 21 | 22 | 23 | def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int): 24 | 25 | wav_dir = outdir_root.joinpath("audio") 26 | metadata_fpath = synthesizer_root.joinpath("train.txt") 27 | assert wav_dir.exists() and metadata_fpath.exists() 28 | embed_dir = synthesizer_root.joinpath("embeds") 29 | embed_dir.mkdir(exist_ok=True) 30 | 31 | # Gather the input wave filepath and the target output embed filepath 32 | with metadata_fpath.open("r") as metadata_file: 33 | metadata = [line.split("|") for line in metadata_file] 34 | fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata] 35 | 36 | # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here. 37 | # Embed the utterances in separate threads 38 | func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath) 39 | job = Pool(n_processes).imap(func, fpaths) 40 | list(tqdm(job, "Embedding", len(fpaths), unit="utterances")) -------------------------------------------------------------------------------- /speaker_encoder/config.py: -------------------------------------------------------------------------------- 1 | librispeech_datasets = { 2 | "train": { 3 | "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], 4 | "other": ["LibriSpeech/train-other-500"] 5 | }, 6 | "test": { 7 | "clean": ["LibriSpeech/test-clean"], 8 | "other": ["LibriSpeech/test-other"] 9 | }, 10 | "dev": { 11 | "clean": ["LibriSpeech/dev-clean"], 12 | "other": ["LibriSpeech/dev-other"] 13 | }, 14 | } 15 | libritts_datasets = { 16 | "train": { 17 | "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], 18 | "other": ["LibriTTS/train-other-500"] 19 | }, 20 | "test": { 21 | "clean": ["LibriTTS/test-clean"], 22 | "other": ["LibriTTS/test-other"] 23 | }, 24 | "dev": { 25 | "clean": ["LibriTTS/dev-clean"], 26 | "other": ["LibriTTS/dev-other"] 27 | }, 28 | } 29 | voxceleb_datasets = { 30 | "voxceleb1" : { 31 | "train": ["VoxCeleb1/wav"], 32 | "test": ["VoxCeleb1/test_wav"] 33 | }, 34 | "voxceleb2" : { 35 | "train": ["VoxCeleb2/dev/aac"], 36 | "test": ["VoxCeleb2/test_wav"] 37 | } 38 | } 39 | 40 | other_datasets = [ 41 | "LJSpeech-1.1", 42 | "VCTK-Corpus/wav48", 43 | ] 44 | 45 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] 46 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader 3 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/random_cycler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class RandomCycler: 4 | """ 5 | Creates an internal copy of a sequence and allows access to its items in a constrained random 6 | order. For a source sequence of n items and one or several consecutive queries of a total 7 | of m items, the following guarantees hold (one implies the other): 8 | - Each item will be returned between m // n and ((m - 1) // n) + 1 times. 9 | - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. 10 | """ 11 | 12 | def __init__(self, source): 13 | if len(source) == 0: 14 | raise Exception("Can't create RandomCycler from an empty collection") 15 | self.all_items = list(source) 16 | self.next_items = [] 17 | 18 | def sample(self, count: int): 19 | shuffle = lambda l: random.sample(l, len(l)) 20 | 21 | out = [] 22 | while count > 0: 23 | if count >= len(self.all_items): 24 | out.extend(shuffle(list(self.all_items))) 25 | count -= len(self.all_items) 26 | continue 27 | n = min(count, len(self.next_items)) 28 | out.extend(self.next_items[:n]) 29 | count -= n 30 | self.next_items = self.next_items[n:] 31 | if len(self.next_items) == 0: 32 | self.next_items = shuffle(list(self.all_items)) 33 | return out 34 | 35 | def __next__(self): 36 | return self.sample(1)[0] 37 | 38 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/speaker.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.random_cycler import RandomCycler 2 | from speaker_encoder.data_objects.utterance import Utterance 3 | from pathlib import Path 4 | 5 | # Contains the set of utterances of a single speaker 6 | class Speaker: 7 | def __init__(self, root: Path): 8 | self.root = root 9 | self.name = root.name 10 | self.utterances = None 11 | self.utterance_cycler = None 12 | 13 | def _load_utterances(self): 14 | with self.root.joinpath("_sources.txt").open("r") as sources_file: 15 | sources = [l.split(",") for l in sources_file] 16 | sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} 17 | self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] 18 | self.utterance_cycler = RandomCycler(self.utterances) 19 | 20 | def random_partial(self, count, n_frames): 21 | """ 22 | Samples a batch of unique partial utterances from the disk in a way that all 23 | utterances come up at least once every two cycles and in a random order every time. 24 | 25 | :param count: The number of partial utterances to sample from the set of utterances from 26 | that speaker. Utterances are guaranteed not to be repeated if is not larger than 27 | the number of utterances available. 28 | :param n_frames: The number of frames in the partial utterance. 29 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, 30 | frames are the frames of the partial utterances and range is the range of the partial 31 | utterance with regard to the complete utterance. 32 | """ 33 | if self.utterances is None: 34 | self._load_utterances() 35 | 36 | utterances = self.utterance_cycler.sample(count) 37 | 38 | a = [(u,) + u.random_partial(n_frames) for u in utterances] 39 | 40 | return a 41 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/speaker_batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | from speaker_encoder.data_objects.speaker import Speaker 4 | 5 | class SpeakerBatch: 6 | def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): 7 | self.speakers = speakers 8 | self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} 9 | 10 | # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with 11 | # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) 12 | self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) 13 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/speaker_verification_dataset.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.random_cycler import RandomCycler 2 | from speaker_encoder.data_objects.speaker_batch import SpeakerBatch 3 | from speaker_encoder.data_objects.speaker import Speaker 4 | from speaker_encoder.params_data import partials_n_frames 5 | from torch.utils.data import Dataset, DataLoader 6 | from pathlib import Path 7 | 8 | # TODO: improve with a pool of speakers for data efficiency 9 | 10 | class SpeakerVerificationDataset(Dataset): 11 | def __init__(self, datasets_root: Path): 12 | self.root = datasets_root 13 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] 14 | if len(speaker_dirs) == 0: 15 | raise Exception("No speakers found. Make sure you are pointing to the directory " 16 | "containing all preprocessed speaker directories.") 17 | self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] 18 | self.speaker_cycler = RandomCycler(self.speakers) 19 | 20 | def __len__(self): 21 | return int(1e10) 22 | 23 | def __getitem__(self, index): 24 | return next(self.speaker_cycler) 25 | 26 | def get_logs(self): 27 | log_string = "" 28 | for log_fpath in self.root.glob("*.txt"): 29 | with log_fpath.open("r") as log_file: 30 | log_string += "".join(log_file.readlines()) 31 | return log_string 32 | 33 | 34 | class SpeakerVerificationDataLoader(DataLoader): 35 | def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, 36 | batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, 37 | worker_init_fn=None): 38 | self.utterances_per_speaker = utterances_per_speaker 39 | 40 | super().__init__( 41 | dataset=dataset, 42 | batch_size=speakers_per_batch, 43 | shuffle=False, 44 | sampler=sampler, 45 | batch_sampler=batch_sampler, 46 | num_workers=num_workers, 47 | collate_fn=self.collate, 48 | pin_memory=pin_memory, 49 | drop_last=False, 50 | timeout=timeout, 51 | worker_init_fn=worker_init_fn 52 | ) 53 | 54 | def collate(self, speakers): 55 | return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) 56 | -------------------------------------------------------------------------------- /speaker_encoder/data_objects/utterance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Utterance: 5 | def __init__(self, frames_fpath, wave_fpath): 6 | self.frames_fpath = frames_fpath 7 | self.wave_fpath = wave_fpath 8 | 9 | def get_frames(self): 10 | return np.load(self.frames_fpath) 11 | 12 | def random_partial(self, n_frames): 13 | """ 14 | Crops the frames into a partial utterance of n_frames 15 | 16 | :param n_frames: The number of frames of the partial utterance 17 | :return: the partial utterance frames and a tuple indicating the start and end of the 18 | partial utterance in the complete utterance. 19 | """ 20 | frames = self.get_frames() 21 | if frames.shape[0] == n_frames: 22 | start = 0 23 | else: 24 | start = np.random.randint(0, frames.shape[0] - n_frames) 25 | end = start + n_frames 26 | return frames[start:end], (start, end) -------------------------------------------------------------------------------- /speaker_encoder/hparams.py: -------------------------------------------------------------------------------- 1 | ## Mel-filterbank 2 | mel_window_length = 25 # In milliseconds 3 | mel_window_step = 10 # In milliseconds 4 | mel_n_channels = 40 5 | 6 | 7 | ## Audio 8 | sampling_rate = 16000 9 | # Number of spectrogram frames in a partial utterance 10 | partials_n_frames = 160 # 1600 ms 11 | 12 | 13 | ## Voice Activation Detection 14 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 15 | # This sets the granularity of the VAD. Should not need to be changed. 16 | vad_window_length = 30 # In milliseconds 17 | # Number of frames to average together when performing the moving average smoothing. 18 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 19 | vad_moving_average_width = 8 20 | # Maximum number of consecutive silent frames a segment can have. 21 | vad_max_silence_length = 6 22 | 23 | 24 | ## Audio volume normalization 25 | audio_norm_target_dBFS = -30 26 | 27 | 28 | ## Model parameters 29 | model_hidden_size = 256 30 | model_embedding_size = 256 31 | model_num_layers = 3 -------------------------------------------------------------------------------- /speaker_encoder/inference.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.params_data import * 2 | from speaker_encoder.model import SpeakerEncoder 3 | from speaker_encoder.audio import preprocess_wav # We want to expose this function from here 4 | from matplotlib import cm 5 | from speaker_encoder import audio 6 | from pathlib import Path 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | 11 | _model = None # type: SpeakerEncoder 12 | _device = None # type: torch.device 13 | 14 | 15 | def load_model(weights_fpath: Path, device=None): 16 | """ 17 | Loads the model in memory. If this function is not explicitely called, it will be run on the 18 | first call to embed_frames() with the default weights file. 19 | 20 | :param weights_fpath: the path to saved model weights. 21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The 22 | model will be loaded and will run on this device. Outputs will however always be on the cpu. 23 | If None, will default to your GPU if it"s available, otherwise your CPU. 24 | """ 25 | # TODO: I think the slow loading of the encoder might have something to do with the device it 26 | # was saved on. Worth investigating. 27 | global _model, _device 28 | if device is None: 29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | elif isinstance(device, str): 31 | _device = torch.device(device) 32 | _model = SpeakerEncoder(_device, torch.device("cpu")) 33 | checkpoint = torch.load(weights_fpath) 34 | _model.load_state_dict(checkpoint["model_state"]) 35 | _model.eval() 36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) 37 | 38 | 39 | def is_loaded(): 40 | return _model is not None 41 | 42 | 43 | def embed_frames_batch(frames_batch): 44 | """ 45 | Computes embeddings for a batch of mel spectrogram. 46 | 47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape 48 | (batch_size, n_frames, n_channels) 49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) 50 | """ 51 | if _model is None: 52 | raise Exception("Model was not loaded. Call load_model() before inference.") 53 | 54 | frames = torch.from_numpy(frames_batch).to(_device) 55 | embed = _model.forward(frames).detach().cpu().numpy() 56 | return embed 57 | 58 | 59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, 60 | min_pad_coverage=0.75, overlap=0.5): 61 | """ 62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain 63 | partial utterances of each. Both the waveform and the mel 64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to 65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those 66 | defined in params_data.py. 67 | 68 | The returned ranges may be indexing further than the length of the waveform. It is 69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop. 70 | 71 | :param n_samples: the number of samples in the waveform 72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial 73 | utterance 74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have 75 | enough frames. If at least of are present, 76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise, 77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial 78 | utterance, this parameter is ignored so that the function always returns at least 1 slice. 79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial 80 | utterances are entirely disjoint. 81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial 83 | utterances. 84 | """ 85 | assert 0 <= overlap < 1 86 | assert 0 < min_pad_coverage <= 1 87 | 88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) 91 | 92 | # Compute the slices 93 | wav_slices, mel_slices = [], [] 94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) 95 | for i in range(0, steps, frame_step): 96 | mel_range = np.array([i, i + partial_utterance_n_frames]) 97 | wav_range = mel_range * samples_per_frame 98 | mel_slices.append(slice(*mel_range)) 99 | wav_slices.append(slice(*wav_range)) 100 | 101 | # Evaluate whether extra padding is warranted or not 102 | last_wav_range = wav_slices[-1] 103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 104 | if coverage < min_pad_coverage and len(mel_slices) > 1: 105 | mel_slices = mel_slices[:-1] 106 | wav_slices = wav_slices[:-1] 107 | 108 | return wav_slices, mel_slices 109 | 110 | 111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): 112 | """ 113 | Computes an embedding for a single utterance. 114 | 115 | # TODO: handle multiple wavs to benefit from batching on GPU 116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 117 | :param using_partials: if True, then the utterance is split in partial utterances of 118 | frames and the utterance embedding is computed from their 119 | normalized average. If False, the utterance is instead computed from feeding the entire 120 | spectogram to the network. 121 | :param return_partials: if True, the partial embeddings will also be returned along with the 122 | wav slices that correspond to the partial embeddings. 123 | :param kwargs: additional arguments to compute_partial_splits() 124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 125 | is True, the partial utterances as a numpy array of float32 of shape 126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 127 | returned. If is simultaneously set to False, both these values will be None 128 | instead. 129 | """ 130 | # Process the entire utterance if not using partials 131 | if not using_partials: 132 | frames = audio.wav_to_mel_spectrogram(wav) 133 | embed = embed_frames_batch(frames[None, ...])[0] 134 | if return_partials: 135 | return embed, None, None 136 | return embed 137 | 138 | # Compute where to split the utterance into partials and pad if necessary 139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 140 | max_wave_length = wave_slices[-1].stop 141 | if max_wave_length >= len(wav): 142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 143 | 144 | # Split the utterance into partials 145 | frames = audio.wav_to_mel_spectrogram(wav) 146 | frames_batch = np.array([frames[s] for s in mel_slices]) 147 | partial_embeds = embed_frames_batch(frames_batch) 148 | 149 | # Compute the utterance embedding from the partial embeddings 150 | raw_embed = np.mean(partial_embeds, axis=0) 151 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 152 | 153 | if return_partials: 154 | return embed, partial_embeds, wave_slices 155 | return embed 156 | 157 | 158 | def embed_speaker(wavs, **kwargs): 159 | raise NotImplemented() 160 | 161 | 162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): 163 | if ax is None: 164 | ax = plt.gca() 165 | 166 | if shape is None: 167 | height = int(np.sqrt(len(embed))) 168 | shape = (height, -1) 169 | embed = embed.reshape(shape) 170 | 171 | cmap = cm.get_cmap() 172 | mappable = ax.imshow(embed, cmap=cmap) 173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) 174 | cbar.set_clim(*color_range) 175 | 176 | ax.set_xticks([]), ax.set_yticks([]) 177 | ax.set_title(title) 178 | -------------------------------------------------------------------------------- /speaker_encoder/model.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.params_model import * 2 | from speaker_encoder.params_data import * 3 | from scipy.interpolate import interp1d 4 | from sklearn.metrics import roc_curve 5 | from torch.nn.utils import clip_grad_norm_ 6 | from scipy.optimize import brentq 7 | from torch import nn 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class SpeakerEncoder(nn.Module): 13 | def __init__(self, device, loss_device): 14 | super().__init__() 15 | self.loss_device = loss_device 16 | 17 | # Network defition 18 | self.lstm = nn.LSTM(input_size=mel_n_channels, # 40 19 | hidden_size=model_hidden_size, # 256 20 | num_layers=model_num_layers, # 3 21 | batch_first=True).to(device) 22 | self.linear = nn.Linear(in_features=model_hidden_size, 23 | out_features=model_embedding_size).to(device) 24 | self.relu = torch.nn.ReLU().to(device) 25 | 26 | # Cosine similarity scaling (with fixed initial parameter values) 27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) 28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) 29 | 30 | # Loss 31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device) 32 | 33 | def do_gradient_ops(self): 34 | # Gradient scale 35 | self.similarity_weight.grad *= 0.01 36 | self.similarity_bias.grad *= 0.01 37 | 38 | # Gradient clipping 39 | clip_grad_norm_(self.parameters(), 3, norm_type=2) 40 | 41 | def forward(self, utterances, hidden_init=None): 42 | """ 43 | Computes the embeddings of a batch of utterance spectrograms. 44 | 45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 46 | (batch_size, n_frames, n_channels) 47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 48 | batch_size, hidden_size). Will default to a tensor of zeros if None. 49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 50 | """ 51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 52 | # and the final cell state. 53 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 54 | 55 | # We take only the hidden state of the last layer 56 | embeds_raw = self.relu(self.linear(hidden[-1])) 57 | 58 | # L2-normalize it 59 | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 60 | 61 | return embeds 62 | 63 | def similarity_matrix(self, embeds): 64 | """ 65 | Computes the similarity matrix according the section 2.1 of GE2E. 66 | 67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 68 | utterances_per_speaker, embedding_size) 69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch, 70 | utterances_per_speaker, speakers_per_batch) 71 | """ 72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 73 | 74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation 75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) 76 | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) 77 | 78 | # Exclusive centroids (1 per utterance) 79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) 80 | centroids_excl /= (utterances_per_speaker - 1) 81 | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) 82 | 83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot 84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum). 85 | # We vectorize the computation for efficiency. 86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, 87 | speakers_per_batch).to(self.loss_device) 88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) 89 | for j in range(speakers_per_batch): 90 | mask = np.where(mask_matrix[j])[0] 91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) 92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) 93 | 94 | ## Even more vectorized version (slower maybe because of transpose) 95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker 96 | # ).to(self.loss_device) 97 | # eye = np.eye(speakers_per_batch, dtype=np.int) 98 | # mask = np.where(1 - eye) 99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) 100 | # mask = np.where(eye) 101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) 102 | # sim_matrix2 = sim_matrix2.transpose(1, 2) 103 | 104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias 105 | return sim_matrix 106 | 107 | def loss(self, embeds): 108 | """ 109 | Computes the softmax loss according the section 2.1 of GE2E. 110 | 111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 112 | utterances_per_speaker, embedding_size) 113 | :return: the loss and the EER for this batch of embeddings. 114 | """ 115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 116 | 117 | # Loss 118 | sim_matrix = self.similarity_matrix(embeds) 119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 120 | speakers_per_batch)) 121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) 122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device) 123 | loss = self.loss_fn(sim_matrix, target) 124 | 125 | # EER (not backpropagated) 126 | with torch.no_grad(): 127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] 128 | labels = np.array([inv_argmax(i) for i in ground_truth]) 129 | preds = sim_matrix.detach().cpu().numpy() 130 | 131 | # Snippet from https://yangcha.github.io/EER-ROC/ 132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) 133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 134 | 135 | return loss, eer -------------------------------------------------------------------------------- /speaker_encoder/params_data.py: -------------------------------------------------------------------------------- 1 | 2 | ## Mel-filterbank 3 | mel_window_length = 25 # In milliseconds 4 | mel_window_step = 10 # In milliseconds 5 | mel_n_channels = 40 6 | 7 | 8 | ## Audio 9 | sampling_rate = 16000 10 | # Number of spectrogram frames in a partial utterance 11 | partials_n_frames = 160 # 1600 ms 12 | # Number of spectrogram frames at inference 13 | inference_n_frames = 80 # 800 ms 14 | 15 | 16 | ## Voice Activation Detection 17 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 18 | # This sets the granularity of the VAD. Should not need to be changed. 19 | vad_window_length = 30 # In milliseconds 20 | # Number of frames to average together when performing the moving average smoothing. 21 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 22 | vad_moving_average_width = 8 23 | # Maximum number of consecutive silent frames a segment can have. 24 | vad_max_silence_length = 6 25 | 26 | 27 | ## Audio volume normalization 28 | audio_norm_target_dBFS = -30 29 | 30 | -------------------------------------------------------------------------------- /speaker_encoder/params_model.py: -------------------------------------------------------------------------------- 1 | 2 | ## Model parameters 3 | model_hidden_size = 256 4 | model_embedding_size = 256 5 | model_num_layers = 3 6 | 7 | 8 | ## Training parameters 9 | learning_rate_init = 1e-4 10 | speakers_per_batch = 64 11 | utterances_per_speaker = 10 12 | -------------------------------------------------------------------------------- /speaker_encoder/train.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.visualizations import Visualizations 2 | from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset 3 | from speaker_encoder.params_model import * 4 | from speaker_encoder.model import SpeakerEncoder 5 | from utils.profiler import Profiler 6 | from pathlib import Path 7 | import torch 8 | 9 | def sync(device: torch.device): 10 | # FIXME 11 | return 12 | # For correct profiling (cuda operations are async) 13 | if device.type == "cuda": 14 | torch.cuda.synchronize(device) 15 | 16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, 17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, 18 | no_visdom: bool): 19 | # Create a dataset and a dataloader 20 | dataset = SpeakerVerificationDataset(clean_data_root) 21 | loader = SpeakerVerificationDataLoader( 22 | dataset, 23 | speakers_per_batch, # 64 24 | utterances_per_speaker, # 10 25 | num_workers=8, 26 | ) 27 | 28 | # Setup the device on which to run the forward pass and the loss. These can be different, 29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your 30 | # hyperparameters) faster on the CPU. 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # FIXME: currently, the gradient is None if loss_device is cuda 33 | loss_device = torch.device("cpu") 34 | 35 | # Create the model and the optimizer 36 | model = SpeakerEncoder(device, loss_device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) 38 | init_step = 1 39 | 40 | # Configure file path for the model 41 | state_fpath = models_dir.joinpath(run_id + ".pt") 42 | backup_dir = models_dir.joinpath(run_id + "_backups") 43 | 44 | # Load any existing model 45 | if not force_restart: 46 | if state_fpath.exists(): 47 | print("Found existing model \"%s\", loading it and resuming training." % run_id) 48 | checkpoint = torch.load(state_fpath) 49 | init_step = checkpoint["step"] 50 | model.load_state_dict(checkpoint["model_state"]) 51 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 52 | optimizer.param_groups[0]["lr"] = learning_rate_init 53 | else: 54 | print("No model \"%s\" found, starting training from scratch." % run_id) 55 | else: 56 | print("Starting the training from scratch.") 57 | model.train() 58 | 59 | # Initialize the visualization environment 60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) 61 | vis.log_dataset(dataset) 62 | vis.log_params() 63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") 64 | vis.log_implementation({"Device": device_name}) 65 | 66 | # Training loop 67 | profiler = Profiler(summarize_every=10, disabled=False) 68 | for step, speaker_batch in enumerate(loader, init_step): 69 | profiler.tick("Blocking, waiting for batch (threaded)") 70 | 71 | # Forward pass 72 | inputs = torch.from_numpy(speaker_batch.data).to(device) 73 | sync(device) 74 | profiler.tick("Data to %s" % device) 75 | embeds = model(inputs) 76 | sync(device) 77 | profiler.tick("Forward pass") 78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) 79 | loss, eer = model.loss(embeds_loss) 80 | sync(loss_device) 81 | profiler.tick("Loss") 82 | 83 | # Backward pass 84 | model.zero_grad() 85 | loss.backward() 86 | profiler.tick("Backward pass") 87 | model.do_gradient_ops() 88 | optimizer.step() 89 | profiler.tick("Parameter update") 90 | 91 | # Update visualizations 92 | # learning_rate = optimizer.param_groups[0]["lr"] 93 | vis.update(loss.item(), eer, step) 94 | 95 | # Draw projections and save them to the backup folder 96 | if umap_every != 0 and step % umap_every == 0: 97 | print("Drawing and saving projections (step %d)" % step) 98 | backup_dir.mkdir(exist_ok=True) 99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) 100 | embeds = embeds.detach().cpu().numpy() 101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) 102 | vis.save() 103 | 104 | # Overwrite the latest version of the model 105 | if save_every != 0 and step % save_every == 0: 106 | print("Saving the model (step %d)" % step) 107 | torch.save({ 108 | "step": step + 1, 109 | "model_state": model.state_dict(), 110 | "optimizer_state": optimizer.state_dict(), 111 | }, state_fpath) 112 | 113 | # Make a backup 114 | if backup_every != 0 and step % backup_every == 0: 115 | print("Making a backup (step %d)" % step) 116 | backup_dir.mkdir(exist_ok=True) 117 | backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) 118 | torch.save({ 119 | "step": step + 1, 120 | "model_state": model.state_dict(), 121 | "optimizer_state": optimizer.state_dict(), 122 | }, backup_fpath) 123 | 124 | profiler.tick("Extras (visualizations, saving)") 125 | -------------------------------------------------------------------------------- /speaker_encoder/visualizations.py: -------------------------------------------------------------------------------- 1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from datetime import datetime 3 | from time import perf_counter as timer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | # import webbrowser 7 | import visdom 8 | import umap 9 | 10 | colormap = np.array([ 11 | [76, 255, 0], 12 | [0, 127, 70], 13 | [255, 0, 0], 14 | [255, 217, 38], 15 | [0, 135, 255], 16 | [165, 0, 165], 17 | [255, 167, 255], 18 | [0, 255, 255], 19 | [255, 96, 38], 20 | [142, 76, 0], 21 | [33, 0, 127], 22 | [0, 0, 0], 23 | [183, 183, 183], 24 | ], dtype=np.float) / 255 25 | 26 | 27 | class Visualizations: 28 | def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): 29 | # Tracking data 30 | self.last_update_timestamp = timer() 31 | self.update_every = update_every 32 | self.step_times = [] 33 | self.losses = [] 34 | self.eers = [] 35 | print("Updating the visualizations every %d steps." % update_every) 36 | 37 | # If visdom is disabled TODO: use a better paradigm for that 38 | self.disabled = disabled 39 | if self.disabled: 40 | return 41 | 42 | # Set the environment name 43 | now = str(datetime.now().strftime("%d-%m %Hh%M")) 44 | if env_name is None: 45 | self.env_name = now 46 | else: 47 | self.env_name = "%s (%s)" % (env_name, now) 48 | 49 | # Connect to visdom and open the corresponding window in the browser 50 | try: 51 | self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) 52 | except ConnectionError: 53 | raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " 54 | "start it.") 55 | # webbrowser.open("http://localhost:8097/env/" + self.env_name) 56 | 57 | # Create the windows 58 | self.loss_win = None 59 | self.eer_win = None 60 | # self.lr_win = None 61 | self.implementation_win = None 62 | self.projection_win = None 63 | self.implementation_string = "" 64 | 65 | def log_params(self): 66 | if self.disabled: 67 | return 68 | from speaker_encoder import params_data 69 | from speaker_encoder import params_model 70 | param_string = "Model parameters:
" 71 | for param_name in (p for p in dir(params_model) if not p.startswith("__")): 72 | value = getattr(params_model, param_name) 73 | param_string += "\t%s: %s
" % (param_name, value) 74 | param_string += "Data parameters:
" 75 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 76 | value = getattr(params_data, param_name) 77 | param_string += "\t%s: %s
" % (param_name, value) 78 | self.vis.text(param_string, opts={"title": "Parameters"}) 79 | 80 | def log_dataset(self, dataset: SpeakerVerificationDataset): 81 | if self.disabled: 82 | return 83 | dataset_string = "" 84 | dataset_string += "Speakers: %s\n" % len(dataset.speakers) 85 | dataset_string += "\n" + dataset.get_logs() 86 | dataset_string = dataset_string.replace("\n", "
") 87 | self.vis.text(dataset_string, opts={"title": "Dataset"}) 88 | 89 | def log_implementation(self, params): 90 | if self.disabled: 91 | return 92 | implementation_string = "" 93 | for param, value in params.items(): 94 | implementation_string += "%s: %s\n" % (param, value) 95 | implementation_string = implementation_string.replace("\n", "
") 96 | self.implementation_string = implementation_string 97 | self.implementation_win = self.vis.text( 98 | implementation_string, 99 | opts={"title": "Training implementation"} 100 | ) 101 | 102 | def update(self, loss, eer, step): 103 | # Update the tracking data 104 | now = timer() 105 | self.step_times.append(1000 * (now - self.last_update_timestamp)) 106 | self.last_update_timestamp = now 107 | self.losses.append(loss) 108 | self.eers.append(eer) 109 | print(".", end="") 110 | 111 | # Update the plots every steps 112 | if step % self.update_every != 0: 113 | return 114 | time_string = "Step time: mean: %5dms std: %5dms" % \ 115 | (int(np.mean(self.step_times)), int(np.std(self.step_times))) 116 | print("\nStep %6d Loss: %.4f EER: %.4f %s" % 117 | (step, np.mean(self.losses), np.mean(self.eers), time_string)) 118 | if not self.disabled: 119 | self.loss_win = self.vis.line( 120 | [np.mean(self.losses)], 121 | [step], 122 | win=self.loss_win, 123 | update="append" if self.loss_win else None, 124 | opts=dict( 125 | legend=["Avg. loss"], 126 | xlabel="Step", 127 | ylabel="Loss", 128 | title="Loss", 129 | ) 130 | ) 131 | self.eer_win = self.vis.line( 132 | [np.mean(self.eers)], 133 | [step], 134 | win=self.eer_win, 135 | update="append" if self.eer_win else None, 136 | opts=dict( 137 | legend=["Avg. EER"], 138 | xlabel="Step", 139 | ylabel="EER", 140 | title="Equal error rate" 141 | ) 142 | ) 143 | if self.implementation_win is not None: 144 | self.vis.text( 145 | self.implementation_string + ("%s" % time_string), 146 | win=self.implementation_win, 147 | opts={"title": "Training implementation"}, 148 | ) 149 | 150 | # Reset the tracking 151 | self.losses.clear() 152 | self.eers.clear() 153 | self.step_times.clear() 154 | 155 | def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, 156 | max_speakers=10): 157 | max_speakers = min(max_speakers, len(colormap)) 158 | embeds = embeds[:max_speakers * utterances_per_speaker] 159 | 160 | n_speakers = len(embeds) // utterances_per_speaker 161 | ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) 162 | colors = [colormap[i] for i in ground_truth] 163 | 164 | reducer = umap.UMAP() 165 | projected = reducer.fit_transform(embeds) 166 | plt.scatter(projected[:, 0], projected[:, 1], c=colors) 167 | plt.gca().set_aspect("equal", "datalim") 168 | plt.title("UMAP projection (step %d)" % step) 169 | if not self.disabled: 170 | self.projection_win = self.vis.matplot(plt, win=self.projection_win) 171 | if out_fpath is not None: 172 | plt.savefig(out_fpath) 173 | plt.clf() 174 | 175 | def save(self): 176 | if not self.disabled: 177 | self.vis.save([self.env_name]) 178 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .mel_decoder_mol_encAddlf0 import MelDecoderMOL 2 | from .mel_decoder_mol_v2 import MelDecoderMOLv2 3 | from .rnn_ppg2mel import BiRnnPpg2MelModel 4 | from .mel_decoder_lsa import MelDecoderLSA 5 | 6 | 7 | def build_model(model_name: str): 8 | if model_name == "seq2seqmol": 9 | return MelDecoderMOL 10 | elif model_name == "seq2seqmolv2": 11 | return MelDecoderMOLv2 12 | elif model_name == "bilstm": 13 | return BiRnnPpg2MelModel 14 | elif model_name == "seq2seqlsa": 15 | return MelDecoderLSA 16 | else: 17 | raise ValueError(f"Unknown model name: {model_name}.") 18 | -------------------------------------------------------------------------------- /src/abs_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | from typing import Tuple 4 | 5 | import torch 6 | 7 | 8 | class AbsMelDecoder(torch.nn.Module, ABC): 9 | """The abstract PPG-based voice conversion class 10 | This "model" is one of mediator objects for "Task" class. 11 | 12 | """ 13 | 14 | @abstractmethod 15 | def forward( 16 | self, 17 | bottle_neck_features: torch.Tensor, 18 | feature_lengths: torch.Tensor, 19 | speech: torch.Tensor, 20 | speech_lengths: torch.Tensor, 21 | logf0_uv: torch.Tensor = None, 22 | spembs: torch.Tensor = None, 23 | styleembs: torch.Tensor = None, 24 | ) -> torch.Tensor: 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /src/audio_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | from librosa.util import normalize 8 | from scipy.io.wavfile import read 9 | from librosa.filters import mel as librosa_mel_fn 10 | 11 | MAX_WAV_VALUE = 32768.0 12 | 13 | 14 | def load_wav(full_path): 15 | sampling_rate, data = read(full_path) 16 | return data, sampling_rate 17 | 18 | 19 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 20 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 21 | 22 | 23 | def dynamic_range_decompression(x, C=1): 24 | return np.exp(x) / C 25 | 26 | 27 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 28 | return torch.log(torch.clamp(x, min=clip_val) * C) 29 | 30 | 31 | def dynamic_range_decompression_torch(x, C=1): 32 | return torch.exp(x) / C 33 | 34 | 35 | def spectral_normalize_torch(magnitudes): 36 | output = dynamic_range_compression_torch(magnitudes) 37 | return output 38 | 39 | 40 | def spectral_de_normalize_torch(magnitudes): 41 | output = dynamic_range_decompression_torch(magnitudes) 42 | return output 43 | 44 | 45 | mel_basis = {} 46 | hann_window = {} 47 | 48 | 49 | def mel_spectrogram( 50 | y, 51 | n_fft=1024, 52 | num_mels=80, 53 | sampling_rate=24000, 54 | hop_size=240, 55 | win_size=1024, 56 | fmin=0, 57 | fmax=8000, 58 | center=False, 59 | output_energy=False, 60 | ): 61 | if torch.min(y) < -1.: 62 | print('min value is ', torch.min(y)) 63 | if torch.max(y) > 1.: 64 | print('max value is ', torch.max(y)) 65 | 66 | global mel_basis, hann_window 67 | if fmax not in mel_basis: 68 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 69 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) 70 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 71 | 72 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 73 | y = y.squeeze(1) 74 | 75 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], 76 | center=center, pad_mode='reflect', normalized=False, onesided=True) 77 | 78 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) 79 | mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) 80 | mel_spec = spectral_normalize_torch(mel_spec) 81 | if output_energy: 82 | energy = torch.norm(spec, dim=1) 83 | return mel_spec, energy 84 | else: 85 | return mel_spec 86 | 87 | 88 | # def get_dataset_filelist(a): 89 | # with open(a.input_training_file, 'r', encoding='utf-8') as fi: 90 | # training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') 91 | # for x in fi.read().split('\n') if len(x) > 0] 92 | 93 | # with open(a.input_validation_file, 'r', encoding='utf-8') as fi: 94 | # validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') 95 | # for x in fi.read().split('\n') if len(x) > 0] 96 | # return training_files, validation_files 97 | 98 | 99 | # class MelDataset(torch.utils.data.Dataset): 100 | # def __init__(self, training_files, segment_size, n_fft, num_mels, 101 | # hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, 102 | # device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): 103 | # self.audio_files = training_files 104 | # random.seed(1234) 105 | # if shuffle: 106 | # random.shuffle(self.audio_files) 107 | # self.segment_size = segment_size 108 | # self.sampling_rate = sampling_rate 109 | # self.split = split 110 | # self.n_fft = n_fft 111 | # self.num_mels = num_mels 112 | # self.hop_size = hop_size 113 | # self.win_size = win_size 114 | # self.fmin = fmin 115 | # self.fmax = fmax 116 | # self.fmax_loss = fmax_loss 117 | # self.cached_wav = None 118 | # self.n_cache_reuse = n_cache_reuse 119 | # self._cache_ref_count = 0 120 | # self.device = device 121 | # self.fine_tuning = fine_tuning 122 | # self.base_mels_path = base_mels_path 123 | 124 | # def __getitem__(self, index): 125 | # filename = self.audio_files[index] 126 | # if self._cache_ref_count == 0: 127 | # audio, sampling_rate = load_wav(filename) 128 | # audio = audio / MAX_WAV_VALUE 129 | # if not self.fine_tuning: 130 | # audio = normalize(audio) * 0.95 131 | # self.cached_wav = audio 132 | # if sampling_rate != self.sampling_rate: 133 | # raise ValueError("{} SR doesn't match target {} SR".format( 134 | # sampling_rate, self.sampling_rate)) 135 | # self._cache_ref_count = self.n_cache_reuse 136 | # else: 137 | # audio = self.cached_wav 138 | # self._cache_ref_count -= 1 139 | 140 | # audio = torch.FloatTensor(audio) 141 | # audio = audio.unsqueeze(0) 142 | 143 | # if not self.fine_tuning: 144 | # if self.split: 145 | # if audio.size(1) >= self.segment_size: 146 | # max_audio_start = audio.size(1) - self.segment_size 147 | # audio_start = random.randint(0, max_audio_start) 148 | # audio = audio[:, audio_start:audio_start+self.segment_size] 149 | # else: 150 | # audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 151 | 152 | # mel = mel_spectrogram(audio, self.n_fft, self.num_mels, 153 | # self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, 154 | # center=False) 155 | # else: 156 | # mel = np.load( 157 | # os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy')) 158 | # mel = torch.from_numpy(mel) 159 | 160 | # if len(mel.shape) < 3: 161 | # mel = mel.unsqueeze(0) 162 | 163 | # if self.split: 164 | # frames_per_seg = math.ceil(self.segment_size / self.hop_size) 165 | 166 | # if audio.size(1) >= self.segment_size: 167 | # mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 168 | # mel = mel[:, :, mel_start:mel_start + frames_per_seg] 169 | # audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size] 170 | # else: 171 | # mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant') 172 | # audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') 173 | 174 | # mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, 175 | # self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, 176 | # center=False) 177 | 178 | # return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 179 | 180 | # def __len__(self): 181 | # return len(self.audio_files) 182 | -------------------------------------------------------------------------------- /src/basic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Function 5 | 6 | def tile(x, count, dim=0): 7 | """ 8 | Tiles x on dimension dim count times. 9 | """ 10 | perm = list(range(len(x.size()))) 11 | if dim != 0: 12 | perm[0], perm[dim] = perm[dim], perm[0] 13 | x = x.permute(perm).contiguous() 14 | out_size = list(x.size()) 15 | out_size[0] *= count 16 | batch = x.size(0) 17 | x = x.view(batch, -1) \ 18 | .transpose(0, 1) \ 19 | .repeat(count, 1) \ 20 | .transpose(0, 1) \ 21 | .contiguous() \ 22 | .view(*out_size) 23 | if dim != 0: 24 | x = x.permute(perm).contiguous() 25 | return x 26 | 27 | class Linear(torch.nn.Module): 28 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 29 | super(Linear, self).__init__() 30 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 31 | 32 | torch.nn.init.xavier_uniform_( 33 | self.linear_layer.weight, 34 | gain=torch.nn.init.calculate_gain(w_init_gain)) 35 | 36 | def forward(self, x): 37 | return self.linear_layer(x) 38 | 39 | class Conv1d(torch.nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 41 | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): 42 | super(Conv1d, self).__init__() 43 | if padding is None: 44 | assert(kernel_size % 2 == 1) 45 | padding = int(dilation * (kernel_size - 1)/2) 46 | 47 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 48 | kernel_size=kernel_size, stride=stride, 49 | padding=padding, dilation=dilation, 50 | bias=bias) 51 | torch.nn.init.xavier_uniform_( 52 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 53 | 54 | def forward(self, x): 55 | # x: BxDxT 56 | return self.conv(x) 57 | 58 | class Flatten(nn.Module): 59 | def forward(self, x): 60 | return x.view(x.size(0), -1) 61 | 62 | # Reshape layer 63 | class Reshape(nn.Module): 64 | def __init__(self, outer_shape): 65 | super(Reshape, self).__init__() 66 | self.outer_shape = outer_shape 67 | def forward(self, x): 68 | return x.view(x.size(0), *self.outer_shape) 69 | 70 | # Sample from the Gumbel-Softmax distribution and optionally discretize. 71 | class GumbelSoftmax(nn.Module): 72 | 73 | def __init__(self, f_dim, c_dim): 74 | super(GumbelSoftmax, self).__init__() 75 | self.logits = nn.Linear(f_dim, c_dim) 76 | self.f_dim = f_dim 77 | self.c_dim = c_dim 78 | 79 | def sample_gumbel(self, shape, is_cuda=False, eps=1e-20): 80 | U = torch.rand(shape) 81 | if is_cuda: 82 | U = U.cuda() 83 | return -torch.log(-torch.log(U + eps) + eps) 84 | 85 | def gumbel_softmax_sample(self, logits, temperature): 86 | y = logits + self.sample_gumbel(logits.size(), logits.is_cuda) 87 | return F.softmax(y / temperature, dim=-1) 88 | 89 | def gumbel_softmax(self, logits, temperature, hard=False): 90 | """ 91 | ST-gumple-softmax 92 | input: [*, n_class] 93 | return: flatten --> [*, n_class] an one-hot vector 94 | """ 95 | #categorical_dim = 10 96 | y = self.gumbel_softmax_sample(logits, temperature) 97 | 98 | if not hard: 99 | return y 100 | 101 | shape = y.size() 102 | _, ind = y.max(dim=-1) 103 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 104 | y_hard.scatter_(1, ind.view(-1, 1), 1) 105 | y_hard = y_hard.view(*shape) 106 | # Set gradients w.r.t. y_hard gradients w.r.t. y 107 | y_hard = (y_hard - y).detach() + y 108 | return y_hard 109 | 110 | def forward(self, x, temperature=1.0, hard=False): 111 | logits = self.logits(x).view(-1, self.c_dim) 112 | prob = F.softmax(logits, dim=-1) 113 | y = self.gumbel_softmax(logits, temperature, hard) 114 | return logits, prob, y 115 | 116 | class Softmax(nn.Module): 117 | 118 | def __init__(self, in_dim, out_dim): 119 | super().__init__() 120 | self.logits = nn.Linear(in_dim, out_dim) 121 | self.in_dim = in_dim 122 | self.out_dim = out_dim 123 | 124 | def forward(self, x): 125 | logits = self.logits(x).view(-1, self.out_dim) 126 | prob = F.softmax(logits, dim=-1) 127 | return logits, prob 128 | 129 | # Sample from a Gaussian distribution 130 | class Gaussian(nn.Module): 131 | def __init__(self, in_dim, z_dim, use_bias=True): 132 | super(Gaussian, self).__init__() 133 | self.mu = nn.Linear(in_dim, z_dim, bias=use_bias) 134 | self.log_std = nn.Linear(in_dim, z_dim, bias=use_bias) 135 | # self.mu.weight.data.fill_(0.0) 136 | # self.log_std.weight.data.fill_(0.0) 137 | # if use_bias: 138 | # self.mu.bias.data.fill_(0.0) 139 | # self.log_std.bias.data.fill_(0.0) 140 | 141 | def reparameterize(self, mu, log_std): 142 | std = torch.exp(log_std) 143 | noise = torch.randn_like(std) 144 | z = mu + noise * std 145 | return z 146 | 147 | def forward(self, x): 148 | mu = self.mu(x) 149 | log_std = self.log_std(x) 150 | z = self.reparameterize(mu, log_std) 151 | return mu, log_std, z 152 | 153 | # https://github.com/fungtion/DANN/blob/master/models/functions.py 154 | class ReverseLayerF(Function): 155 | 156 | @staticmethod 157 | def forward(ctx, x, alpha): 158 | ctx.alpha = alpha 159 | return x.view_as(x) 160 | 161 | @staticmethod 162 | def backward(ctx, grad_output): 163 | output = grad_output.neg() * ctx.alpha 164 | return output, None 165 | # class Gaussian(nn.Module): 166 | # def __init__(self, in_dim, z_dim): 167 | # super(Gaussian, self).__init__() 168 | # self.mu = nn.Linear(in_dim, z_dim) 169 | # self.var = nn.Linear(in_dim, z_dim) 170 | 171 | # def reparameterize(self, mu, var): 172 | # std = torch.sqrt(var + 1e-10) 173 | # noise = torch.randn_like(std) 174 | # z = mu + noise * std 175 | # return z 176 | 177 | # def forward(self, x): 178 | # mu = self.mu(x) 179 | # var = F.softplus(self.var(x)) 180 | # z = self.reparameterize(mu, var) 181 | # return mu, var, z 182 | 183 | def tile(x, count, dim=0): 184 | """ 185 | Tiles x on dimension dim count times. 186 | """ 187 | perm = list(range(len(x.size()))) 188 | if dim != 0: 189 | perm[0], perm[dim] = perm[dim], perm[0] 190 | x = x.permute(perm).contiguous() 191 | out_size = list(x.size()) 192 | out_size[0] *= count 193 | batch = x.size(0) 194 | x = x.view(batch, -1) \ 195 | .transpose(0, 1) \ 196 | .repeat(count, 1) \ 197 | .transpose(0, 1) \ 198 | .contiguous() \ 199 | .view(*out_size) 200 | if dim != 0: 201 | x = x.permute(perm).contiguous() 202 | return x 203 | 204 | 205 | def sort_batch(data, lengths): 206 | ''' 207 | sort data by length 208 | sorted_data[initial_index] == data 209 | ''' 210 | sorted_lengths, sorted_index = lengths.sort(0, descending=True) 211 | sorted_data = data[sorted_index] 212 | _, initial_index = sorted_index.sort(0, descending=False) 213 | 214 | return sorted_data, sorted_lengths, initial_index 215 | -------------------------------------------------------------------------------- /src/cnn_postnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .basic_layers import Linear, Conv1d 5 | 6 | 7 | class Postnet(nn.Module): 8 | """Postnet 9 | - Five 1-d convolution with 512 channels and kernel size 5 10 | """ 11 | def __init__(self, num_mels=80, 12 | num_layers=5, 13 | hidden_dim=512, 14 | kernel_size=5): 15 | super(Postnet, self).__init__() 16 | self.convolutions = nn.ModuleList() 17 | 18 | self.convolutions.append( 19 | nn.Sequential( 20 | Conv1d( 21 | num_mels, hidden_dim, 22 | kernel_size=kernel_size, stride=1, 23 | padding=int((kernel_size - 1) / 2), 24 | dilation=1, w_init_gain='tanh'), 25 | nn.BatchNorm1d(hidden_dim))) 26 | 27 | for i in range(1, num_layers - 1): 28 | self.convolutions.append( 29 | nn.Sequential( 30 | Conv1d( 31 | hidden_dim, 32 | hidden_dim, 33 | kernel_size=kernel_size, stride=1, 34 | padding=int((kernel_size - 1) / 2), 35 | dilation=1, w_init_gain='tanh'), 36 | nn.BatchNorm1d(hidden_dim))) 37 | 38 | self.convolutions.append( 39 | nn.Sequential( 40 | Conv1d( 41 | hidden_dim, num_mels, 42 | kernel_size=kernel_size, stride=1, 43 | padding=int((kernel_size - 1) / 2), 44 | dilation=1, w_init_gain='linear'), 45 | nn.BatchNorm1d(num_mels))) 46 | 47 | def forward(self, x): 48 | # x: (B, num_mels, T_dec) 49 | for i in range(len(self.convolutions) - 1): 50 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 51 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 52 | return x 53 | -------------------------------------------------------------------------------- /src/f0_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import pyworld 4 | from scipy.interpolate import interp1d 5 | from scipy.signal import firwin, get_window, lfilter 6 | 7 | 8 | def compute_f0(wav, sr=16000, frame_period=10.0): 9 | """Compute f0 from wav using pyworld harvest algorithm.""" 10 | wav = wav.astype(np.float64) 11 | f0, _ = pyworld.harvest( 12 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0) 13 | return f0.astype(np.float32) 14 | 15 | 16 | def low_pass_filter(x, fs, cutoff=70, padding=True): 17 | """FUNCTION TO APPLY LOW PASS FILTER 18 | 19 | Args: 20 | x (ndarray): Waveform sequence 21 | fs (int): Sampling frequency 22 | cutoff (float): Cutoff frequency of low pass filter 23 | 24 | Return: 25 | (ndarray): Low pass filtered waveform sequence 26 | """ 27 | 28 | nyquist = fs // 2 29 | norm_cutoff = cutoff / nyquist 30 | 31 | # low cut filter 32 | numtaps = 255 33 | fil = firwin(numtaps, norm_cutoff) 34 | x_pad = np.pad(x, (numtaps, numtaps), 'edge') 35 | lpf_x = lfilter(fil, 1, x_pad) 36 | lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2] 37 | 38 | return lpf_x 39 | 40 | 41 | def convert_continuos_f0(f0): 42 | """CONVERT F0 TO CONTINUOUS F0 43 | 44 | Args: 45 | f0 (ndarray): original f0 sequence with the shape (T) 46 | 47 | Return: 48 | (ndarray): continuous f0 with the shape (T) 49 | """ 50 | # get uv information as binary 51 | uv = np.float32(f0 != 0) 52 | 53 | # get start and end of f0 54 | if (f0 == 0).all(): 55 | logging.warn("all of the f0 values are 0.") 56 | return uv, f0 57 | start_f0 = f0[f0 != 0][0] 58 | end_f0 = f0[f0 != 0][-1] 59 | 60 | # padding start and end of f0 sequence 61 | start_idx = np.where(f0 == start_f0)[0][0] 62 | end_idx = np.where(f0 == end_f0)[0][-1] 63 | f0[:start_idx] = start_f0 64 | f0[end_idx:] = end_f0 65 | 66 | # get non-zero frame index 67 | nz_frames = np.where(f0 != 0)[0] 68 | 69 | # perform linear interpolation 70 | f = interp1d(nz_frames, f0[nz_frames]) 71 | cont_f0 = f(np.arange(0, f0.shape[0])) 72 | 73 | return uv, cont_f0 74 | 75 | 76 | def get_cont_lf0(f0, frame_period=10.0, lpf=False): 77 | uv, cont_f0 = convert_continuos_f0(f0) 78 | if lpf: 79 | cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20) 80 | cont_lf0_lpf = cont_f0_lpf.copy() 81 | nonzero_indices = np.nonzero(cont_lf0_lpf) 82 | cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices]) 83 | # cont_lf0_lpf = np.log(cont_f0_lpf) 84 | return uv, cont_lf0_lpf 85 | else: 86 | nonzero_indices = np.nonzero(cont_f0) 87 | cont_lf0 = cont_f0.copy() 88 | cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0]) 89 | return uv, cont_lf0 90 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .nets_utils import make_pad_mask 9 | 10 | 11 | class MaskedMSELoss(nn.Module): 12 | def __init__(self, frames_per_step): 13 | super().__init__() 14 | self.frames_per_step = frames_per_step 15 | self.mel_loss_criterion = nn.MSELoss(reduction='none') 16 | # self.loss = nn.MSELoss() 17 | self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none') 18 | 19 | def get_mask(self, lengths, max_len=None): 20 | # lengths: [B,] 21 | if max_len is None: 22 | max_len = torch.max(lengths) 23 | batch_size = lengths.size(0) 24 | seq_range = torch.arange(0, max_len).long() 25 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device) 26 | seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand) 27 | return (seq_range_expand < seq_length_expand).float() 28 | 29 | def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths, 30 | stop_target, stop_pred): 31 | ## process stop_target 32 | B = stop_target.size(0) 33 | stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0] 34 | stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long() 35 | stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step)) 36 | 37 | mel_trg.requires_grad = False 38 | # (B, T, 1) 39 | mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1) 40 | # (B, T, D) 41 | mel_mask = mel_mask.expand_as(mel_trg) 42 | mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum() 43 | mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum() 44 | 45 | mel_loss = mel_loss_pre + mel_loss_post 46 | 47 | # stop token loss 48 | stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum() 49 | 50 | return mel_loss, stop_loss 51 | -------------------------------------------------------------------------------- /src/loss_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MaskedMSELoss(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.loss = nn.MSELoss(reduction='none') 9 | 10 | def get_mask(self, lengths): 11 | # lengths: [B,] 12 | max_len = torch.max(lengths) 13 | batch_size = lengths.size(0) 14 | seq_range = torch.arange(0, max_len).long() 15 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device) 16 | seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand) 17 | return (seq_range_expand < seq_length_expand).float() 18 | 19 | def forward(self, mel_pred, mel_trg, lengths): 20 | # (B, T, 1) 21 | mask = self.get_mask(lengths).unsqueeze(-1) 22 | # (B, T, D) 23 | mask_ = mask.expand_as(mel_trg) 24 | loss = self.loss(mel_pred, mel_trg) 25 | return ((loss * mask_).sum()) / mask_.sum() 26 | -------------------------------------------------------------------------------- /src/lsa_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .basic_layers import Linear, Conv1d 5 | 6 | 7 | 8 | class LocationLayer(nn.Module): 9 | def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): 10 | super().__init__() 11 | padding = int((attention_kernel_size - 1) / 2) 12 | self.location_conv = Conv1d(2, attention_n_filters, 13 | kernel_size=attention_kernel_size, 14 | padding=padding, bias=False, 15 | stride=1, dilation=1) 16 | self.location_dense = Linear(attention_n_filters, attention_dim, 17 | bias=False, w_init_gain='tanh') 18 | 19 | def forward(self, attention_weights_cat): 20 | processed_attention = self.location_conv(attention_weights_cat) 21 | processed_attention = processed_attention.transpose(1, 2) 22 | processed_attention = self.location_dense(processed_attention) 23 | return processed_attention 24 | 25 | class LocationSensitiveAttention(nn.Module): 26 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 27 | attention_location_n_filters, attention_location_kernel_size): 28 | super().__init__() 29 | self.query_layer = Linear(attention_rnn_dim, attention_dim, 30 | bias=False, w_init_gain='tanh') 31 | self.memory_layer = Linear(embedding_dim, attention_dim, 32 | bias=False, w_init_gain='tanh') 33 | self.v = Linear(attention_dim, 1, bias=False) 34 | self.location_layer = LocationLayer(attention_location_n_filters, 35 | attention_location_kernel_size, 36 | attention_dim) 37 | self.score_mask_value = -float('inf') 38 | 39 | def get_alignment_energies(self, query, processed_memory, 40 | attention_weights_cat): 41 | """ 42 | Args: 43 | query: attention rnn output (B, att_rnn_dim) 44 | processed_memory: processed encoder outputs (B, T_in, enc_dim) 45 | attention_weights_cat: cumulative and previous attention weights (B, 2, T_in) 46 | Returns: 47 | attention alignment (B, T_in) 48 | """ 49 | processed_query = self.query_layer(query.unsqueeze(1)) 50 | processed_attention_weights = self.location_layer(attention_weights_cat) 51 | energies = self.v(torch.tanh( 52 | processed_query + processed_attention_weights + processed_memory)) 53 | energies = energies.squeeze(-1) 54 | return energies 55 | 56 | def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask): 57 | """ 58 | Args: 59 | attention_hidden_state: attention rnn last output 60 | memory: encoder outputs 61 | processed_memory: processed encoder outputs 62 | attention_weights_cat: previous and cumulative attention weights 63 | mask: binary mask for padded data 64 | """ 65 | alignment = self.get_alignment_energies( 66 | attention_hidden_state, processed_memory, attention_weights_cat) 67 | 68 | if mask is not None: 69 | alignment.data.masked_fill_(mask, self.score_mask_value) 70 | 71 | attention_weights = F.softmax(alignment, dim=1) 72 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 73 | attention_context = attention_context.squeeze(1) 74 | 75 | return attention_context, attention_weights 76 | 77 | 78 | class ForwardAttentionV2(nn.Module): 79 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 80 | attention_location_n_filters, attention_location_kernel_size): 81 | super(ForwardAttentionV2, self).__init__() 82 | self.query_layer = Linear(attention_rnn_dim, attention_dim, 83 | bias=False, w_init_gain='tanh') 84 | self.memory_layer = Linear(embedding_dim, attention_dim, bias=False, 85 | w_init_gain='tanh') 86 | self.v = Linear(attention_dim, 1, bias=False) 87 | self.location_layer = LocationLayer(attention_location_n_filters, 88 | attention_location_kernel_size, 89 | attention_dim) 90 | self.score_mask_value = -float(1e20) 91 | 92 | def get_alignment_energies(self, query, processed_memory, 93 | attention_weights_cat): 94 | """ 95 | PARAMS 96 | ------ 97 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 98 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 99 | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) 100 | RETURNS 101 | ------- 102 | alignment (batch, max_time) 103 | """ 104 | 105 | processed_query = self.query_layer(query.unsqueeze(1)) 106 | processed_attention_weights = self.location_layer(attention_weights_cat) 107 | energies = self.v(torch.tanh( 108 | processed_query + processed_attention_weights + processed_memory)) 109 | 110 | energies = energies.squeeze(-1) 111 | return energies 112 | 113 | def forward(self, attention_hidden_state, memory, processed_memory, 114 | attention_weights_cat, mask, log_alpha): 115 | """ 116 | PARAMS 117 | ------ 118 | attention_hidden_state: attention rnn last output 119 | memory: encoder outputs 120 | processed_memory: processed encoder outputs 121 | attention_weights_cat: previous and cummulative attention weights 122 | mask: binary mask for padded data 123 | """ 124 | log_energy = self.get_alignment_energies( 125 | attention_hidden_state, processed_memory, attention_weights_cat) 126 | 127 | #log_energy = 128 | 129 | if mask is not None: 130 | log_energy.data.masked_fill_(mask, self.score_mask_value) 131 | 132 | #attention_weights = F.softmax(alignment, dim=1) 133 | 134 | #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] 135 | #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] 136 | 137 | #log_total_score = log_alpha + content_score 138 | 139 | #previous_attention_weights = attention_weights_cat[:,0,:] 140 | 141 | log_alpha_shift_padded = [] 142 | max_time = log_energy.size(1) 143 | for sft in range(2): 144 | shifted = log_alpha[:,:max_time-sft] 145 | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) 146 | log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) 147 | 148 | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) 149 | 150 | log_alpha_new = biased + log_energy 151 | 152 | attention_weights = F.softmax(log_alpha_new, dim=1) 153 | 154 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 155 | attention_context = attention_context.squeeze(1) 156 | 157 | return attention_context, attention_weights, log_alpha_new 158 | 159 | -------------------------------------------------------------------------------- /src/mol_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MOLAttention(nn.Module): 7 | """ Discretized Mixture of Logistic (MOL) attention. 8 | C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and 9 | GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis". 10 | """ 11 | def __init__( 12 | self, 13 | query_dim, 14 | r=1, 15 | M=5, 16 | ): 17 | """ 18 | Args: 19 | query_dim: attention_rnn_dim. 20 | M: number of mixtures. 21 | """ 22 | super().__init__() 23 | if r < 1: 24 | self.r = float(r) 25 | else: 26 | self.r = int(r) 27 | self.M = M 28 | self.score_mask_value = 0.0 # -float("inf") 29 | self.eps = 1e-5 30 | # Position arrary for encoder time steps 31 | self.J = None 32 | # Query layer: [w, sigma,] 33 | self.query_layer = torch.nn.Sequential( 34 | nn.Linear(query_dim, 256, bias=True), 35 | nn.ReLU(), 36 | nn.Linear(256, 3*M, bias=True) 37 | ) 38 | self.mu_prev = None 39 | self.initialize_bias() 40 | 41 | def initialize_bias(self): 42 | """Initialize sigma and Delta.""" 43 | # sigma 44 | torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0) 45 | # Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0 46 | # softplus(-0.432) = 0.5003 47 | if self.r == 2: 48 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545) 49 | elif self.r == 4: 50 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815) 51 | elif self.r == 1: 52 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413) 53 | else: 54 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432) 55 | 56 | 57 | def init_states(self, memory): 58 | """Initialize mu_prev and J. 59 | This function should be called by the decoder before decoding one batch. 60 | Args: 61 | memory: (B, T, D_enc) encoder output. 62 | """ 63 | B, T_enc, _ = memory.size() 64 | device = memory.device 65 | self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage 66 | # self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float) 67 | self.mu_prev = torch.zeros(B, self.M).to(device) 68 | 69 | def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None): 70 | """ 71 | att_rnn_h: attetion rnn hidden state. 72 | memory: encoder outputs (B, T_enc, D). 73 | mask: binary mask for padded data (B, T_enc). 74 | """ 75 | # [B, 3M] 76 | mixture_params = self.query_layer(att_rnn_h) 77 | 78 | # [B, M] 79 | w_hat = mixture_params[:, :self.M] 80 | sigma_hat = mixture_params[:, self.M:2*self.M] 81 | Delta_hat = mixture_params[:, 2*self.M:3*self.M] 82 | 83 | # print("w_hat: ", w_hat) 84 | # print("sigma_hat: ", sigma_hat) 85 | # print("Delta_hat: ", Delta_hat) 86 | 87 | # Dropout to de-correlate attention heads 88 | w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed? 89 | 90 | # Mixture parameters 91 | w = torch.softmax(w_hat, dim=-1) + self.eps 92 | sigma = F.softplus(sigma_hat) + self.eps 93 | Delta = F.softplus(Delta_hat) 94 | mu_cur = self.mu_prev + Delta 95 | # print("w:", w) 96 | j = self.J[:memory.size(1) + 1] 97 | 98 | # Attention weights 99 | # CDF of logistic distribution 100 | phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid( 101 | (mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1)))) 102 | # print("phi_t:", phi_t) 103 | 104 | # Discretize attention weights 105 | # (B, T_enc + 1) 106 | alpha_t = torch.sum(phi_t, dim=1) 107 | alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] 108 | alpha_t[alpha_t == 0] = self.eps 109 | # print("alpha_t: ", alpha_t.size()) 110 | # Apply masking 111 | if mask is not None: 112 | alpha_t.data.masked_fill_(mask, self.score_mask_value) 113 | 114 | context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1) 115 | if memory_pitch is not None: 116 | context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1) 117 | 118 | self.mu_prev = mu_cur 119 | 120 | if memory_pitch is not None: 121 | return context, context_pitch, alpha_t 122 | return context, alpha_t 123 | 124 | -------------------------------------------------------------------------------- /src/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Optimizer(): 6 | def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, 7 | **kwargs): 8 | 9 | # Setup torch optimizer 10 | self.opt_type = optimizer 11 | self.init_lr = lr 12 | self.sch_type = lr_scheduler 13 | opt = getattr(torch.optim, optimizer) 14 | if lr_scheduler == 'warmup': 15 | warmup_step = 4000.0 16 | init_lr = lr 17 | self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \ 18 | np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5) 19 | self.opt = opt(parameters, lr=1.0) 20 | else: 21 | self.lr_scheduler = None 22 | self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better? 23 | 24 | def get_opt_state_dict(self): 25 | return self.opt.state_dict() 26 | 27 | def load_opt_state_dict(self, state_dict): 28 | self.opt.load_state_dict(state_dict) 29 | 30 | def pre_step(self, step): 31 | if self.lr_scheduler is not None: 32 | cur_lr = self.lr_scheduler(step) 33 | for param_group in self.opt.param_groups: 34 | param_group['lr'] = cur_lr 35 | else: 36 | cur_lr = self.init_lr 37 | self.opt.zero_grad() 38 | return cur_lr 39 | 40 | def step(self): 41 | self.opt.step() 42 | 43 | def create_msg(self): 44 | return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})' 45 | .format(self.opt_type, self.init_lr, self.sch_type)] 46 | -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | # Default parameters which will be imported by solver 2 | default_hparas = { 3 | 'GRAD_CLIP': 5.0, # Grad. clip threshold 4 | 'PROGRESS_STEP': 100, # Std. output refresh freq. 5 | # Decode steps for objective validation (step = ratio*input_txt_len) 6 | 'DEV_STEP_RATIO': 1.2, 7 | # Number of examples (alignment/text) to show in tensorboard 8 | 'DEV_N_EXAMPLE': 4, 9 | 'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs) 10 | } 11 | -------------------------------------------------------------------------------- /src/rnn_ppg2mel.py: -------------------------------------------------------------------------------- 1 | """Sequential implementation of Recurrent Neural Network Duration Model.""" 2 | from typing import Tuple 3 | from typing import Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | from typeguard import check_argument_types 8 | 9 | 10 | class BiRnnPpg2MelModel(torch.nn.Module): 11 | """ Bidirectional RNN-based PPG-to-Mel Model for voice conversion tasks. 12 | RNN could be LSTM-based or GRU-based. 13 | """ 14 | def __init__( 15 | self, 16 | input_size: int, 17 | multi_spk: bool = False, 18 | num_speakers: int = 1, 19 | spk_embed_dim: int = 256, 20 | use_spk_dvec: bool = False, 21 | multi_styles: bool = False, 22 | num_styles: int = 3, 23 | style_embed_dim: int = 256, 24 | dense_layer_size: int = 256, 25 | num_layers: int = 4, 26 | bidirectional: bool = True, 27 | hidden_dim: int = 256, 28 | dropout_rate: float = 0.5, 29 | output_size: int = 80, 30 | rnn_type: str = "lstm" 31 | ): 32 | assert check_argument_types() 33 | super().__init__() 34 | 35 | self.multi_spk = multi_spk 36 | self.spk_embed_dim = spk_embed_dim 37 | self.use_spk_dvec= use_spk_dvec 38 | self.multi_styles = multi_styles 39 | self.style_embed_dim = style_embed_dim 40 | self.hidden_dim = hidden_dim 41 | self.num_layers = num_layers 42 | self.num_direction = 2 if bidirectional else 1 43 | 44 | self.ppg_dense_layer = nn.Linear(input_size - 2, hidden_dim) 45 | self.logf0_uv_layer = nn.Linear(2, hidden_dim) 46 | 47 | projection_input_size = hidden_dim 48 | if self.multi_spk: 49 | if not self.use_spk_dvec: 50 | self.spk_embedding = nn.Embedding(num_speakers, spk_embed_dim) 51 | projection_input_size += self.spk_embed_dim 52 | if self.multi_styles: 53 | self.style_embedding = nn.Embedding(num_styles, style_embed_dim) 54 | projection_input_size += self.style_embed_dim 55 | 56 | self.reduce_proj = nn.Sequential( 57 | nn.Linear(projection_input_size, hidden_dim), 58 | nn.ReLU(), 59 | nn.Dropout(dropout_rate) 60 | ) 61 | 62 | rnn_type = rnn_type.upper() 63 | if rnn_type in ["LSTM", "GRU"]: 64 | rnn_class = getattr(nn, rnn_type) 65 | self.rnn = rnn_class( 66 | hidden_dim, hidden_dim, num_layers, 67 | bidirectional=bidirectional, 68 | dropout=dropout_rate, 69 | batch_first=True) 70 | else: 71 | # Default: use BiLSTM 72 | self.rnn = nn.LSTM( 73 | hidden_dim, hidden_dim, num_layers, 74 | bidirectional=bidirectional, 75 | dropout_rate=dropout_rate, 76 | batch_first=True) 77 | # Fully connected layers 78 | self.hidden2out_layers = nn.Sequential( 79 | nn.Linear(self.num_direction * hidden_dim, dense_layer_size), 80 | nn.ReLU(), 81 | nn.Dropout(dropout_rate), 82 | nn.Linear(dense_layer_size, output_size) 83 | ) 84 | 85 | def forward( 86 | self, 87 | ppg: torch.Tensor, 88 | ppg_lengths: torch.Tensor, 89 | logf0_uv: torch.Tensor, 90 | spembs: torch.Tensor = None, 91 | styleembs: torch.Tensor = None, 92 | ) -> torch.Tensor: 93 | """ 94 | Args: 95 | ppg (tensor): [B, T, D_ppg] 96 | ppg_lengths (tensor): [B,] 97 | logf0_uv (tensor): [B, T, 2], concatented logf0 and u/v flags. 98 | spembs (tensor): [B,] index-represented speaker. 99 | styleembs (tensor): [B,] index-repreented speaking style (e.g. emotion). 100 | """ 101 | ppg = self.ppg_dense_layer(ppg) 102 | logf0_uv = self.logf0_uv_layer(logf0_uv) 103 | 104 | ## Concatenate/add ppg and logf0_uv 105 | x = ppg + logf0_uv 106 | B, T, _ = x.size() 107 | 108 | if self.multi_spk: 109 | assert spembs is not None 110 | # spk_embs = self.spk_embedding(torch.LongTensor([0,]*ppg.size(0)).to(ppg.device)) 111 | if not self.use_spk_dvec: 112 | spk_embs = self.spk_embedding(spembs) 113 | spk_embs = torch.nn.functional.normalize( 114 | spk_embs).unsqueeze(1).expand(-1, T, -1) 115 | else: 116 | spk_embs = torch.nn.functional.normalize( 117 | spembs).unsqueeze(1).expand(-1, T, -1) 118 | x = torch.cat([x, spk_embs], dim=2) 119 | 120 | if self.multi_styles and styleembs is not None: 121 | style_embs = self.style_embedding(styleembs) 122 | style_embs = torch.nn.functional.normalize( 123 | style_embs).unsqueeze(1).expand(-1, T, -1) 124 | x = torch.cat([x, style_embs], dim=2) 125 | ## FC projection 126 | x = self.reduce_proj(x) 127 | 128 | if ppg_lengths is not None: 129 | x = torch.nn.utils.rnn.pack_padded_sequence(x, ppg_lengths, 130 | batch_first=True, 131 | enforce_sorted=False) 132 | x, _ = self.rnn(x) 133 | if ppg_lengths is not None: 134 | x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) 135 | x = self.hidden2out_layers(x) 136 | 137 | return x 138 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import math 5 | import time 6 | import torch 7 | import numpy as np 8 | from torch import nn 9 | import editdistance as ed 10 | 11 | 12 | class Timer(): 13 | ''' Timer for recording training time distribution. ''' 14 | 15 | def __init__(self): 16 | self.prev_t = time.time() 17 | self.clear() 18 | 19 | def set(self): 20 | self.prev_t = time.time() 21 | 22 | def cnt(self, mode): 23 | self.time_table[mode] += time.time()-self.prev_t 24 | self.set() 25 | if mode == 'bw': 26 | self.click += 1 27 | 28 | def show(self): 29 | total_time = sum(self.time_table.values()) 30 | self.time_table['avg'] = total_time/self.click 31 | self.time_table['rd'] = 100*self.time_table['rd']/total_time 32 | self.time_table['fw'] = 100*self.time_table['fw']/total_time 33 | self.time_table['bw'] = 100*self.time_table['bw']/total_time 34 | msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format( 35 | **self.time_table) 36 | self.clear() 37 | return msg 38 | 39 | def clear(self): 40 | self.time_table = {'rd': 0, 'fw': 0, 'bw': 0} 41 | self.click = 0 42 | 43 | # Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168 44 | 45 | 46 | def init_weights(module): 47 | # Exceptions 48 | if type(module) == nn.Embedding: 49 | module.weight.data.normal_(0, 1) 50 | else: 51 | for p in module.parameters(): 52 | data = p.data 53 | if data.dim() == 1: 54 | # bias 55 | data.zero_() 56 | elif data.dim() == 2: 57 | # linear weight 58 | n = data.size(1) 59 | stdv = 1. / math.sqrt(n) 60 | data.normal_(0, stdv) 61 | elif data.dim() in [3, 4]: 62 | # conv weight 63 | n = data.size(1) 64 | for k in data.size()[2:]: 65 | n *= k 66 | stdv = 1. / math.sqrt(n) 67 | data.normal_(0, stdv) 68 | else: 69 | raise NotImplementedError 70 | 71 | 72 | def init_gate(bias): 73 | n = bias.size(0) 74 | start, end = n // 4, n // 2 75 | bias.data[start:end].fill_(1.) 76 | return bias 77 | 78 | # Convert Tensor to Figure on tensorboard 79 | 80 | 81 | def feat_to_fig(feat): 82 | # feat TxD tensor 83 | data = _save_canvas(feat.numpy()) 84 | return torch.FloatTensor(data), "HWC" 85 | 86 | 87 | def _save_canvas(data, meta=None): 88 | fig, ax = plt.subplots(figsize=(16, 8)) 89 | if meta is None: 90 | ax.imshow(data, aspect="auto", origin="lower") 91 | else: 92 | ax.bar(meta[0], data[0], tick_label=meta[1], fc=(0, 0, 1, 0.5)) 93 | ax.bar(meta[0], data[1], tick_label=meta[1], fc=(1, 0, 0, 0.5)) 94 | fig.canvas.draw() 95 | # Note : torch tb add_image takes color as [0,1] 96 | data = np.array(fig.canvas.renderer._renderer)[:, :, :-1]/255.0 97 | plt.close(fig) 98 | return data 99 | 100 | # Reference : https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings-in-python 101 | 102 | 103 | def human_format(num): 104 | magnitude = 0 105 | while num >= 1000: 106 | magnitude += 1 107 | num /= 1000.0 108 | # add more suffixes if you need them 109 | return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude]) 110 | 111 | 112 | def cal_er(tokenizer, pred, truth, mode='wer', ctc=False): 113 | # Calculate error rate of a batch 114 | if pred is None: 115 | return np.nan 116 | elif len(pred.shape) >= 3: 117 | pred = pred.argmax(dim=-1) 118 | er = [] 119 | for p, t in zip(pred, truth): 120 | p = tokenizer.decode(p.tolist(), ignore_repeat=ctc) 121 | t = tokenizer.decode(t.tolist()) 122 | if mode == 'wer': 123 | p = p.split(' ') 124 | t = t.split(' ') 125 | er.append(float(ed.eval(p, t))/len(t)) 126 | return sum(er)/len(er) 127 | 128 | 129 | def load_embedding(text_encoder, embedding_filepath): 130 | with open(embedding_filepath, "r") as f: 131 | vocab_size, embedding_size = [int(x) 132 | for x in f.readline().strip().split()] 133 | embeddings = np.zeros((text_encoder.vocab_size, embedding_size)) 134 | 135 | unk_count = 0 136 | 137 | for line in f: 138 | vocab, emb = line.strip().split(" ", 1) 139 | # fasttext's is 140 | if vocab == "": 141 | vocab = "" 142 | 143 | if text_encoder.token_type == "subword": 144 | idx = text_encoder.spm.piece_to_id(vocab) 145 | else: 146 | # get rid of 147 | idx = text_encoder.encode(vocab)[0] 148 | 149 | if idx == text_encoder.unk_idx: 150 | unk_count += 1 151 | embeddings[idx] += np.asarray([float(x) 152 | for x in emb.split(" ")]) 153 | else: 154 | # Suppose there is only one (w, v) pair in embedding file 155 | embeddings[idx] = np.asarray( 156 | [float(x) for x in emb.split(" ")]) 157 | 158 | # Average vector 159 | if unk_count != 0: 160 | embeddings[text_encoder.unk_idx] /= unk_count 161 | 162 | return embeddings 163 | -------------------------------------------------------------------------------- /src/vc_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def gcd(a, b): 6 | """Greatest common divisor.""" 7 | a, b = (a, b) if a >=b else (b, a) 8 | if a%b == 0: 9 | return b 10 | else : 11 | return gcd(b, a%b) 12 | 13 | def lcm(a, b): 14 | """Least common multiple""" 15 | return a * b // gcd(a, b) 16 | 17 | def get_mask_from_lengths(lengths, max_len=None): 18 | if max_len is None: 19 | max_len = torch.max(lengths).item() 20 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 21 | mask = (ids < lengths.unsqueeze(1)).bool() 22 | return mask 23 | 24 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | . ./path.sh || exit 1; 2 | export CUDA_VISIBLE_DEVICES=7 3 | 4 | stage=1 5 | stop_stage=1 6 | config=$1 7 | model_file=$2 8 | src_wav_dir=$3 9 | ref_wav_path=$4 10 | echo ${config} 11 | 12 | # =============== One-shot VC ================ 13 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 14 | exp_name="$(basename "${config}" .yaml)" 15 | echo Experiment name: "${exp_name}" 16 | # src_wav_dir="/home/shaunxliu/data/cmu_arctic/cmu_us_rms_arctic/wav" 17 | # ref_wav_path="/home/shaunxliu/data/cmu_arctic/cmu_us_slt_arctic/wav/arctic_a0001.wav" 18 | output_dir="vc_gen_wavs/$(basename "${config}" .yaml)" 19 | 20 | python convert_from_wav.py \ 21 | --ppg2mel_model_train_config ${config} \ 22 | --ppg2mel_model_file ${model_file} \ 23 | --src_wav_dir "${src_wav_dir}" \ 24 | --ref_wav_path "${ref_wav_path}" \ 25 | -o "${output_dir}" 26 | fi 27 | -------------------------------------------------------------------------------- /tools/Makefile: -------------------------------------------------------------------------------- 1 | PYTHON:= python3.7 2 | CUDA_VERSION:= 10.1 3 | PYTORCH_VERSION:= 1.6 4 | DOT:= . 5 | .PHONY: all clean 6 | 7 | all: virtualenv 8 | 9 | virtualenv: 10 | test -d venv || virtualenv -p $(PYTHON) venv 11 | . venv/bin/activate; pip install torch==$(PYTORCH_VERSION) \ 12 | -f https://download.pytorch.org/whl/cu$(subst $(DOT),,$(CUDA_VERSION))/torch_stable.html 13 | . venv/bin/activate; cd ../; pip install -r requirements.txt 14 | touch venv/bin/activate 15 | 16 | clean: 17 | rm -fr venv 18 | find -iname "*.pyc" -delete 19 | -------------------------------------------------------------------------------- /utils/f0_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import pyworld 4 | from scipy.interpolate import interp1d 5 | from scipy.signal import firwin, get_window, lfilter 6 | 7 | 8 | def compute_f0(wav, sr=16000, frame_period=10.0): 9 | """Compute f0 from wav using pyworld harvest algorithm.""" 10 | wav = wav.astype(np.float64) 11 | f0, _ = pyworld.harvest( 12 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0) 13 | return f0.astype(np.float32) 14 | 15 | 16 | def low_pass_filter(x, fs, cutoff=70, padding=True): 17 | """FUNCTION TO APPLY LOW PASS FILTER 18 | 19 | Args: 20 | x (ndarray): Waveform sequence 21 | fs (int): Sampling frequency 22 | cutoff (float): Cutoff frequency of low pass filter 23 | 24 | Return: 25 | (ndarray): Low pass filtered waveform sequence 26 | """ 27 | 28 | nyquist = fs // 2 29 | norm_cutoff = cutoff / nyquist 30 | 31 | # low cut filter 32 | numtaps = 255 33 | fil = firwin(numtaps, norm_cutoff) 34 | x_pad = np.pad(x, (numtaps, numtaps), 'edge') 35 | lpf_x = lfilter(fil, 1, x_pad) 36 | lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2] 37 | 38 | return lpf_x 39 | 40 | 41 | def convert_continuous_f0(f0): 42 | """CONVERT F0 TO CONTINUOUS F0 43 | 44 | Args: 45 | f0 (ndarray): original f0 sequence with the shape (T) 46 | 47 | Return: 48 | (ndarray): continuous f0 with the shape (T) 49 | """ 50 | # get uv information as binary 51 | uv = np.float32(f0 != 0) 52 | 53 | # get start and end of f0 54 | if (f0 == 0).all(): 55 | logging.warn("all of the f0 values are 0.") 56 | return uv, f0 57 | start_f0 = f0[f0 != 0][0] 58 | end_f0 = f0[f0 != 0][-1] 59 | 60 | # padding start and end of f0 sequence 61 | start_idx = np.where(f0 == start_f0)[0][0] 62 | end_idx = np.where(f0 == end_f0)[0][-1] 63 | f0[:start_idx] = start_f0 64 | f0[end_idx:] = end_f0 65 | 66 | # get non-zero frame index 67 | nz_frames = np.where(f0 != 0)[0] 68 | 69 | # perform linear interpolation 70 | f = interp1d(nz_frames, f0[nz_frames]) 71 | cont_f0 = f(np.arange(0, f0.shape[0])) 72 | 73 | return uv, cont_f0 74 | 75 | 76 | def get_cont_lf0(f0, frame_period=10.0, lpf=False): 77 | uv, cont_f0 = convert_continuous_f0(f0) 78 | if lpf: 79 | cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20) 80 | cont_lf0_lpf = cont_f0_lpf.copy() 81 | nonzero_indices = np.nonzero(cont_lf0_lpf) 82 | cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices]) 83 | # cont_lf0_lpf = np.log(cont_f0_lpf) 84 | return uv, cont_lf0_lpf 85 | else: 86 | nonzero_indices = np.nonzero(cont_f0) 87 | cont_lf0 = cont_f0.copy() 88 | cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0]) 89 | return uv, cont_lf0 90 | -------------------------------------------------------------------------------- /utils/file_related.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def load_filepaths_and_text(filename, split="|"): 6 | with open(filename, encoding='utf-8') as f: 7 | filepaths_and_text = [line.strip().split(split) for line in f] 8 | return filepaths_and_text 9 | 10 | 11 | -------------------------------------------------------------------------------- /utils/load_yaml.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def load_hparams(filename): 5 | stream = open(filename, 'r') 6 | docs = yaml.safe_load_all(stream) 7 | hparams_dict = dict() 8 | for doc in docs: 9 | for k, v in doc.items(): 10 | hparams_dict[k] = v 11 | return hparams_dict 12 | 13 | def merge_dict(user, default): 14 | if isinstance(user, dict) and isinstance(default, dict): 15 | for k, v in default.items(): 16 | if k not in user: 17 | user[k] = v 18 | else: 19 | user[k] = merge_dict(user[k], v) 20 | return user 21 | 22 | class Dotdict(dict): 23 | """ 24 | a dictionary that supports dot notation 25 | as well as dictionary access notation 26 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 27 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 28 | get attributes: d.val2 or d['val2'] 29 | """ 30 | __getattr__ = dict.__getitem__ 31 | __setattr__ = dict.__setitem__ 32 | __delattr__ = dict.__delitem__ 33 | 34 | def __init__(self, dct=None): 35 | dct = dict() if not dct else dct 36 | for key, value in dct.items(): 37 | if hasattr(value, 'keys'): 38 | value = Dotdict(value) 39 | self[key] = value 40 | 41 | class HpsYaml(Dotdict): 42 | def __init__(self, yaml_file): 43 | super(Dotdict, self).__init__() 44 | hps = load_hparams(yaml_file) 45 | hp_dict = Dotdict(hps) 46 | for k, v in hp_dict.items(): 47 | setattr(self, k, v) 48 | 49 | __getattr__ = Dotdict.__getitem__ 50 | __setattr__ = Dotdict.__setitem__ 51 | __delattr__ = Dotdict.__delitem__ 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /utils/tensor_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def pad(inputs, max_length=None): 7 | 8 | if max_length: 9 | out_list = list() 10 | for i, mat in enumerate(inputs): 11 | mat_padded = F.pad( 12 | mat, (0, 0, 0, max_length-mat.size(0)), "constant", 0.0) 13 | out_list.append(mat_padded) 14 | out_padded = torch.stack(out_list) 15 | return out_padded 16 | else: 17 | out_list = list() 18 | max_length = max([inputs[i].size(0)for i in range(len(inputs))]) 19 | 20 | for i, mat in enumerate(inputs): 21 | mat_padded = F.pad( 22 | mat, (0, 0, 0, max_length-mat.size(0)), "constant", 0.0) 23 | out_list.append(mat_padded) 24 | out_padded = torch.stack(out_list) 25 | return out_padded 26 | -------------------------------------------------------------------------------- /vocoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/vocoders/__init__.py -------------------------------------------------------------------------------- /vocoders/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) 16 | -------------------------------------------------------------------------------- /vocoders/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | 59 | -------------------------------------------------------------------------------- /vocoders/vctk_24k10ms/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,5,3,2], 12 | "upsample_kernel_sizes": [15,15,5,5], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 12000, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 240, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 24000, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /vocoders/vctk_24k10ms/g_02830000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/vocoders/vctk_24k10ms/g_02830000 --------------------------------------------------------------------------------