├── .gitignore
├── 1_compute_ctc_att_bnf.py
├── 2_compute_f0.py
├── 3_compute_spk_dvecs.py
├── LICENSE
├── README.md
├── bin
├── solver.py
├── train_linglf02mel_seq2seq_encAddlf0.py
├── train_linglf02mel_seq2seq_lsa.py
├── train_linglf02mel_seq2seq_oneshotvc.py
└── train_ppg2mel_oneshotvc.py
├── conf
├── bilstm_ppg2mel_vctk_libri_oneshotvc.yaml
└── seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml
├── conformer_ppg_model
├── build_ppg_model.py
├── e2e_asr_common.py
├── en_conformer_ctc_att
│ ├── 24epoch.pth
│ └── config.yaml
├── encoder
│ ├── __init__.py
│ ├── attention.py
│ ├── conformer_encoder.py
│ ├── convolution.py
│ ├── embedding.py
│ ├── encoder.py
│ ├── encoder_layer.py
│ ├── layer_norm.py
│ ├── multi_layer_conv.py
│ ├── positionwise_feed_forward.py
│ ├── repeat.py
│ ├── subsampling.py
│ ├── swish.py
│ └── vgg.py
├── encoders.py
├── frontend.py
├── log_mel.py
├── nets_utils.py
├── stft.py
└── utterance_mvn.py
├── convert_from_wav.py
├── figs
├── README.md
├── seq2seq_bnf2mel.pdf
└── seq2seq_bnf2mel.png
├── main.py
├── path.sh
├── requirements.txt
├── run.sh
├── speaker_encoder
├── __init__.py
├── audio.py
├── ckpt
│ └── pretrained_bak_5805000.pt
├── compute_embed.py
├── config.py
├── data_objects
│ ├── __init__.py
│ ├── random_cycler.py
│ ├── speaker.py
│ ├── speaker_batch.py
│ ├── speaker_verification_dataset.py
│ └── utterance.py
├── hparams.py
├── inference.py
├── model.py
├── params_data.py
├── params_model.py
├── preprocess.py
├── train.py
├── visualizations.py
└── voice_encoder.py
├── src
├── __init__.py
├── abs_model.py
├── audio_utils.py
├── basic_layers.py
├── cnn_postnet.py
├── data_load.py
├── f0_utils.py
├── loss.py
├── loss_fn.py
├── lsa_attention.py
├── mel_decoder_lsa.py
├── mel_decoder_mol_encAddlf0.py
├── mel_decoder_mol_v2.py
├── module.py
├── mol_attention.py
├── nets_utils.py
├── optim.py
├── option.py
├── rnn_decoder_lsa.py
├── rnn_decoder_mol.py
├── rnn_decoder_mol_add_pitch.py
├── rnn_ppg2mel.py
├── solver.py
├── util.py
└── vc_utils.py
├── test.sh
├── tools
└── Makefile
├── utils
├── f0_utils.py
├── file_related.py
├── load_yaml.py
└── tensor_ops.py
└── vocoders
├── __init__.py
├── env.py
├── hifigan_model.py
├── utils.py
└── vctk_24k10ms
├── config.json
└── g_02830000
/.gitignore:
--------------------------------------------------------------------------------
1 | # general
2 | *~
3 | *.pyc
4 | \#*\#
5 | .\#*
6 | *DS_Store
7 | out.txt
8 | doc/_build
9 | slurm-*.out
10 | tmp*
11 | .eggs/
12 | .hypothesis/
13 | .idea
14 | .backup/
15 | .pytest_cache/
16 | __pycache__/
17 | .coverage*
18 | coverage.xml*
19 | .vscode*
20 | .nfs*
21 | .ipynb_checkpoints
22 |
23 | data/f0s.scp
24 | log
25 | conf/tuning
26 | ckpt
27 | vc_gen_wavs
28 | tools/venv
29 |
--------------------------------------------------------------------------------
/1_compute_ctc_att_bnf.py:
--------------------------------------------------------------------------------
1 | """
2 | Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
3 | """
4 | import sys
5 | import os
6 | import argparse
7 | import torch
8 | import glob2
9 | import soundfile
10 | import librosa
11 |
12 | import numpy as np
13 | from tqdm import tqdm
14 | from conformer_ppg_model.build_ppg_model import load_ppg_model
15 |
16 |
17 | SAMPLE_RATE=16000
18 |
19 |
20 | def compute_bnf(
21 | output_dir: str,
22 | wav_dir: str,
23 | train_config: str,
24 | model_file: str,
25 | ):
26 | device = "cuda"
27 |
28 | # 1. Build PPG model
29 | ppg_model_local = load_ppg_model(train_config, model_file, device)
30 |
31 | # 2. Glob wav files
32 | wav_file_list = glob2.glob(f"{wav_dir}/**/*.wav")
33 | print(f"Globbing {len(wav_file_list)} wav files.")
34 |
35 | # 3. start to compute ppgs
36 | os.makedirs(output_dir, exist_ok=True)
37 | for wav_file in tqdm(wav_file_list):
38 | audio, sr = soundfile.read(wav_file, always_2d=False)
39 | if sr != SAMPLE_RATE:
40 | audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)
41 | wav_tensor = torch.from_numpy(audio).float().to(device).unsqueeze(0)
42 | wav_length = torch.LongTensor([audio.shape[0]]).to(device)
43 | with torch.no_grad():
44 | bnf = ppg_model_local(wav_tensor, wav_length)
45 | # bnf = torch.nn.functional.softmax(asr_model.ctc.ctc_lo(bnf), dim=2)
46 | bnf_npy = bnf.squeeze(0).cpu().numpy()
47 | fid = os.path.basename(wav_file).split(".")[0]
48 | bnf_fname = f"{output_dir}/{fid}.ling_feat.npy"
49 | np.save(bnf_fname, bnf_npy, allow_pickle=False)
50 |
51 |
52 | def get_parser():
53 | parser = argparse.ArgumentParser(description="compute ppg or ctc-bnf or ctc-att-bnf")
54 |
55 | parser.add_argument(
56 | "--output_dir",
57 | type=str,
58 | required=True,
59 | default=None,
60 | )
61 | parser.add_argument(
62 | "--wav_dir",
63 | type=str,
64 | required=True,
65 | default=None,
66 | )
67 | parser.add_argument(
68 | "--train_config",
69 | type=str,
70 | default="./conformer_ppg_model/en_conformer_ctc_att/config.yaml",
71 | )
72 | parser.add_argument(
73 | "--model_file",
74 | type=str,
75 | default="./conformer_ppg_model/en_conformer_ctc_att/24epoch.pth",
76 | )
77 |
78 | return parser
79 |
80 |
81 | if __name__ == "__main__":
82 | parser = get_parser()
83 | args = parser.parse_args()
84 | kwargs = vars(args)
85 | compute_bnf(**kwargs)
86 |
--------------------------------------------------------------------------------
/2_compute_f0.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob2
3 | import numpy as np
4 | import io
5 | from tqdm import tqdm
6 | import soundfile
7 | import resampy
8 | import pyworld
9 |
10 | import torch
11 | from multiprocessing import cpu_count
12 | from concurrent.futures import ProcessPoolExecutor
13 | from functools import partial
14 |
15 |
16 | def compute_f0(
17 | wav,
18 | sr,
19 | f0_floor=20.0,
20 | f0_ceil=600.0,
21 | frame_period=10.0
22 | ):
23 | wav = wav.astype(np.float64)
24 | f0, timeaxis = pyworld.harvest(
25 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0)
26 | return f0.astype(np.float32)
27 |
28 |
29 | def compute_f0_from_wav(
30 | wavfile_path,
31 | sampling_rate,
32 | f0_floor,
33 | f0_ceil,
34 | frame_period_ms,
35 | ):
36 | # try:
37 | wav, sr = soundfile.read(wavfile_path)
38 | if len(wav) < sr:
39 | return None, sr, len(wav)
40 | if sr != sampling_rate:
41 | wav = resampy.resample(wav, sr, sampling_rate)
42 | sr = sampling_rate
43 | f0 = compute_f0(wav, sr, f0_floor, f0_ceil, frame_period_ms)
44 | return f0, sr, len(wav)
45 |
46 |
47 | def process_one(
48 | wav_file_path,
49 | args,
50 | output_dir,
51 | ):
52 | fid = os.path.basename(wav_file_path)[:-4]
53 | save_fname = f"{output_dir}/{fid}.f0.npy"
54 | if os.path.isfile(save_fname):
55 | return
56 |
57 | f0, sr, wav_len = compute_f0_from_wav(
58 | wav_file_path, args.sampling_rate,
59 | args.f0_floor, args.f0_ceil, args.frame_period_ms)
60 | if f0 is None:
61 | return
62 | np.save(save_fname, f0, allow_pickle=False)
63 |
64 |
65 | def run(args):
66 | """Compute merged f0 values."""
67 | output_dir = args.output_dir
68 | os.makedirs(output_dir, exist_ok=True)
69 |
70 | wav_dir = args.wav_dir
71 | # Get file id list
72 | wav_file_list = glob2.glob(f"{wav_dir}/**/*.wav")
73 | print(f"Globbed {len(wav_file_list)} wave files.")
74 |
75 | # Multi-process worker
76 | if args.num_workers < 2 :
77 | for wav_file_path in tqdm(wav_file_list):
78 | process_one(wav_file_path, args, output_dir)
79 | else:
80 | with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
81 | futures = []
82 | for wav_file_path in wav_file_list:
83 | futures.append(executor.submit(
84 | partial(
85 | process_one, wav_file_path, args, output_dir,
86 | )
87 | ))
88 | results = [future.result() for future in tqdm(futures)]
89 |
90 |
91 | def get_parser():
92 | import argparse
93 | parser = argparse.ArgumentParser(description="Compute merged f0 values")
94 | parser.add_argument(
95 | "--wav_dir",
96 | required=True,
97 | type=str,
98 | )
99 | parser.add_argument(
100 | "--output_dir",
101 | required=True,
102 | type=str,
103 | )
104 | parser.add_argument(
105 | "--frame_period_ms",
106 | default=10,
107 | type=float,
108 | )
109 | parser.add_argument(
110 | "--sampling_rate",
111 | default=24000,
112 | type=int,
113 | )
114 | parser.add_argument(
115 | "--f0_floor",
116 | default=80,
117 | type=int,
118 | )
119 | parser.add_argument(
120 | "--f0_ceil",
121 | default=600,
122 | type=int
123 | )
124 | parser.add_argument(
125 | "--num_workers",
126 | default=10,
127 | type=int
128 | )
129 | return parser
130 |
131 |
132 | def main():
133 | parser = get_parser()
134 | args = parser.parse_args()
135 | print(args)
136 | run(args)
137 |
138 |
139 | if __name__ == "__main__":
140 | main()
141 |
--------------------------------------------------------------------------------
/3_compute_spk_dvecs.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from speaker_encoder.voice_encoder import SpeakerEncoder
3 | from speaker_encoder.audio import preprocess_wav
4 | from pathlib import Path
5 | import numpy as np
6 | from os.path import join, basename, split
7 | from tqdm import tqdm
8 | from multiprocessing import cpu_count
9 | from concurrent.futures import ProcessPoolExecutor
10 | from functools import partial
11 | import glob
12 | import argparse
13 |
14 |
15 | def build_from_path(in_dir, out_dir, weights_fpath, num_workers=1):
16 | executor = ProcessPoolExecutor(max_workers=num_workers)
17 | futures = []
18 | wavfile_paths = glob.glob(os.path.join(in_dir, '*/*/*.wav'))
19 | wavfile_paths= sorted(wavfile_paths)
20 | for wav_path in wavfile_paths:
21 | futures.append(executor.submit(
22 | partial(_compute_spkEmbed, out_dir, wav_path, weights_fpath)))
23 | return [future.result() for future in tqdm(futures)]
24 |
25 | def _compute_spkEmbed(out_dir, wav_path, weights_fpath):
26 | utt_id = os.path.basename(wav_path).rstrip(".wav")
27 | fpath = Path(wav_path)
28 | wav = preprocess_wav(fpath)
29 |
30 | encoder = SpeakerEncoder(weights_fpath)
31 | embed = encoder.embed_utterance(wav)
32 | fname_save = os.path.join(out_dir, f"{utt_id}.npy")
33 | np.save(fname_save, embed, allow_pickle=False)
34 | return os.path.basename(fname_save)
35 |
36 | def preprocess(in_dir, out_dir, weights_fpath, num_workers):
37 | os.makedirs(out_dir, exist_ok=True)
38 | metadata = build_from_path(in_dir, out_dir, weights_fpath, num_workers)
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--in_dir', type=str,
43 | default='/home/shaunxliu/data/datasets/LibriTTS/LibriTTS')
44 | parser.add_argument('--num_workers', type=int, default=20)
45 | parser.add_argument('--out_dir_root', type=str,
46 | default='/home/shaunxliu/data/datasets/LibriTTS')
47 | parser.add_argument('--spk_encoder_ckpt', type=str, \
48 | default='speaker_encoder/ckpt/pretrained_bak_5805000.pt')
49 |
50 | args = parser.parse_args()
51 |
52 | split_list = ['train-clean-100', 'train-clean-360']
53 |
54 | # sub_folder_list = os.listdir(args.in_dir)
55 | # sub_folder_list.sort()
56 |
57 | args.num_workers = args.num_workers if args.num_workers is not None else cpu_count()
58 | print("Number of workers: ", args.num_workers)
59 | ckpt_step = os.path.basename(args.spk_encoder_ckpt).split('.')[0].split('_')[-1]
60 | spk_embed_out_dir = os.path.join(args.out_dir_root, f"GE2E_spkEmbed_step_{ckpt_step}")
61 | print("[INFO] spk_embed_out_dir: ", spk_embed_out_dir)
62 | os.makedirs(spk_embed_out_dir, exist_ok=True)
63 |
64 | # for data_split in split_list:
65 | # sub_folder_list = os.listdir(args.in_dir, data_split)
66 | # for spk in sub_folder_list:
67 | # print("Preprocessing {} ...".format(spk))
68 | # in_dir = os.path.join(args.in_dir, dataset, spk)
69 | # if not os.path.isdir(in_dir):
70 | # continue
71 | # # out_dir = os.path.join(args.out_dir, spk)
72 | # preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers)
73 | for data_split in split_list:
74 | in_dir = os.path.join(args.in_dir, data_split)
75 | preprocess(in_dir, spk_embed_out_dir, args.spk_encoder_ckpt, args.num_workers)
76 |
77 | print("DONE!")
78 | sys.exit(0)
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # One-shot Phonetic PosteriorGram (PPG)-Based Voice Conversion (PPG-VC): Any-to-Many Voice Conversion with Location-Relative Sequence-to-Sequence Modeling (TASLP 2021)
2 |
3 | ### [Paper](https://arxiv.org/abs/2009.02725v3) | [Pre-trained models](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing) | [Paper Demo](https://liusongxiang.github.io/BNE-Seq2SeqMoL-VC/)
4 |
5 | This paper proposes an any-to-many location-relative, sequence-to-sequence (seq2seq) based, non-parallel voice conversion approach. In this approach, we combine a bottle-neck feature extractor (BNE) with a seq2seq based synthesis module. During the training stage, an encoder-decoder based hybrid connectionist-temporal-classification-attention (CTC-attention) phoneme recognizer is trained, whose encoder has a bottle-neck layer. A BNE is obtained from the phoneme recognizer and is utilized to extract speaker-independent, dense and rich linguistic representations from spectral features. Then a multi-speaker location-relative attention based seq2seq synthesis model is trained to reconstruct spectral features from the bottle-neck features, conditioning on speaker representations for speaker identity control in the generated speech. To mitigate the difficulties of using seq2seq based models to align long sequences, we down-sample the input spectral feature along the temporal dimension and equip the synthesis model with a discretized mixture of logistic (MoL) attention mechanism. Since the phoneme recognizer is trained with large speech recognition data corpus, the proposed approach can conduct any-to-many voice conversion. Objective and subjective evaluations shows that the proposed any-to-many approach has superior voice conversion performance in terms of both naturalness and speaker similarity. Ablation studies are conducted to confirm the effectiveness of feature selection and model design strategies in the proposed approach. The proposed VC approach can readily be extended to support any-to-any VC (also known as one/few-shot VC), and achieve high performance according to objective and subjective evaluations.
6 |
7 |
8 |
9 |
10 |
11 | Diagram of the BNE-Seq2seqMoL system.
12 |
13 |
14 |
15 | This repo implements an updated version of PPG-based VC models.
16 |
17 | Notes:
18 |
19 | - The PPG model provided in `conformer_ppg_model` is based on Hybrid CTC-Attention phoneme recognizer, trained with LibriSpeech (960hrs). PPGs have frame-shift of 10 ms, with dimensionality of 144. This modelis very much similar to the one used in [this paper](https://arxiv.org/pdf/2011.05731v2.pdf).
20 |
21 | - This repo uses [Hifi-GAN V1](https://github.com/jik876/hifi-gan) as the vocoder model, sampling rate of synthesized audio is 24kHz.
22 |
23 | ## Updates!
24 | - We provide an audio sample uttered by Barack Obama ([link](https://drive.google.com/file/d/10Cgtw14UtVf2jTqKtR-C1y5bZqq6Ue7U/view?usp=sharing)), you can convert any voice into Obama's voice using this sample as reference. Please have a try!
25 | - BNE-Seq2seqMoL One-shot VC model are uploaded ([link](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing))
26 | - BiLSTM-based One-shot VC model are uploaded ([link](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing))
27 |
28 |
29 | ## How to use
30 | ### Setup with virtualenv
31 | ```
32 | $ cd tools
33 | $ make
34 | ```
35 |
36 | Note: If you want to specify Python version, CUDA version or PyTorch version, please run for example:
37 |
38 | ```
39 | $ make PYTHON=3.7 CUDA_VERSION=10.1 PYTORCH_VERSION=1.6
40 | ```
41 |
42 | ### Conversion with a pretrained model
43 | 1. Download a model from [here](https://drive.google.com/drive/folders/1JeFntg2ax9gX4POFbQwcS85eC9hyQ6W6?usp=sharing), we recommend to first try the model `bneSeq2seqMoL-vctk-libritts460-oneshot`. Put the config file and the checkpoint file in a folder ``.
44 | 2. Prepare a source wav directory ``, where the wavs inside are what you want to convert.
45 | 3. Prepare a reference audio sample (i.e., the target voice you want convert to) ``.
46 | 4. Run `test.sh` as:
47 | ```
48 | sh test.sh /seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml /best_loss_step_304000.pth \
49 |
50 | ```
51 | The converted wavs are saved in the folder `vc_gen_wavs`.
52 |
53 | ### Data preprocessing
54 | Activate the virtual env py `source tools/venv/bin/activate`, then:
55 | - Please run `1_compute_ctc_att_bnf.py` to compute PPG features.
56 | - Please run `2_compute_f0.py` to compute fundamental frequency.
57 | - Please run `3_compute_spk_dvecs.py` to compute speaker d-vectors.
58 |
59 | ### Training
60 | - Please refer to `run.sh`
61 |
62 | ## Citations
63 | If you use this repo for your research, please consider of citing the following related papers.
64 | ```
65 | @ARTICLE{liu2021any,
66 | author={Liu, Songxiang and Cao, Yuewen and Wang, Disong and Wu, Xixin and Liu, Xunying and Meng, Helen},
67 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
68 | title={Any-to-Many Voice Conversion With Location-Relative Sequence-to-Sequence Modeling},
69 | year={2021},
70 | volume={29},
71 | number={},
72 | pages={1717-1728},
73 | doi={10.1109/TASLP.2021.3076867}
74 | }
75 |
76 | @inproceedings{Liu2018,
77 | author={Songxiang Liu and Jinghua Zhong and Lifa Sun and Xixin Wu and Xunying Liu and Helen Meng},
78 | title={Voice Conversion Across Arbitrary Speakers Based on a Single Target-Speaker Utterance},
79 | year=2018,
80 | booktitle={Proc. Interspeech 2018},
81 | pages={496--500},
82 | doi={10.21437/Interspeech.2018-1504},
83 | url={http://dx.doi.org/10.21437/Interspeech.2018-1504}
84 | }
85 |
--------------------------------------------------------------------------------
/bin/solver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import abc
4 | import math
5 | import yaml
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | from src.option import default_hparas
10 | from src.util import human_format, Timer
11 | from utils.load_yaml import HpsYaml
12 |
13 |
14 | class BaseSolver():
15 | '''
16 | Prototype Solver for all kinds of tasks
17 | Arguments
18 | config - yaml-styled config
19 | paras - argparse outcome
20 | mode - "train"/"test"
21 | '''
22 |
23 | def __init__(self, config, paras, mode="train"):
24 | # General Settings
25 | self.config = config # load from yaml file
26 | self.paras = paras # command line args
27 | self.mode = mode # 'train' or 'test'
28 | for k, v in default_hparas.items():
29 | setattr(self, k, v)
30 | self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
31 | else torch.device('cpu')
32 |
33 | # Name experiment
34 | self.exp_name = paras.name
35 | if self.exp_name is None:
36 | if 'exp_name' in self.config:
37 | self.exp_name = self.config.exp_name
38 | else:
39 | # By default, exp is named after config file
40 | self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
41 | if mode == 'train':
42 | self.exp_name += '_seed{}'.format(paras.seed)
43 |
44 |
45 | if mode == 'train':
46 | # Filepath setup
47 | os.makedirs(paras.ckpdir, exist_ok=True)
48 | self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
49 | os.makedirs(self.ckpdir, exist_ok=True)
50 |
51 | # Logger settings
52 | self.logdir = os.path.join(paras.logdir, self.exp_name)
53 | self.log = SummaryWriter(
54 | self.logdir, flush_secs=self.TB_FLUSH_FREQ)
55 | self.timer = Timer()
56 |
57 | # Hyper-parameters
58 | self.step = 0
59 | self.valid_step = config.hparas.valid_step
60 | self.max_step = config.hparas.max_step
61 |
62 | self.verbose('Exp. name : {}'.format(self.exp_name))
63 | self.verbose('Loading data... large corpus may took a while.')
64 |
65 | # elif mode == 'test':
66 | # # Output path
67 | # os.makedirs(paras.outdir, exist_ok=True)
68 | # self.ckpdir = os.path.join(paras.outdir, self.exp_name)
69 |
70 | # Load training config to get acoustic feat and build model
71 | # self.src_config = HpsYaml(config.src.config)
72 | # self.paras.load = config.src.ckpt
73 |
74 | # self.verbose('Evaluating result of tr. config @ {}'.format(
75 | # config.src.config))
76 |
77 | def backward(self, loss):
78 | '''
79 | Standard backward step with self.timer and debugger
80 | Arguments
81 | loss - the loss to perform loss.backward()
82 | '''
83 | self.timer.set()
84 | loss.backward()
85 | grad_norm = torch.nn.utils.clip_grad_norm_(
86 | self.model.parameters(), self.GRAD_CLIP)
87 | if math.isnan(grad_norm):
88 | self.verbose('Error : grad norm is NaN @ step '+str(self.step))
89 | else:
90 | self.optimizer.step()
91 | self.timer.cnt('bw')
92 | return grad_norm
93 |
94 | def load_ckpt(self):
95 | ''' Load ckpt if --load option is specified '''
96 | if self.paras.load is not None:
97 | if self.paras.warm_start:
98 | self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
99 | ckpt = torch.load(
100 | self.paras.load, map_location=self.device if self.mode == 'train'
101 | else 'cpu')
102 | model_dict = ckpt['model']
103 | if len(self.config.model.ignore_layers) > 0:
104 | model_dict = {k:v for k, v in model_dict.items()
105 | if k not in self.config.model.ignore_layers}
106 | dummy_dict = self.model.state_dict()
107 | dummy_dict.update(model_dict)
108 | model_dict = dummy_dict
109 | self.model.load_state_dict(model_dict)
110 | else:
111 | # Load weights
112 | ckpt = torch.load(
113 | self.paras.load, map_location=self.device if self.mode == 'train'
114 | else 'cpu')
115 | self.model.load_state_dict(ckpt['model'])
116 |
117 | # Load task-dependent items
118 | if self.mode == 'train':
119 | self.step = ckpt['global_step']
120 | self.optimizer.load_opt_state_dict(ckpt['optimizer'])
121 | self.verbose('Load ckpt from {}, restarting at step {}'.format(
122 | self.paras.load, self.step))
123 | else:
124 | for k, v in ckpt.items():
125 | if type(v) is float:
126 | metric, score = k, v
127 | self.model.eval()
128 | self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
129 | self.paras.load, metric, score))
130 |
131 | def verbose(self, msg):
132 | ''' Verbose function for print information to stdout'''
133 | if self.paras.verbose:
134 | if type(msg) == list:
135 | for m in msg:
136 | print('[INFO]', m.ljust(100))
137 | else:
138 | print('[INFO]', msg.ljust(100))
139 |
140 | def progress(self, msg):
141 | ''' Verbose function for updating progress on stdout (do not include newline) '''
142 | if self.paras.verbose:
143 | sys.stdout.write("\033[K") # Clear line
144 | print('[{}] {}'.format(human_format(self.step), msg), end='\r')
145 |
146 | def write_log(self, log_name, log_dict):
147 | '''
148 | Write log to TensorBoard
149 | log_name - Name of tensorboard variable
150 | log_value - / Value of variable (e.g. dict of losses), passed if value = None
151 | '''
152 | if type(log_dict) is dict:
153 | log_dict = {key: val for key, val in log_dict.items() if (
154 | val is not None and not math.isnan(val))}
155 | if log_dict is None:
156 | pass
157 | elif len(log_dict) > 0:
158 | if 'align' in log_name or 'spec' in log_name:
159 | img, form = log_dict
160 | self.log.add_image(
161 | log_name, img, global_step=self.step, dataformats=form)
162 | elif 'text' in log_name or 'hyp' in log_name:
163 | self.log.add_text(log_name, log_dict, self.step)
164 | else:
165 | self.log.add_scalars(log_name, log_dict, self.step)
166 |
167 | def save_checkpoint(self, f_name, metric, score, show_msg=True):
168 | ''''
169 | Ckpt saver
170 | f_name - the name of ckpt file (w/o prefix) to store, overwrite if existed
171 | score - The value of metric used to evaluate model
172 | '''
173 | ckpt_path = os.path.join(self.ckpdir, f_name)
174 | full_dict = {
175 | "model": self.model.state_dict(),
176 | "optimizer": self.optimizer.get_opt_state_dict(),
177 | "global_step": self.step,
178 | metric: score
179 | }
180 |
181 | torch.save(full_dict, ckpt_path)
182 | if show_msg:
183 | self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
184 | format(human_format(self.step), metric, score, ckpt_path))
185 |
186 |
187 | # ----------------------------------- Abtract Methods ------------------------------------------ #
188 | @abc.abstractmethod
189 | def load_data(self):
190 | '''
191 | Called by main to load all data
192 | After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
193 | No return value
194 | '''
195 | raise NotImplementedError
196 |
197 | @abc.abstractmethod
198 | def set_model(self):
199 | '''
200 | Called by main to set models
201 | After this call, model related attributes should be setup (e.g. self.l2_loss)
202 | The followings MUST be setup
203 | - self.model (torch.nn.Module)
204 | - self.optimizer (src.Optimizer),
205 | init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
206 | Loading pre-trained model should also be performed here
207 | No return value
208 | '''
209 | raise NotImplementedError
210 |
211 | @abc.abstractmethod
212 | def exec(self):
213 | '''
214 | Called by main to execute training/inference
215 | '''
216 | raise NotImplementedError
217 |
--------------------------------------------------------------------------------
/bin/train_ppg2mel_oneshotvc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | import numpy as np
4 | from src.solver import BaseSolver
5 | # from src.data_load import VcDataset, VcCollate
6 | from src.data_load import OneshotVcDataset, MultiSpkVcCollate
7 | from src.rnn_ppg2mel import BiRnnPpg2MelModel
8 | from src.optim import Optimizer
9 | from src.util import human_format, feat_to_fig
10 | from src.loss_fn import MaskedMSELoss
11 |
12 |
13 | class Solver(BaseSolver):
14 | """Customized Solver."""
15 | def __init__(self, config, paras, mode):
16 | super().__init__(config, paras, mode)
17 | self.best_loss = np.inf
18 |
19 | def fetch_data(self, data):
20 | """Move data to device"""
21 | data = [i.to(self.device) for i in data]
22 | return data
23 |
24 | def load_data(self):
25 | """ Load data for training/validation/plotting."""
26 | train_dataset = OneshotVcDataset(
27 | meta_file=self.config.data.train_fid_list,
28 | vctk_ppg_dir=self.config.data.vctk_ppg_dir,
29 | libri_ppg_dir=self.config.data.libri_ppg_dir,
30 | vctk_f0_dir=self.config.data.vctk_f0_dir,
31 | libri_f0_dir=self.config.data.libri_f0_dir,
32 | vctk_wav_dir=self.config.data.vctk_wav_dir,
33 | libri_wav_dir=self.config.data.libri_wav_dir,
34 | vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
35 | libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
36 | ppg_file_ext=self.config.data.ppg_file_ext,
37 | min_max_norm_mel=self.config.data.min_max_norm_mel,
38 | mel_min=self.config.data.mel_min,
39 | mel_max=self.config.data.mel_max,
40 | )
41 | dev_dataset = OneshotVcDataset(
42 | meta_file=self.config.data.dev_fid_list,
43 | vctk_ppg_dir=self.config.data.vctk_ppg_dir,
44 | libri_ppg_dir=self.config.data.libri_ppg_dir,
45 | vctk_f0_dir=self.config.data.vctk_f0_dir,
46 | libri_f0_dir=self.config.data.libri_f0_dir,
47 | vctk_wav_dir=self.config.data.vctk_wav_dir,
48 | libri_wav_dir=self.config.data.libri_wav_dir,
49 | vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
50 | libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
51 | ppg_file_ext=self.config.data.ppg_file_ext,
52 | min_max_norm_mel=self.config.data.min_max_norm_mel,
53 | mel_min=self.config.data.mel_min,
54 | mel_max=self.config.data.mel_max,
55 | )
56 | self.train_dataloader = DataLoader(
57 | train_dataset,
58 | num_workers=self.paras.njobs,
59 | shuffle=True,
60 | batch_size=self.config.hparas.batch_size,
61 | pin_memory=False,
62 | drop_last=True,
63 | collate_fn=MultiSpkVcCollate(n_frames_per_step=1,
64 | f02ppg_length_ratio=1,
65 | use_spk_dvec=True),
66 | )
67 | self.dev_dataloader = DataLoader(
68 | dev_dataset,
69 | num_workers=self.paras.njobs,
70 | shuffle=False,
71 | batch_size=self.config.hparas.batch_size,
72 | pin_memory=False,
73 | drop_last=False,
74 | collate_fn=MultiSpkVcCollate(n_frames_per_step=1,
75 | f02ppg_length_ratio=1,
76 | use_spk_dvec=True),
77 | )
78 | msg = "Have prepared training set and dev set."
79 | self.verbose(msg)
80 |
81 | def load_pretrained_params(self):
82 | prefix = "ppg2mel_model"
83 | ignore_layers = ["ppg2mel_model.spk_embedding.weight"]
84 | pretrain_model_file = self.config.data.pretrain_model_file
85 | pretrain_ckpt = torch.load(
86 | pretrain_model_file, map_location=self.device
87 | )
88 | model_dict = self.model.state_dict()
89 |
90 | # 1. filter out unnecessrary keys
91 | pretrain_dict = {k.split(".", maxsplit=1)[1]: v
92 | for k, v in pretrain_ckpt.items() if "spk_embedding" not in k
93 | and "wav2ppg_model" not in k and "reduce_proj" not in k}
94 | # assert len(pretrain_dict.keys()) == len(model_dict.keys())
95 |
96 | # 2. overwrite entries in the existing state dict
97 | model_dict.update(pretrain_dict)
98 |
99 | # 3. load the new state dict
100 | self.model.load_state_dict(model_dict)
101 |
102 | def set_model(self):
103 | """Setup model and optimizer"""
104 | # Model
105 | self.model = BiRnnPpg2MelModel(**self.config["model"]).to(self.device)
106 | if "pretrain_model_file" in self.config.data:
107 | self.load_pretrained_params()
108 |
109 | # model_params = [{'params': self.model.spk_embedding.weight}]
110 | model_params = [{'params': self.model.parameters()}]
111 |
112 | # Loss criterion
113 | self.loss_criterion = MaskedMSELoss()
114 |
115 | # Optimizer
116 | self.optimizer = Optimizer(model_params, **self.config["hparas"])
117 | self.verbose(self.optimizer.create_msg())
118 |
119 | # Automatically load pre-trained model if self.paras.load is given
120 | self.load_ckpt()
121 |
122 | def exec(self):
123 | self.verbose("Total training steps {}.".format(
124 | human_format(self.max_step)))
125 |
126 | mel_loss = None
127 | n_epochs = 0
128 | # Set as current time
129 | self.timer.set()
130 |
131 | while self.step < self.max_step:
132 | for data in self.train_dataloader:
133 | # Pre-step: updata lr_rate and do zero_grad
134 | lr_rate = self.optimizer.pre_step(self.step)
135 | total_loss = 0
136 | # data to device
137 | ppgs, lf0_uvs, mels, in_lengths, \
138 | out_lengths, spk_ids, _ = self.fetch_data(data)
139 | self.timer.cnt("rd")
140 | mel_pred = self.model(
141 | ppg=ppgs,
142 | ppg_lengths=out_lengths,
143 | logf0_uv=lf0_uvs,
144 | spembs=spk_ids,
145 | )
146 | loss = self.loss_criterion(mel_pred, mels, out_lengths)
147 |
148 | self.timer.cnt("fw")
149 |
150 | # Back-prop
151 | grad_norm = self.backward(loss)
152 | self.step += 1
153 |
154 | # Logger
155 | if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
156 | self.progress("Tr stat | Loss - {:.4f} | Grad. Norm - {:.2f} | {}"
157 | .format(loss.cpu().item(), grad_norm, self.timer.show()))
158 | self.write_log('loss', {'tr': loss})
159 |
160 | # Validation
161 | if (self.step == 1) or (self.step % self.valid_step == 0):
162 | self.validate()
163 |
164 | # End of step
165 | # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
166 | torch.cuda.empty_cache()
167 | self.timer.set()
168 | if self.step > self.max_step:
169 | break
170 | n_epochs += 1
171 | self.log.close()
172 |
173 | def validate(self):
174 | self.model.eval()
175 | dev_loss = 0.0
176 |
177 | for i, data in enumerate(self.dev_dataloader):
178 | self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
179 | # Fetch data
180 | # ppgs, lf0_uvs, mels, lengths = self.fetch_data(data)
181 | ppgs, lf0_uvs, mels, in_lengths, \
182 | out_lengths, spk_ids, _ = self.fetch_data(data)
183 |
184 | with torch.no_grad():
185 | mel_pred = self.model(
186 | ppg=ppgs,
187 | ppg_lengths=out_lengths,
188 | logf0_uv=lf0_uvs,
189 | spembs=spk_ids,
190 | )
191 | loss = self.loss_criterion(mel_pred, mels, out_lengths)
192 | dev_loss += loss.cpu().item()
193 |
194 | dev_loss = dev_loss / (i + 1)
195 | self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
196 | if dev_loss < self.best_loss:
197 | self.best_loss = dev_loss
198 | self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
199 | self.write_log('loss', {'dv_loss': dev_loss})
200 |
201 | # Resume training
202 | self.model.train()
203 |
204 |
--------------------------------------------------------------------------------
/conf/bilstm_ppg2mel_vctk_libri_oneshotvc.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_fid_list: "/home/shaunxliu/data/vctk/fidlists/train_fidlist.new"
3 | dev_fid_list: "/home/shaunxliu/data/vctk/fidlists/dev_fidlist.new"
4 | eval_fid_list: "/home/shaunxliu/data/vctk/fidlists/eval_fidlist.txt"
5 | vctk_ppg_dir: "/home/shaunxliu/data/vctk/conformer_bnf10ms"
6 | libri_ppg_dir: "/home/shaunxliu/data/LibriTTS/conformer_bnf10ms"
7 | vctk_f0_dir: "/home/shaunxliu/data/vctk/merged_f0s"
8 | libri_f0_dir: "/home/shaunxliu/data/LibriTTS/f0s"
9 | vctk_wav_dir: "/home/shaunxliu/data/vctk/wav_mono_24k_16b_norm-6db"
10 | libri_wav_dir: "/home/shaunxliu/data/LibriTTS/LibriTTS/train-wavs-clean460/"
11 | vctk_spk_dvec_dir: "/home/shaunxliu/data/vctk/GE2E_spkEmbed_step_5805000_perSpk"
12 | libri_spk_dvec_dir: "/home/shaunxliu/data/LibriTTS/GE2E_spkEmbed_step_5805000_perSpk"
13 | ppg_file_ext: "ling_feat.npy"
14 | f0_file_ext: "f0.npy"
15 | wav_file_ext: "wav"
16 | min_max_norm_mel: true
17 | mel_min: -12.0
18 | mel_max: 2.5
19 |
20 | hparas:
21 | batch_size: 32
22 | valid_step: 1000
23 | max_step: 1000000
24 | optimizer: 'Adam'
25 | lr: 0.001
26 | eps: 1.0e-8
27 | weight_decay: 1.0e-6
28 | lr_scheduler: 'warmup' # "fixed", "warmup"
29 |
30 | model_name: "bilstm"
31 | model:
32 | input_size: 146 # 144 ppg-dim and 2 pitch
33 | multi_spk: True
34 | use_spk_dvec: True # for one-shot VC
35 |
36 |
37 |
--------------------------------------------------------------------------------
/conf/seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | #train_fid_list: "../wavernn-vc/dump/vc_train_set/f0s.scp"
3 | train_fid_list: "/home/shaunxliu/data/vctk/fidlists/train_fidlist.new"
4 | dev_fid_list: "/home/shaunxliu/data/vctk/fidlists/dev_fidlist.new"
5 | eval_fid_list: "/home/shaunxliu/data/vctk/fidlists/eval_fidlist.txt"
6 | vctk_ppg_dir: "/home/shaunxliu/data/vctk/conformer_bnf10ms"
7 | libri_ppg_dir: "/home/shaunxliu/data/LibriTTS/conformer_bnf10ms"
8 | vctk_f0_dir: "/home/shaunxliu/data/vctk/merged_f0s"
9 | libri_f0_dir: "/home/shaunxliu/data/LibriTTS/f0s"
10 | vctk_wav_dir: "/home/shaunxliu/data/vctk/wav_mono_24k_16b_norm-6db"
11 | libri_wav_dir: "/home/shaunxliu/data/LibriTTS/LibriTTS/train-wavs-clean460/"
12 | vctk_spk_dvec_dir: "/home/shaunxliu/data/vctk/GE2E_spkEmbed_step_5805000_perSpk"
13 | libri_spk_dvec_dir: "/home/shaunxliu/data/LibriTTS/GE2E_spkEmbed_step_5805000_perSpk"
14 | ppg_file_ext: "ling_feat.npy"
15 | f0_file_ext: "f0.npy"
16 | wav_file_ext: "wav"
17 | min_max_norm_mel: true
18 | mel_min: -12.0
19 | mel_max: 2.5
20 |
21 | hparas:
22 | batch_size: 32
23 | valid_step: 2000
24 | max_step: 1000000
25 | optimizer: 'Adam'
26 | lr: 0.001
27 | eps: 1.0e-8
28 | weight_decay: 1.0e-6
29 | lr_scheduler: 'warmup' # "fixed", "warmup"
30 |
31 | model_name: "seq2seqmolv2"
32 | model:
33 | num_speakers: 1250
34 | spk_embed_dim: 256
35 | bottle_neck_feature_dim: 144
36 | encoder_downsample_rates: [2, 2]
37 | attention_rnn_dim: 512
38 | attention_dim: 512
39 | decoder_rnn_dim: 512
40 | num_decoder_rnn_layer: 1
41 | concat_context_to_last: True
42 | prenet_dims: [256, 128]
43 | prenet_dropout: 0.5
44 | num_mixtures: 5
45 | frames_per_step: 4
46 | postnet_num_layers: 5
47 | postnet_hidden_dim: 512
48 | mask_padding: True
49 | use_spk_dvec: True
50 |
--------------------------------------------------------------------------------
/conformer_ppg_model/build_ppg_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from pathlib import Path
4 | import yaml
5 |
6 |
7 | from .frontend import DefaultFrontend
8 | from .utterance_mvn import UtteranceMVN
9 | from .encoder.conformer_encoder import ConformerEncoder
10 |
11 |
12 | class PPGModel(torch.nn.Module):
13 | def __init__(
14 | self,
15 | frontend,
16 | normalizer,
17 | encoder,
18 | ):
19 | super().__init__()
20 | self.frontend = frontend
21 | self.normalize = normalizer
22 | self.encoder = encoder
23 |
24 | def forward(self, speech, speech_lengths):
25 | """
26 |
27 | Args:
28 | speech (tensor): (B, L)
29 | speech_lengths (tensor): (B, )
30 |
31 | Returns:
32 | bottle_neck_feats (tensor): (B, L//hop_size, 144)
33 |
34 | """
35 | feats, feats_lengths = self._extract_feats(speech, speech_lengths)
36 | feats, feats_lengths = self.normalize(feats, feats_lengths)
37 | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
38 | return encoder_out
39 |
40 | def _extract_feats(
41 | self, speech: torch.Tensor, speech_lengths: torch.Tensor
42 | ):
43 | assert speech_lengths.dim() == 1, speech_lengths.shape
44 |
45 | # for data-parallel
46 | speech = speech[:, : speech_lengths.max()]
47 |
48 | if self.frontend is not None:
49 | # Frontend
50 | # e.g. STFT and Feature extract
51 | # data_loader may send time-domain signal in this case
52 | # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
53 | feats, feats_lengths = self.frontend(speech, speech_lengths)
54 | else:
55 | # No frontend and no feature extract
56 | feats, feats_lengths = speech, speech_lengths
57 | return feats, feats_lengths
58 |
59 |
60 | def build_model(args):
61 | normalizer = UtteranceMVN(**args.normalize_conf)
62 | frontend = DefaultFrontend(**args.frontend_conf)
63 | encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
64 | model = PPGModel(frontend, normalizer, encoder)
65 |
66 | return model
67 |
68 |
69 | def load_ppg_model(train_config, model_file, device):
70 | config_file = Path(train_config)
71 | with config_file.open("r", encoding="utf-8") as f:
72 | args = yaml.safe_load(f)
73 |
74 | args = argparse.Namespace(**args)
75 |
76 | model = build_model(args)
77 | model_state_dict = model.state_dict()
78 |
79 | ckpt_state_dict = torch.load(model_file, map_location='cpu')
80 | ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
81 |
82 | model_state_dict.update(ckpt_state_dict)
83 | model.load_state_dict(model_state_dict)
84 |
85 | return model.eval().to(device)
86 |
--------------------------------------------------------------------------------
/conformer_ppg_model/en_conformer_ctc_att/24epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/conformer_ppg_model/en_conformer_ctc_att/24epoch.pth
--------------------------------------------------------------------------------
/conformer_ppg_model/en_conformer_ctc_att/config.yaml:
--------------------------------------------------------------------------------
1 | config: ./conf/train_asr_conformer_nodownsample.v1.yaml
2 | print_config: false
3 | log_level: INFO
4 | dry_run: false
5 | iterator_type: sequence
6 | output_dir: exp/asr_train_asr_conformer_nodownsample.v1_raw_sp
7 | ngpu: 1
8 | seed: 0
9 | num_workers: 1
10 | num_att_plot: 5
11 | dist_backend: nccl
12 | dist_init_method: env://
13 | dist_world_size: 4
14 | dist_rank: 1
15 | local_rank: 1
16 | dist_master_addr: localhost
17 | dist_master_port: 50826
18 | dist_launcher: null
19 | multiprocessing_distributed: true
20 | cudnn_enabled: true
21 | cudnn_benchmark: false
22 | cudnn_deterministic: true
23 | collect_stats: false
24 | write_collected_feats: false
25 | max_epoch: 2000
26 | patience: 200
27 | val_scheduler_criterion:
28 | - valid
29 | - acc
30 | early_stopping_criterion:
31 | - valid
32 | - loss
33 | - min
34 | best_model_criterion:
35 | - - valid
36 | - acc
37 | - max
38 | keep_nbest_models: 10
39 | grad_clip: 5.0
40 | grad_noise: false
41 | accum_grad: 8
42 | no_forward_run: false
43 | resume: true
44 | train_dtype: float32
45 | log_interval: null
46 | pretrain_path: []
47 | pretrain_key: []
48 | num_iters_per_epoch: null
49 | batch_size: 32
50 | valid_batch_size: 32
51 | batch_bins: 1000000
52 | valid_batch_bins: 1000000
53 | train_shape_file:
54 | - exp/asr_stats_raw_sp/train/speech_shape
55 | - exp/asr_stats_raw_sp/train/text_shape.phn
56 | valid_shape_file:
57 | - exp/asr_stats_raw_sp/valid/speech_shape
58 | - exp/asr_stats_raw_sp/valid/text_shape.phn
59 | batch_type: folded
60 | valid_batch_type: folded
61 | fold_length:
62 | - 128000
63 | - 150
64 | sort_in_batch: descending
65 | sort_batch: descending
66 | chunk_length: 500
67 | chunk_shift_ratio: 0.5
68 | num_cache_chunks: 1024
69 | train_data_path_and_name_and_type:
70 | - - dump/raw/train_960_sp/wav.scp
71 | - speech
72 | - sound
73 | - - dump/raw/train_960_sp/text
74 | - text
75 | - text
76 | valid_data_path_and_name_and_type:
77 | - - dump/raw/dev_set/wav.scp
78 | - speech
79 | - sound
80 | - - dump/raw/dev_set/text
81 | - text
82 | - text
83 | allow_variable_data_keys: false
84 | max_cache_size: 0.0
85 | valid_max_cache_size: 0.0
86 | optim: adam
87 | optim_conf:
88 | lr: 0.0015
89 | scheduler: warmuplr
90 | scheduler_conf:
91 | warmup_steps: 25000
92 | token_list:
93 | -
94 | -
95 | - AA0
96 | - AA1
97 | - AA2
98 | - AE0
99 | - AE1
100 | - AE2
101 | - AH0
102 | - AH1
103 | - AH2
104 | - AO0
105 | - AO1
106 | - AO2
107 | - AW0
108 | - AW1
109 | - AW2
110 | - AY0
111 | - AY1
112 | - AY2
113 | - B
114 | - CH
115 | - D
116 | - DH
117 | - EH0
118 | - EH1
119 | - EH2
120 | - ER0
121 | - ER1
122 | - ER2
123 | - EY0
124 | - EY1
125 | - EY2
126 | - F
127 | - G
128 | - HH
129 | - IH0
130 | - IH1
131 | - IH2
132 | - IY0
133 | - IY1
134 | - IY2
135 | - JH
136 | - K
137 | - L
138 | - M
139 | - N
140 | - NG
141 | - OW0
142 | - OW1
143 | - OW2
144 | - OY0
145 | - OY1
146 | - OY2
147 | - P
148 | - R
149 | - S
150 | - SH
151 | - T
152 | - TH
153 | - UH0
154 | - UH1
155 | - UH2
156 | - UW0
157 | - UW1
158 | - UW2
159 | - V
160 | - W
161 | - Y
162 | - Z
163 | - ZH
164 | - sil
165 | - sp
166 | - spn
167 | -
168 | init: null
169 | input_size: null
170 | ctc_conf:
171 | dropout_rate: 0.0
172 | ctc_type: builtin
173 | reduce: true
174 | model_conf:
175 | ctc_weight: 0.5
176 | lsm_weight: 0.1
177 | length_normalized_loss: false
178 | use_preprocessor: true
179 | token_type: phn
180 | bpemodel: null
181 | non_linguistic_symbols: null
182 | frontend: default
183 | frontend_conf:
184 | fs: 16000
185 | specaug: specaug
186 | specaug_conf:
187 | apply_time_warp: true
188 | time_warp_window: 5
189 | time_warp_mode: bicubic
190 | apply_freq_mask: true
191 | freq_mask_width_range:
192 | - 0
193 | - 30
194 | num_freq_mask: 2
195 | apply_time_mask: true
196 | time_mask_width_range:
197 | - 0
198 | - 40
199 | num_time_mask: 2
200 | normalize: utterance_mvn
201 | normalize_conf:
202 | norm_means: true
203 | norm_vars: true
204 | encoder: conformer
205 | encoder_conf:
206 | attention_dim: 144
207 | attention_heads: 4
208 | linear_units: 576
209 | num_blocks: 16
210 | dropout_rate: 0.1
211 | positional_dropout_rate: 0.1
212 | attention_dropout_rate: 0.0
213 | input_layer: conv2d
214 | normalize_before: true
215 | concat_after: false
216 | positionwise_layer_type: linear
217 | positionwise_conv_kernel_size: 1
218 | macaron_style: true
219 | pos_enc_layer_type: rel_pos
220 | selfattention_layer_type: rel_selfattn
221 | activation_type: swish
222 | use_cnn_module: true
223 | cnn_module_kernel: 15
224 | no_subsample: true
225 | subsample_by_2: false
226 | decoder: rnn
227 | decoder_conf:
228 | rnn_type: lstm
229 | num_layers: 1
230 | hidden_size: 320
231 | sampling_probability: 0.0
232 | dropout: 0.0
233 | att_conf:
234 | adim: 320
235 | required:
236 | - output_dir
237 | - token_list
238 | distributed: true
239 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/conformer_ppg_model/encoder/__init__.py
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/attention.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Multi-Head Attention layer definition."""
8 |
9 | import math
10 |
11 | import numpy
12 | import torch
13 | from torch import nn
14 |
15 |
16 | class MultiHeadedAttention(nn.Module):
17 | """Multi-Head Attention layer.
18 |
19 | :param int n_head: the number of head s
20 | :param int n_feat: the number of features
21 | :param float dropout_rate: dropout rate
22 |
23 | """
24 |
25 | def __init__(self, n_head, n_feat, dropout_rate):
26 | """Construct an MultiHeadedAttention object."""
27 | super(MultiHeadedAttention, self).__init__()
28 | assert n_feat % n_head == 0
29 | # We assume d_v always equals d_k
30 | self.d_k = n_feat // n_head
31 | self.h = n_head
32 | self.linear_q = nn.Linear(n_feat, n_feat)
33 | self.linear_k = nn.Linear(n_feat, n_feat)
34 | self.linear_v = nn.Linear(n_feat, n_feat)
35 | self.linear_out = nn.Linear(n_feat, n_feat)
36 | self.attn = None
37 | self.dropout = nn.Dropout(p=dropout_rate)
38 |
39 | def forward_qkv(self, query, key, value):
40 | """Transform query, key and value.
41 |
42 | :param torch.Tensor query: (batch, time1, size)
43 | :param torch.Tensor key: (batch, time2, size)
44 | :param torch.Tensor value: (batch, time2, size)
45 | :return torch.Tensor transformed query, key and value
46 |
47 | """
48 | n_batch = query.size(0)
49 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
50 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
51 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
52 | q = q.transpose(1, 2) # (batch, head, time1, d_k)
53 | k = k.transpose(1, 2) # (batch, head, time2, d_k)
54 | v = v.transpose(1, 2) # (batch, head, time2, d_k)
55 |
56 | return q, k, v
57 |
58 | def forward_attention(self, value, scores, mask):
59 | """Compute attention context vector.
60 |
61 | :param torch.Tensor value: (batch, head, time2, size)
62 | :param torch.Tensor scores: (batch, head, time1, time2)
63 | :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
64 | :return torch.Tensor transformed `value` (batch, time1, d_model)
65 | weighted by the attention score (batch, time1, time2)
66 |
67 | """
68 | n_batch = value.size(0)
69 | if mask is not None:
70 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
71 | min_value = float(
72 | numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
73 | )
74 | scores = scores.masked_fill(mask, min_value)
75 | self.attn = torch.softmax(scores, dim=-1).masked_fill(
76 | mask, 0.0
77 | ) # (batch, head, time1, time2)
78 | else:
79 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
80 |
81 | p_attn = self.dropout(self.attn)
82 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
83 | x = (
84 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
85 | ) # (batch, time1, d_model)
86 |
87 | return self.linear_out(x) # (batch, time1, d_model)
88 |
89 | def forward(self, query, key, value, mask):
90 | """Compute 'Scaled Dot Product Attention'.
91 |
92 | :param torch.Tensor query: (batch, time1, size)
93 | :param torch.Tensor key: (batch, time2, size)
94 | :param torch.Tensor value: (batch, time2, size)
95 | :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
96 | :param torch.nn.Dropout dropout:
97 | :return torch.Tensor: attention output (batch, time1, d_model)
98 | """
99 | q, k, v = self.forward_qkv(query, key, value)
100 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
101 | return self.forward_attention(v, scores, mask)
102 |
103 |
104 | class RelPositionMultiHeadedAttention(MultiHeadedAttention):
105 | """Multi-Head Attention layer with relative position encoding.
106 |
107 | Paper: https://arxiv.org/abs/1901.02860
108 |
109 | :param int n_head: the number of head s
110 | :param int n_feat: the number of features
111 | :param float dropout_rate: dropout rate
112 |
113 | """
114 |
115 | def __init__(self, n_head, n_feat, dropout_rate):
116 | """Construct an RelPositionMultiHeadedAttention object."""
117 | super().__init__(n_head, n_feat, dropout_rate)
118 | # linear transformation for positional ecoding
119 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
120 | # these two learnable bias are used in matrix c and matrix d
121 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3
122 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
123 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
124 | torch.nn.init.xavier_uniform_(self.pos_bias_u)
125 | torch.nn.init.xavier_uniform_(self.pos_bias_v)
126 |
127 | def rel_shift(self, x, zero_triu=False):
128 | """Compute relative positinal encoding.
129 |
130 | :param torch.Tensor x: (batch, time, size)
131 | :param bool zero_triu: return the lower triangular part of the matrix
132 | """
133 | zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
134 | x_padded = torch.cat([zero_pad, x], dim=-1)
135 |
136 | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
137 | x = x_padded[:, :, 1:].view_as(x)
138 |
139 | if zero_triu:
140 | ones = torch.ones((x.size(2), x.size(3)))
141 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
142 |
143 | return x
144 |
145 | def forward(self, query, key, value, pos_emb, mask):
146 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
147 |
148 | :param torch.Tensor query: (batch, time1, size)
149 | :param torch.Tensor key: (batch, time2, size)
150 | :param torch.Tensor value: (batch, time2, size)
151 | :param torch.Tensor pos_emb: (batch, time1, size)
152 | :param torch.Tensor mask: (batch, time1, time2)
153 | :param torch.nn.Dropout dropout:
154 | :return torch.Tensor: attention output (batch, time1, d_model)
155 | """
156 | q, k, v = self.forward_qkv(query, key, value)
157 | q = q.transpose(1, 2) # (batch, time1, head, d_k)
158 |
159 | n_batch_pos = pos_emb.size(0)
160 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
161 | p = p.transpose(1, 2) # (batch, head, time1, d_k)
162 |
163 | # (batch, head, time1, d_k)
164 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
165 | # (batch, head, time1, d_k)
166 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
167 |
168 | # compute attention score
169 | # first compute matrix a and matrix c
170 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3
171 | # (batch, head, time1, time2)
172 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
173 |
174 | # compute matrix b and matrix d
175 | # (batch, head, time1, time2)
176 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
177 | matrix_bd = self.rel_shift(matrix_bd)
178 |
179 | scores = (matrix_ac + matrix_bd) / math.sqrt(
180 | self.d_k
181 | ) # (batch, head, time1, time2)
182 |
183 | return self.forward_attention(v, scores, mask)
184 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/convolution.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
5 | # Northwestern Polytechnical University (Pengcheng Guo)
6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
7 |
8 | """ConvolutionModule definition."""
9 |
10 | from torch import nn
11 |
12 |
13 | class ConvolutionModule(nn.Module):
14 | """ConvolutionModule in Conformer model.
15 |
16 | :param int channels: channels of cnn
17 | :param int kernel_size: kernerl size of cnn
18 |
19 | """
20 |
21 | def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
22 | """Construct an ConvolutionModule object."""
23 | super(ConvolutionModule, self).__init__()
24 | # kernerl_size should be a odd number for 'SAME' padding
25 | assert (kernel_size - 1) % 2 == 0
26 |
27 | self.pointwise_conv1 = nn.Conv1d(
28 | channels,
29 | 2 * channels,
30 | kernel_size=1,
31 | stride=1,
32 | padding=0,
33 | bias=bias,
34 | )
35 | self.depthwise_conv = nn.Conv1d(
36 | channels,
37 | channels,
38 | kernel_size,
39 | stride=1,
40 | padding=(kernel_size - 1) // 2,
41 | groups=channels,
42 | bias=bias,
43 | )
44 | self.norm = nn.BatchNorm1d(channels)
45 | self.pointwise_conv2 = nn.Conv1d(
46 | channels,
47 | channels,
48 | kernel_size=1,
49 | stride=1,
50 | padding=0,
51 | bias=bias,
52 | )
53 | self.activation = activation
54 |
55 | def forward(self, x):
56 | """Compute convolution module.
57 |
58 | :param torch.Tensor x: (batch, time, size)
59 | :return torch.Tensor: convoluted `value` (batch, time, d_model)
60 | """
61 | # exchange the temporal dimension and the feature dimension
62 | x = x.transpose(1, 2)
63 |
64 | # GLU mechanism
65 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
66 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
67 |
68 | # 1D Depthwise Conv
69 | x = self.depthwise_conv(x)
70 | x = self.activation(self.norm(x))
71 |
72 | x = self.pointwise_conv2(x)
73 |
74 | return x.transpose(1, 2)
75 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/embedding.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Positonal Encoding Module."""
8 |
9 | import math
10 |
11 | import torch
12 |
13 |
14 | def _pre_hook(
15 | state_dict,
16 | prefix,
17 | local_metadata,
18 | strict,
19 | missing_keys,
20 | unexpected_keys,
21 | error_msgs,
22 | ):
23 | """Perform pre-hook in load_state_dict for backward compatibility.
24 |
25 | Note:
26 | We saved self.pe until v.0.5.2 but we have omitted it later.
27 | Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
28 |
29 | """
30 | k = prefix + "pe"
31 | if k in state_dict:
32 | state_dict.pop(k)
33 |
34 |
35 | class PositionalEncoding(torch.nn.Module):
36 | """Positional encoding.
37 |
38 | :param int d_model: embedding dim
39 | :param float dropout_rate: dropout rate
40 | :param int max_len: maximum input length
41 | :param reverse: whether to reverse the input position
42 |
43 | """
44 |
45 | def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
46 | """Construct an PositionalEncoding object."""
47 | super(PositionalEncoding, self).__init__()
48 | self.d_model = d_model
49 | self.reverse = reverse
50 | self.xscale = math.sqrt(self.d_model)
51 | self.dropout = torch.nn.Dropout(p=dropout_rate)
52 | self.pe = None
53 | self.extend_pe(torch.tensor(0.0).expand(1, max_len))
54 | self._register_load_state_dict_pre_hook(_pre_hook)
55 |
56 | def extend_pe(self, x):
57 | """Reset the positional encodings."""
58 | if self.pe is not None:
59 | if self.pe.size(1) >= x.size(1):
60 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
61 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
62 | return
63 | pe = torch.zeros(x.size(1), self.d_model)
64 | if self.reverse:
65 | position = torch.arange(
66 | x.size(1) - 1, -1, -1.0, dtype=torch.float32
67 | ).unsqueeze(1)
68 | else:
69 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
70 | div_term = torch.exp(
71 | torch.arange(0, self.d_model, 2, dtype=torch.float32)
72 | * -(math.log(10000.0) / self.d_model)
73 | )
74 | pe[:, 0::2] = torch.sin(position * div_term)
75 | pe[:, 1::2] = torch.cos(position * div_term)
76 | pe = pe.unsqueeze(0)
77 | self.pe = pe.to(device=x.device, dtype=x.dtype)
78 |
79 | def forward(self, x: torch.Tensor):
80 | """Add positional encoding.
81 |
82 | Args:
83 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
84 |
85 | Returns:
86 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
87 |
88 | """
89 | self.extend_pe(x)
90 | x = x * self.xscale + self.pe[:, : x.size(1)]
91 | return self.dropout(x)
92 |
93 |
94 | class ScaledPositionalEncoding(PositionalEncoding):
95 | """Scaled positional encoding module.
96 |
97 | See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
98 |
99 | """
100 |
101 | def __init__(self, d_model, dropout_rate, max_len=5000):
102 | """Initialize class.
103 |
104 | :param int d_model: embedding dim
105 | :param float dropout_rate: dropout rate
106 | :param int max_len: maximum input length
107 |
108 | """
109 | super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
110 | self.alpha = torch.nn.Parameter(torch.tensor(1.0))
111 |
112 | def reset_parameters(self):
113 | """Reset parameters."""
114 | self.alpha.data = torch.tensor(1.0)
115 |
116 | def forward(self, x):
117 | """Add positional encoding.
118 |
119 | Args:
120 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
121 |
122 | Returns:
123 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
124 |
125 | """
126 | self.extend_pe(x)
127 | x = x + self.alpha * self.pe[:, : x.size(1)]
128 | return self.dropout(x)
129 |
130 |
131 | class RelPositionalEncoding(PositionalEncoding):
132 | """Relitive positional encoding module.
133 |
134 | See : Appendix B in https://arxiv.org/abs/1901.02860
135 |
136 | :param int d_model: embedding dim
137 | :param float dropout_rate: dropout rate
138 | :param int max_len: maximum input length
139 |
140 | """
141 |
142 | def __init__(self, d_model, dropout_rate, max_len=5000):
143 | """Initialize class.
144 |
145 | :param int d_model: embedding dim
146 | :param float dropout_rate: dropout rate
147 | :param int max_len: maximum input length
148 |
149 | """
150 | super().__init__(d_model, dropout_rate, max_len, reverse=True)
151 |
152 | def forward(self, x):
153 | """Compute positional encoding.
154 |
155 | Args:
156 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
157 |
158 | Returns:
159 | torch.Tensor: x. Its shape is (batch, time, ...)
160 | torch.Tensor: pos_emb. Its shape is (1, time, ...)
161 |
162 | """
163 | self.extend_pe(x)
164 | x = x * self.xscale
165 | pos_emb = self.pe[:, : x.size(1)]
166 | return self.dropout(x), self.dropout(pos_emb)
167 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/encoder_layer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Encoder self-attention layer definition."""
8 |
9 | import torch
10 |
11 | from torch import nn
12 |
13 | from .layer_norm import LayerNorm
14 |
15 |
16 | class EncoderLayer(nn.Module):
17 | """Encoder layer module.
18 |
19 | :param int size: input dim
20 | :param espnet.nets.pytorch_backend.transformer.attention.
21 | MultiHeadedAttention self_attn: self attention module
22 | RelPositionMultiHeadedAttention self_attn: self attention module
23 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
24 | PositionwiseFeedForward feed_forward:
25 | feed forward module
26 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
27 | for macaron style
28 | PositionwiseFeedForward feed_forward:
29 | feed forward module
30 | :param espnet.nets.pytorch_backend.conformer.convolution.
31 | ConvolutionModule feed_foreard:
32 | feed forward module
33 | :param float dropout_rate: dropout rate
34 | :param bool normalize_before: whether to use layer_norm before the first block
35 | :param bool concat_after: whether to concat attention layer's input and output
36 | if True, additional linear will be applied.
37 | i.e. x -> x + linear(concat(x, att(x)))
38 | if False, no additional linear will be applied. i.e. x -> x + att(x)
39 |
40 | """
41 |
42 | def __init__(
43 | self,
44 | size,
45 | self_attn,
46 | feed_forward,
47 | feed_forward_macaron,
48 | conv_module,
49 | dropout_rate,
50 | normalize_before=True,
51 | concat_after=False,
52 | ):
53 | """Construct an EncoderLayer object."""
54 | super(EncoderLayer, self).__init__()
55 | self.self_attn = self_attn
56 | self.feed_forward = feed_forward
57 | self.feed_forward_macaron = feed_forward_macaron
58 | self.conv_module = conv_module
59 | self.norm_ff = LayerNorm(size) # for the FNN module
60 | self.norm_mha = LayerNorm(size) # for the MHA module
61 | if feed_forward_macaron is not None:
62 | self.norm_ff_macaron = LayerNorm(size)
63 | self.ff_scale = 0.5
64 | else:
65 | self.ff_scale = 1.0
66 | if self.conv_module is not None:
67 | self.norm_conv = LayerNorm(size) # for the CNN module
68 | self.norm_final = LayerNorm(size) # for the final output of the block
69 | self.dropout = nn.Dropout(dropout_rate)
70 | self.size = size
71 | self.normalize_before = normalize_before
72 | self.concat_after = concat_after
73 | if self.concat_after:
74 | self.concat_linear = nn.Linear(size + size, size)
75 |
76 | def forward(self, x_input, mask, cache=None):
77 | """Compute encoded features.
78 |
79 | :param torch.Tensor x_input: encoded source features, w/o pos_emb
80 | tuple((batch, max_time_in, size), (1, max_time_in, size))
81 | or (batch, max_time_in, size)
82 | :param torch.Tensor mask: mask for x (batch, max_time_in)
83 | :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
84 | :rtype: Tuple[torch.Tensor, torch.Tensor]
85 | """
86 | if isinstance(x_input, tuple):
87 | x, pos_emb = x_input[0], x_input[1]
88 | else:
89 | x, pos_emb = x_input, None
90 |
91 | # whether to use macaron style
92 | if self.feed_forward_macaron is not None:
93 | residual = x
94 | if self.normalize_before:
95 | x = self.norm_ff_macaron(x)
96 | x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
97 | if not self.normalize_before:
98 | x = self.norm_ff_macaron(x)
99 |
100 | # multi-headed self-attention module
101 | residual = x
102 | if self.normalize_before:
103 | x = self.norm_mha(x)
104 |
105 | if cache is None:
106 | x_q = x
107 | else:
108 | assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
109 | x_q = x[:, -1:, :]
110 | residual = residual[:, -1:, :]
111 | mask = None if mask is None else mask[:, -1:, :]
112 |
113 | if pos_emb is not None:
114 | x_att = self.self_attn(x_q, x, x, pos_emb, mask)
115 | else:
116 | x_att = self.self_attn(x_q, x, x, mask)
117 |
118 | if self.concat_after:
119 | x_concat = torch.cat((x, x_att), dim=-1)
120 | x = residual + self.concat_linear(x_concat)
121 | else:
122 | x = residual + self.dropout(x_att)
123 | if not self.normalize_before:
124 | x = self.norm_mha(x)
125 |
126 | # convolution module
127 | if self.conv_module is not None:
128 | residual = x
129 | if self.normalize_before:
130 | x = self.norm_conv(x)
131 | x = residual + self.dropout(self.conv_module(x))
132 | if not self.normalize_before:
133 | x = self.norm_conv(x)
134 |
135 | # feed forward module
136 | residual = x
137 | if self.normalize_before:
138 | x = self.norm_ff(x)
139 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
140 | if not self.normalize_before:
141 | x = self.norm_ff(x)
142 |
143 | if self.conv_module is not None:
144 | x = self.norm_final(x)
145 |
146 | if cache is not None:
147 | x = torch.cat([cache, x], dim=1)
148 |
149 | if pos_emb is not None:
150 | return (x, pos_emb), mask
151 |
152 | return x, mask
153 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/layer_norm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Layer normalization module."""
8 |
9 | import torch
10 |
11 |
12 | class LayerNorm(torch.nn.LayerNorm):
13 | """Layer normalization module.
14 |
15 | :param int nout: output dim size
16 | :param int dim: dimension to be normalized
17 | """
18 |
19 | def __init__(self, nout, dim=-1):
20 | """Construct an LayerNorm object."""
21 | super(LayerNorm, self).__init__(nout, eps=1e-12)
22 | self.dim = dim
23 |
24 | def forward(self, x):
25 | """Apply layer normalization.
26 |
27 | :param torch.Tensor x: input tensor
28 | :return: layer normalized tensor
29 | :rtype torch.Tensor
30 | """
31 | if self.dim == -1:
32 | return super(LayerNorm, self).forward(x)
33 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
34 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/multi_layer_conv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Tomoki Hayashi
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
8 |
9 | import torch
10 |
11 |
12 | class MultiLayeredConv1d(torch.nn.Module):
13 | """Multi-layered conv1d for Transformer block.
14 |
15 | This is a module of multi-leyered conv1d designed
16 | to replace positionwise feed-forward network
17 | in Transforner block, which is introduced in
18 | `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
19 |
20 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
21 | https://arxiv.org/pdf/1905.09263.pdf
22 |
23 | """
24 |
25 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
26 | """Initialize MultiLayeredConv1d module.
27 |
28 | Args:
29 | in_chans (int): Number of input channels.
30 | hidden_chans (int): Number of hidden channels.
31 | kernel_size (int): Kernel size of conv1d.
32 | dropout_rate (float): Dropout rate.
33 |
34 | """
35 | super(MultiLayeredConv1d, self).__init__()
36 | self.w_1 = torch.nn.Conv1d(
37 | in_chans,
38 | hidden_chans,
39 | kernel_size,
40 | stride=1,
41 | padding=(kernel_size - 1) // 2,
42 | )
43 | self.w_2 = torch.nn.Conv1d(
44 | hidden_chans,
45 | in_chans,
46 | kernel_size,
47 | stride=1,
48 | padding=(kernel_size - 1) // 2,
49 | )
50 | self.dropout = torch.nn.Dropout(dropout_rate)
51 |
52 | def forward(self, x):
53 | """Calculate forward propagation.
54 |
55 | Args:
56 | x (Tensor): Batch of input tensors (B, ..., in_chans).
57 |
58 | Returns:
59 | Tensor: Batch of output tensors (B, ..., hidden_chans).
60 |
61 | """
62 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
63 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
64 |
65 |
66 | class Conv1dLinear(torch.nn.Module):
67 | """Conv1D + Linear for Transformer block.
68 |
69 | A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
70 |
71 | """
72 |
73 | def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
74 | """Initialize Conv1dLinear module.
75 |
76 | Args:
77 | in_chans (int): Number of input channels.
78 | hidden_chans (int): Number of hidden channels.
79 | kernel_size (int): Kernel size of conv1d.
80 | dropout_rate (float): Dropout rate.
81 |
82 | """
83 | super(Conv1dLinear, self).__init__()
84 | self.w_1 = torch.nn.Conv1d(
85 | in_chans,
86 | hidden_chans,
87 | kernel_size,
88 | stride=1,
89 | padding=(kernel_size - 1) // 2,
90 | )
91 | self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
92 | self.dropout = torch.nn.Dropout(dropout_rate)
93 |
94 | def forward(self, x):
95 | """Calculate forward propagation.
96 |
97 | Args:
98 | x (Tensor): Batch of input tensors (B, ..., in_chans).
99 |
100 | Returns:
101 | Tensor: Batch of output tensors (B, ..., hidden_chans).
102 |
103 | """
104 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
105 | return self.w_2(self.dropout(x))
106 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/positionwise_feed_forward.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Positionwise feed forward layer definition."""
8 |
9 | import torch
10 |
11 |
12 | class PositionwiseFeedForward(torch.nn.Module):
13 | """Positionwise feed forward layer.
14 |
15 | :param int idim: input dimenstion
16 | :param int hidden_units: number of hidden units
17 | :param float dropout_rate: dropout rate
18 |
19 | """
20 |
21 | def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
22 | """Construct an PositionwiseFeedForward object."""
23 | super(PositionwiseFeedForward, self).__init__()
24 | self.w_1 = torch.nn.Linear(idim, hidden_units)
25 | self.w_2 = torch.nn.Linear(hidden_units, idim)
26 | self.dropout = torch.nn.Dropout(dropout_rate)
27 | self.activation = activation
28 |
29 | def forward(self, x):
30 | """Forward funciton."""
31 | return self.w_2(self.dropout(self.activation(self.w_1(x))))
32 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/repeat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Repeat the same layer definition."""
8 |
9 | import torch
10 |
11 |
12 | class MultiSequential(torch.nn.Sequential):
13 | """Multi-input multi-output torch.nn.Sequential."""
14 |
15 | def forward(self, *args):
16 | """Repeat."""
17 | for m in self:
18 | args = m(*args)
19 | return args
20 |
21 |
22 | def repeat(N, fn):
23 | """Repeat module N times.
24 |
25 | :param int N: repeat time
26 | :param function fn: function to generate module
27 | :return: repeated modules
28 | :rtype: MultiSequential
29 | """
30 | return MultiSequential(*[fn(n) for n in range(N)])
31 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/subsampling.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Shigeki Karita
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Subsampling layer definition."""
8 | import logging
9 | import torch
10 |
11 | from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
12 |
13 |
14 | class Conv2dSubsampling(torch.nn.Module):
15 | """Convolutional 2D subsampling (to 1/4 length or 1/2 length).
16 |
17 | :param int idim: input dim
18 | :param int odim: output dim
19 | :param flaot dropout_rate: dropout rate
20 | :param torch.nn.Module pos_enc: custom position encoding layer
21 |
22 | """
23 |
24 | def __init__(self, idim, odim, dropout_rate, pos_enc=None,
25 | subsample_by_2=False,
26 | ):
27 | """Construct an Conv2dSubsampling object."""
28 | super(Conv2dSubsampling, self).__init__()
29 | self.subsample_by_2 = subsample_by_2
30 | if subsample_by_2:
31 | self.conv = torch.nn.Sequential(
32 | torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
33 | torch.nn.ReLU(),
34 | torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
35 | torch.nn.ReLU(),
36 | )
37 | self.out = torch.nn.Sequential(
38 | torch.nn.Linear(odim * (idim // 2), odim),
39 | pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
40 | )
41 | else:
42 | self.conv = torch.nn.Sequential(
43 | torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
44 | torch.nn.ReLU(),
45 | torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
46 | torch.nn.ReLU(),
47 | )
48 | self.out = torch.nn.Sequential(
49 | torch.nn.Linear(odim * (idim // 4), odim),
50 | pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
51 | )
52 |
53 | def forward(self, x, x_mask):
54 | """Subsample x.
55 |
56 | :param torch.Tensor x: input tensor
57 | :param torch.Tensor x_mask: input mask
58 | :return: subsampled x and mask
59 | :rtype Tuple[torch.Tensor, torch.Tensor]
60 |
61 | """
62 | x = x.unsqueeze(1) # (b, c, t, f)
63 | x = self.conv(x)
64 | b, c, t, f = x.size()
65 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
66 | if x_mask is None:
67 | return x, None
68 | if self.subsample_by_2:
69 | return x, x_mask[:, :, ::2]
70 | else:
71 | return x, x_mask[:, :, ::2][:, :, ::2]
72 |
73 | def __getitem__(self, key):
74 | """Subsample x.
75 |
76 | When reset_parameters() is called, if use_scaled_pos_enc is used,
77 | return the positioning encoding.
78 |
79 | """
80 | if key != -1:
81 | raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
82 | return self.out[key]
83 |
84 |
85 | class Conv2dNoSubsampling(torch.nn.Module):
86 | """Convolutional 2D without subsampling.
87 |
88 | :param int idim: input dim
89 | :param int odim: output dim
90 | :param flaot dropout_rate: dropout rate
91 | :param torch.nn.Module pos_enc: custom position encoding layer
92 |
93 | """
94 |
95 | def __init__(self, idim, odim, dropout_rate, pos_enc=None):
96 | """Construct an Conv2dSubsampling object."""
97 | super().__init__()
98 | logging.info("Encoder does not do down-sample on mel-spectrogram.")
99 | self.conv = torch.nn.Sequential(
100 | torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
101 | torch.nn.ReLU(),
102 | torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
103 | torch.nn.ReLU(),
104 | )
105 | self.out = torch.nn.Sequential(
106 | torch.nn.Linear(odim * idim, odim),
107 | pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
108 | )
109 |
110 | def forward(self, x, x_mask):
111 | """Subsample x.
112 |
113 | :param torch.Tensor x: input tensor
114 | :param torch.Tensor x_mask: input mask
115 | :return: subsampled x and mask
116 | :rtype Tuple[torch.Tensor, torch.Tensor]
117 |
118 | """
119 | x = x.unsqueeze(1) # (b, c, t, f)
120 | x = self.conv(x)
121 | b, c, t, f = x.size()
122 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
123 | if x_mask is None:
124 | return x, None
125 | return x, x_mask
126 |
127 | def __getitem__(self, key):
128 | """Subsample x.
129 |
130 | When reset_parameters() is called, if use_scaled_pos_enc is used,
131 | return the positioning encoding.
132 |
133 | """
134 | if key != -1:
135 | raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
136 | return self.out[key]
137 |
138 |
139 | class Conv2dSubsampling6(torch.nn.Module):
140 | """Convolutional 2D subsampling (to 1/6 length).
141 |
142 | :param int idim: input dim
143 | :param int odim: output dim
144 | :param flaot dropout_rate: dropout rate
145 |
146 | """
147 |
148 | def __init__(self, idim, odim, dropout_rate):
149 | """Construct an Conv2dSubsampling object."""
150 | super(Conv2dSubsampling6, self).__init__()
151 | self.conv = torch.nn.Sequential(
152 | torch.nn.Conv2d(1, odim, 3, 2),
153 | torch.nn.ReLU(),
154 | torch.nn.Conv2d(odim, odim, 5, 3),
155 | torch.nn.ReLU(),
156 | )
157 | self.out = torch.nn.Sequential(
158 | torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
159 | PositionalEncoding(odim, dropout_rate),
160 | )
161 |
162 | def forward(self, x, x_mask):
163 | """Subsample x.
164 |
165 | :param torch.Tensor x: input tensor
166 | :param torch.Tensor x_mask: input mask
167 | :return: subsampled x and mask
168 | :rtype Tuple[torch.Tensor, torch.Tensor]
169 | """
170 | x = x.unsqueeze(1) # (b, c, t, f)
171 | x = self.conv(x)
172 | b, c, t, f = x.size()
173 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
174 | if x_mask is None:
175 | return x, None
176 | return x, x_mask[:, :, :-2:2][:, :, :-4:3]
177 |
178 |
179 | class Conv2dSubsampling8(torch.nn.Module):
180 | """Convolutional 2D subsampling (to 1/8 length).
181 |
182 | :param int idim: input dim
183 | :param int odim: output dim
184 | :param flaot dropout_rate: dropout rate
185 |
186 | """
187 |
188 | def __init__(self, idim, odim, dropout_rate):
189 | """Construct an Conv2dSubsampling object."""
190 | super(Conv2dSubsampling8, self).__init__()
191 | self.conv = torch.nn.Sequential(
192 | torch.nn.Conv2d(1, odim, 3, 2),
193 | torch.nn.ReLU(),
194 | torch.nn.Conv2d(odim, odim, 3, 2),
195 | torch.nn.ReLU(),
196 | torch.nn.Conv2d(odim, odim, 3, 2),
197 | torch.nn.ReLU(),
198 | )
199 | self.out = torch.nn.Sequential(
200 | torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
201 | PositionalEncoding(odim, dropout_rate),
202 | )
203 |
204 | def forward(self, x, x_mask):
205 | """Subsample x.
206 |
207 | :param torch.Tensor x: input tensor
208 | :param torch.Tensor x_mask: input mask
209 | :return: subsampled x and mask
210 | :rtype Tuple[torch.Tensor, torch.Tensor]
211 | """
212 | x = x.unsqueeze(1) # (b, c, t, f)
213 | x = self.conv(x)
214 | b, c, t, f = x.size()
215 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
216 | if x_mask is None:
217 | return x, None
218 | return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
219 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/swish.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
5 | # Northwestern Polytechnical University (Pengcheng Guo)
6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
7 |
8 | """Swish() activation function for Conformer."""
9 |
10 | import torch
11 |
12 |
13 | class Swish(torch.nn.Module):
14 | """Construct an Swish object."""
15 |
16 | def forward(self, x):
17 | """Return Swich activation function."""
18 | return x * torch.sigmoid(x)
19 |
--------------------------------------------------------------------------------
/conformer_ppg_model/encoder/vgg.py:
--------------------------------------------------------------------------------
1 | """VGG2L definition for transformer-transducer."""
2 |
3 | import torch
4 |
5 |
6 | class VGG2L(torch.nn.Module):
7 | """VGG2L module for transformer-transducer encoder."""
8 |
9 | def __init__(self, idim, odim):
10 | """Construct a VGG2L object.
11 |
12 | Args:
13 | idim (int): dimension of inputs
14 | odim (int): dimension of outputs
15 |
16 | """
17 | super(VGG2L, self).__init__()
18 |
19 | self.vgg2l = torch.nn.Sequential(
20 | torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
21 | torch.nn.ReLU(),
22 | torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
23 | torch.nn.ReLU(),
24 | torch.nn.MaxPool2d((3, 2)),
25 | torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
26 | torch.nn.ReLU(),
27 | torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
28 | torch.nn.ReLU(),
29 | torch.nn.MaxPool2d((2, 2)),
30 | )
31 |
32 | self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
33 |
34 | def forward(self, x, x_mask):
35 | """VGG2L forward for x.
36 |
37 | Args:
38 | x (torch.Tensor): input torch (B, T, idim)
39 | x_mask (torch.Tensor): (B, 1, T)
40 |
41 | Returns:
42 | x (torch.Tensor): input torch (B, sub(T), attention_dim)
43 | x_mask (torch.Tensor): (B, 1, sub(T))
44 |
45 | """
46 | x = x.unsqueeze(1)
47 | x = self.vgg2l(x)
48 |
49 | b, c, t, f = x.size()
50 |
51 | x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
52 |
53 | if x_mask is None:
54 | return x, None
55 | else:
56 | x_mask = self.create_new_mask(x_mask, x)
57 |
58 | return x, x_mask
59 |
60 | def create_new_mask(self, x_mask, x):
61 | """Create a subsampled version of x_mask.
62 |
63 | Args:
64 | x_mask (torch.Tensor): (B, 1, T)
65 | x (torch.Tensor): (B, sub(T), attention_dim)
66 |
67 | Returns:
68 | x_mask (torch.Tensor): (B, 1, sub(T))
69 |
70 | """
71 | x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
72 | x_mask = x_mask[:, :, :x_t1][:, :, ::3]
73 |
74 | x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
75 | x_mask = x_mask[:, :, :x_t2][:, :, ::2]
76 |
77 | return x_mask
78 |
--------------------------------------------------------------------------------
/conformer_ppg_model/frontend.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Optional
3 | from typing import Tuple
4 | from typing import Union
5 |
6 | import humanfriendly
7 | import numpy as np
8 | import torch
9 | from torch_complex.tensor import ComplexTensor
10 |
11 | from .log_mel import LogMel
12 | from .stft import Stft
13 |
14 |
15 | class DefaultFrontend(torch.nn.Module):
16 | """Conventional frontend structure for ASR
17 |
18 | Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
19 | """
20 |
21 | def __init__(
22 | self,
23 | fs: Union[int, str] = 16000,
24 | n_fft: int = 1024,
25 | win_length: int = 800,
26 | hop_length: int = 160,
27 | center: bool = True,
28 | pad_mode: str = "reflect",
29 | normalized: bool = False,
30 | onesided: bool = True,
31 | n_mels: int = 80,
32 | fmin: int = None,
33 | fmax: int = None,
34 | htk: bool = False,
35 | norm=1,
36 | frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend),
37 | kaldi_padding_mode=False,
38 | downsample_rate: int = 1,
39 | ):
40 | super().__init__()
41 | if isinstance(fs, str):
42 | fs = humanfriendly.parse_size(fs)
43 | self.downsample_rate = downsample_rate
44 |
45 | # Deepcopy (In general, dict shouldn't be used as default arg)
46 | frontend_conf = copy.deepcopy(frontend_conf)
47 |
48 | self.stft = Stft(
49 | n_fft=n_fft,
50 | win_length=win_length,
51 | hop_length=hop_length,
52 | center=center,
53 | pad_mode=pad_mode,
54 | normalized=normalized,
55 | onesided=onesided,
56 | kaldi_padding_mode=kaldi_padding_mode,
57 | )
58 | if frontend_conf is not None:
59 | self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
60 | else:
61 | self.frontend = None
62 |
63 | self.logmel = LogMel(
64 | fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
65 | )
66 | self.n_mels = n_mels
67 |
68 | def output_size(self) -> int:
69 | return self.n_mels
70 |
71 | def forward(
72 | self, input: torch.Tensor, input_lengths: torch.Tensor
73 | ) -> Tuple[torch.Tensor, torch.Tensor]:
74 | # 1. Domain-conversion: e.g. Stft: time -> time-freq
75 | input_stft, feats_lens = self.stft(input, input_lengths)
76 |
77 | assert input_stft.dim() >= 4, input_stft.shape
78 | # "2" refers to the real/imag parts of Complex
79 | assert input_stft.shape[-1] == 2, input_stft.shape
80 |
81 | # Change torch.Tensor to ComplexTensor
82 | # input_stft: (..., F, 2) -> (..., F)
83 | input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
84 |
85 | # 2. [Option] Speech enhancement
86 | if self.frontend is not None:
87 | assert isinstance(input_stft, ComplexTensor), type(input_stft)
88 | # input_stft: (Batch, Length, [Channel], Freq)
89 | input_stft, _, mask = self.frontend(input_stft, feats_lens)
90 |
91 | # 3. [Multi channel case]: Select a channel
92 | if input_stft.dim() == 4:
93 | # h: (B, T, C, F) -> h: (B, T, F)
94 | if self.training:
95 | # Select 1ch randomly
96 | ch = np.random.randint(input_stft.size(2))
97 | input_stft = input_stft[:, :, ch, :]
98 | else:
99 | # Use the first channel
100 | input_stft = input_stft[:, :, 0, :]
101 |
102 | # 4. STFT -> Power spectrum
103 | # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
104 | input_power = input_stft.real ** 2 + input_stft.imag ** 2
105 |
106 | # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
107 | # input_power: (Batch, [Channel,] Length, Freq)
108 | # -> input_feats: (Batch, Length, Dim)
109 | input_feats, _ = self.logmel(input_power, feats_lens)
110 |
111 | # NOTE(sx): pad
112 | max_len = input_feats.size(1)
113 | if self.downsample_rate > 1 and max_len % self.downsample_rate != 0:
114 | padding = self.downsample_rate - max_len % self.downsample_rate
115 | # print("Logmel: ", input_feats.size())
116 | input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding),
117 | "constant", 0)
118 | # print("Logmel(after padding): ",input_feats.size())
119 | feats_lens[torch.argmax(feats_lens)] = max_len + padding
120 |
121 | return input_feats, feats_lens
122 |
--------------------------------------------------------------------------------
/conformer_ppg_model/log_mel.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | import torch
4 | from typing import Tuple
5 |
6 | from .nets_utils import make_pad_mask
7 |
8 |
9 | class LogMel(torch.nn.Module):
10 | """Convert STFT to fbank feats
11 |
12 | The arguments is same as librosa.filters.mel
13 |
14 | Args:
15 | fs: number > 0 [scalar] sampling rate of the incoming signal
16 | n_fft: int > 0 [scalar] number of FFT components
17 | n_mels: int > 0 [scalar] number of Mel bands to generate
18 | fmin: float >= 0 [scalar] lowest frequency (in Hz)
19 | fmax: float >= 0 [scalar] highest frequency (in Hz).
20 | If `None`, use `fmax = fs / 2.0`
21 | htk: use HTK formula instead of Slaney
22 | norm: {None, 1, np.inf} [scalar]
23 | if 1, divide the triangular mel weights by the width of the mel band
24 | (area normalization). Otherwise, leave all the triangles aiming for
25 | a peak value of 1.0
26 |
27 | """
28 |
29 | def __init__(
30 | self,
31 | fs: int = 16000,
32 | n_fft: int = 512,
33 | n_mels: int = 80,
34 | fmin: float = None,
35 | fmax: float = None,
36 | htk: bool = False,
37 | norm=1,
38 | ):
39 | super().__init__()
40 |
41 | fmin = 0 if fmin is None else fmin
42 | fmax = fs / 2 if fmax is None else fmax
43 | _mel_options = dict(
44 | sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
45 | )
46 | self.mel_options = _mel_options
47 |
48 | # Note(kamo): The mel matrix of librosa is different from kaldi.
49 | melmat = librosa.filters.mel(**_mel_options)
50 | # melmat: (D2, D1) -> (D1, D2)
51 | self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
52 | inv_mel = np.linalg.pinv(melmat)
53 | self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float())
54 |
55 | def extra_repr(self):
56 | return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
57 |
58 | def forward(
59 | self, feat: torch.Tensor, ilens: torch.Tensor = None,
60 | ) -> Tuple[torch.Tensor, torch.Tensor]:
61 | # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
62 | mel_feat = torch.matmul(feat, self.melmat)
63 |
64 | logmel_feat = (mel_feat + 1e-20).log()
65 | # Zero padding
66 | if ilens is not None:
67 | logmel_feat = logmel_feat.masked_fill(
68 | make_pad_mask(ilens, logmel_feat, 1), 0.0
69 | )
70 | else:
71 | ilens = feat.new_full(
72 | [feat.size(0)], fill_value=feat.size(1), dtype=torch.long
73 | )
74 | return logmel_feat, ilens
75 |
--------------------------------------------------------------------------------
/conformer_ppg_model/stft.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from typing import Tuple
3 | from typing import Union
4 |
5 | import torch
6 |
7 | from .nets_utils import make_pad_mask
8 |
9 |
10 | class Stft(torch.nn.Module):
11 | def __init__(
12 | self,
13 | n_fft: int = 512,
14 | win_length: Union[int, None] = 512,
15 | hop_length: int = 128,
16 | center: bool = True,
17 | pad_mode: str = "reflect",
18 | normalized: bool = False,
19 | onesided: bool = True,
20 | kaldi_padding_mode=False,
21 | ):
22 | super().__init__()
23 | self.n_fft = n_fft
24 | if win_length is None:
25 | self.win_length = n_fft
26 | else:
27 | self.win_length = win_length
28 | self.hop_length = hop_length
29 | self.center = center
30 | self.pad_mode = pad_mode
31 | self.normalized = normalized
32 | self.onesided = onesided
33 | self.kaldi_padding_mode = kaldi_padding_mode
34 | if self.kaldi_padding_mode:
35 | self.win_length = 400
36 |
37 | def extra_repr(self):
38 | return (
39 | f"n_fft={self.n_fft}, "
40 | f"win_length={self.win_length}, "
41 | f"hop_length={self.hop_length}, "
42 | f"center={self.center}, "
43 | f"pad_mode={self.pad_mode}, "
44 | f"normalized={self.normalized}, "
45 | f"onesided={self.onesided}"
46 | )
47 |
48 | def forward(
49 | self, input: torch.Tensor, ilens: torch.Tensor = None
50 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
51 | """STFT forward function.
52 |
53 | Args:
54 | input: (Batch, Nsamples) or (Batch, Nsample, Channels)
55 | ilens: (Batch)
56 | Returns:
57 | output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
58 |
59 | """
60 | bs = input.size(0)
61 | if input.dim() == 3:
62 | multi_channel = True
63 | # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
64 | input = input.transpose(1, 2).reshape(-1, input.size(1))
65 | else:
66 | multi_channel = False
67 |
68 | # output: (Batch, Freq, Frames, 2=real_imag)
69 | # or (Batch, Channel, Freq, Frames, 2=real_imag)
70 | if not self.kaldi_padding_mode:
71 | output = torch.stft(
72 | input,
73 | n_fft=self.n_fft,
74 | win_length=self.win_length,
75 | hop_length=self.hop_length,
76 | center=self.center,
77 | pad_mode=self.pad_mode,
78 | normalized=self.normalized,
79 | onesided=self.onesided,
80 | )
81 | else:
82 | # NOTE(sx): Use Kaldi-fasion padding, maybe wrong
83 | num_pads = self.n_fft - self.win_length
84 | input = torch.nn.functional.pad(input, (num_pads, 0))
85 | output = torch.stft(
86 | input,
87 | n_fft=self.n_fft,
88 | win_length=self.win_length,
89 | hop_length=self.hop_length,
90 | center=False,
91 | pad_mode=self.pad_mode,
92 | normalized=self.normalized,
93 | onesided=self.onesided,
94 | )
95 |
96 | # output: (Batch, Freq, Frames, 2=real_imag)
97 | # -> (Batch, Frames, Freq, 2=real_imag)
98 | output = output.transpose(1, 2)
99 | if multi_channel:
100 | # output: (Batch * Channel, Frames, Freq, 2=real_imag)
101 | # -> (Batch, Frame, Channel, Freq, 2=real_imag)
102 | output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
103 | 1, 2
104 | )
105 |
106 | if ilens is not None:
107 | if self.center:
108 | pad = self.win_length // 2
109 | ilens = ilens + 2 * pad
110 |
111 | olens = (ilens - self.win_length) // self.hop_length + 1
112 | output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
113 | else:
114 | olens = None
115 |
116 | return output, olens
117 |
--------------------------------------------------------------------------------
/conformer_ppg_model/utterance_mvn.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 |
5 | from .nets_utils import make_pad_mask
6 |
7 |
8 | class UtteranceMVN(torch.nn.Module):
9 | def __init__(
10 | self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20,
11 | ):
12 | super().__init__()
13 | self.norm_means = norm_means
14 | self.norm_vars = norm_vars
15 | self.eps = eps
16 |
17 | def extra_repr(self):
18 | return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
19 |
20 | def forward(
21 | self, x: torch.Tensor, ilens: torch.Tensor = None
22 | ) -> Tuple[torch.Tensor, torch.Tensor]:
23 | """Forward function
24 |
25 | Args:
26 | x: (B, L, ...)
27 | ilens: (B,)
28 |
29 | """
30 | return utterance_mvn(
31 | x,
32 | ilens,
33 | norm_means=self.norm_means,
34 | norm_vars=self.norm_vars,
35 | eps=self.eps,
36 | )
37 |
38 |
39 | def utterance_mvn(
40 | x: torch.Tensor,
41 | ilens: torch.Tensor = None,
42 | norm_means: bool = True,
43 | norm_vars: bool = False,
44 | eps: float = 1.0e-20,
45 | ) -> Tuple[torch.Tensor, torch.Tensor]:
46 | """Apply utterance mean and variance normalization
47 |
48 | Args:
49 | x: (B, T, D), assumed zero padded
50 | ilens: (B,)
51 | norm_means:
52 | norm_vars:
53 | eps:
54 |
55 | """
56 | if ilens is None:
57 | ilens = x.new_full([x.size(0)], x.size(1))
58 | ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
59 | # Zero padding
60 | if x.requires_grad:
61 | x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
62 | else:
63 | x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
64 | # mean: (B, 1, D)
65 | mean = x.sum(dim=1, keepdim=True) / ilens_
66 |
67 | if norm_means:
68 | x -= mean
69 |
70 | if norm_vars:
71 | var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
72 | std = torch.clamp(var.sqrt(), min=eps)
73 | x = x / std.sqrt()
74 | return x, ilens
75 | else:
76 | if norm_vars:
77 | y = x - mean
78 | y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
79 | var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
80 | std = torch.clamp(var.sqrt(), min=eps)
81 | x /= std
82 | return x, ilens
83 |
--------------------------------------------------------------------------------
/convert_from_wav.py:
--------------------------------------------------------------------------------
1 | import time
2 | import sys
3 | import os
4 | import argparse
5 | import torch
6 | import numpy as np
7 | import glob
8 | from pathlib import Path
9 | from tqdm import tqdm
10 | from conformer_ppg_model.build_ppg_model import load_ppg_model
11 | from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
12 | from src.mel_decoder_lsa import MelDecoderLSA
13 | from src.rnn_ppg2mel import BiRnnPpg2MelModel
14 | import pyworld
15 | import librosa
16 | import resampy
17 | import soundfile as sf
18 | from utils.f0_utils import get_cont_lf0
19 | from utils.load_yaml import HpsYaml
20 |
21 | from vocoders.hifigan_model import load_hifigan_generator
22 |
23 | from speaker_encoder.voice_encoder import SpeakerEncoder
24 | from speaker_encoder.audio import preprocess_wav
25 | from src import build_model
26 |
27 |
28 | def compute_spk_dvec(
29 | wav_path, weights_fpath="speaker_encoder/ckpt/pretrained_bak_5805000.pt",
30 | ):
31 | fpath = Path(wav_path)
32 | wav = preprocess_wav(fpath)
33 | encoder = SpeakerEncoder(weights_fpath)
34 | spk_dvec = encoder.embed_utterance(wav)
35 | return spk_dvec
36 |
37 |
38 | def compute_f0(wav, sr=16000, frame_period=10.0):
39 | wav = wav.astype(np.float64)
40 | f0, timeaxis = pyworld.harvest(
41 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0)
42 | return f0
43 |
44 |
45 | def compute_mean_std(lf0):
46 | nonzero_indices = np.nonzero(lf0)
47 | mean = np.mean(lf0[nonzero_indices])
48 | std = np.std(lf0[nonzero_indices])
49 | return mean, std
50 |
51 |
52 | def f02lf0(f0):
53 | lf0 = f0.copy()
54 | nonzero_indices = np.nonzero(f0)
55 | lf0[nonzero_indices] = np.log(f0[nonzero_indices])
56 | return lf0
57 |
58 |
59 | def get_converted_lf0uv(
60 | wav,
61 | lf0_mean_trg,
62 | lf0_std_trg,
63 | convert=True,
64 | ):
65 | f0_src = compute_f0(wav)
66 | if not convert:
67 | uv, cont_lf0 = get_cont_lf0(f0_src)
68 | lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
69 | return lf0_uv
70 |
71 | lf0_src = f02lf0(f0_src)
72 | lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src)
73 |
74 | lf0_vc = lf0_src.copy()
75 | lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg
76 | f0_vc = lf0_vc.copy()
77 | f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0])
78 |
79 | uv, cont_lf0_vc = get_cont_lf0(f0_vc)
80 | lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1)
81 | return lf0_uv
82 |
83 |
84 | def build_ppg2mel_model(model_config, model_file, device):
85 | model_class = build_model(model_config["model_name"])
86 | ppg2mel_model = model_class(
87 | **model_config["model"]
88 | ).to(device)
89 | ckpt = torch.load(model_file, map_location=device)
90 | ppg2mel_model.load_state_dict(ckpt["model"])
91 | ppg2mel_model.eval()
92 | return ppg2mel_model
93 |
94 |
95 | @torch.no_grad()
96 | def convert(args):
97 | device = 'cuda'
98 | ppg2mel_config = HpsYaml(args.ppg2mel_model_train_config)
99 | output_dir = args.output_dir
100 | os.makedirs(output_dir, exist_ok=True)
101 |
102 | step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
103 |
104 | # Build models
105 | print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
106 | ppg_model = load_ppg_model(
107 | './conformer_ppg_model/en_conformer_ctc_att/config.yaml',
108 | './conformer_ppg_model/en_conformer_ctc_att/24epoch.pth',
109 | device,
110 | )
111 | ppg2mel_model = build_ppg2mel_model(ppg2mel_config, args.ppg2mel_model_file, device)
112 | hifigan_model = load_hifigan_generator(device)
113 |
114 | # Data related
115 | ref_wav_path = args.ref_wav_path
116 | ref_fid = os.path.basename(ref_wav_path)[:-4]
117 | ref_spk_dvec = compute_spk_dvec(ref_wav_path)
118 | ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
119 | ref_wav, _ = librosa.load(ref_wav_path, sr=16000)
120 | ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
121 |
122 | source_file_list = sorted(glob.glob(f"{args.src_wav_dir}/*.wav"))
123 | print(f"Number of source utterances: {len(source_file_list)}.")
124 |
125 | total_rtf = 0.0
126 | cnt = 0
127 | for src_wav_path in tqdm(source_file_list):
128 | # Load the audio to a numpy array:
129 | src_wav, _ = librosa.load(src_wav_path, sr=16000)
130 | src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
131 | src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
132 | ppg = ppg_model(src_wav_tensor, src_wav_lengths)
133 |
134 | lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
135 | min_len = min(ppg.shape[1], len(lf0_uv))
136 |
137 | ppg = ppg[:, :min_len]
138 | lf0_uv = lf0_uv[:min_len]
139 |
140 | start = time.time()
141 | if isinstance(ppg2mel_model, BiRnnPpg2MelModel):
142 | ppg_length = torch.LongTensor([ppg.shape[1]]).to(device)
143 | logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device)
144 | mel_pred = ppg2mel_model(ppg, ppg_length, logf0_uv, ref_spk_dvec)
145 | else:
146 | _, mel_pred, att_ws = ppg2mel_model.inference(
147 | ppg,
148 | logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
149 | spembs=ref_spk_dvec,
150 | use_stop_tokens=True,
151 | )
152 | # if ppg2mel_config.data.min_max_norm_mel:
153 | # mel_min = ppg2mel_config.data.mel_min
154 | # mel_max = ppg2mel_config.data.mel_max
155 | # mel_pred = (mel_pred + 4.0) / 8.0 * (mel_max - mel_min) + mel_min
156 | src_fid = os.path.basename(src_wav_path)[:-4]
157 | wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
158 | mel_len = mel_pred.shape[0]
159 | rtf = (time.time() - start) / (0.01 * mel_len)
160 | total_rtf += rtf
161 | cnt += 1
162 | # continue
163 | y = hifigan_model(mel_pred.view(1, -1, 80).transpose(1, 2))
164 | sf.write(wav_fname, y.squeeze().cpu().numpy(), 24000, "PCM_16")
165 |
166 | print("RTF:")
167 | print(total_rtf / cnt)
168 |
169 |
170 | def get_parser():
171 | parser = argparse.ArgumentParser(description="Conversion from wave input")
172 | parser.add_argument(
173 | "--src_wav_dir",
174 | type=str,
175 | default=None,
176 | required=True,
177 | help="Source wave directory.",
178 | )
179 | parser.add_argument(
180 | "--ref_wav_path",
181 | type=str,
182 | required=True,
183 | help="Reference wave file path.",
184 | )
185 | parser.add_argument(
186 | "--ppg2mel_model_train_config", "-c",
187 | type=str,
188 | default=None,
189 | required=True,
190 | help="Training config file (yaml file)",
191 | )
192 | parser.add_argument(
193 | "--ppg2mel_model_file", "-m",
194 | type=str,
195 | default=None,
196 | required=True,
197 | help="ppg2mel model checkpoint file path"
198 | )
199 | parser.add_argument(
200 | "--output_dir", "-o",
201 | type=str,
202 | default="vc_gens_vctk_oneshot",
203 | help="Output folder to save the converted wave."
204 | )
205 |
206 | return parser
207 |
208 |
209 | def main():
210 | parser = get_parser()
211 | args = parser.parse_args()
212 | convert(args)
213 |
214 |
215 | if __name__ == "__main__":
216 | main()
217 |
--------------------------------------------------------------------------------
/figs/README.md:
--------------------------------------------------------------------------------
1 | images
2 |
--------------------------------------------------------------------------------
/figs/seq2seq_bnf2mel.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/figs/seq2seq_bnf2mel.pdf
--------------------------------------------------------------------------------
/figs/seq2seq_bnf2mel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/figs/seq2seq_bnf2mel.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | import sys
4 | import torch
5 | import argparse
6 | import numpy as np
7 | from utils.load_yaml import HpsYaml
8 |
9 | # For reproducibility, comment these may speed up training
10 | torch.backends.cudnn.deterministic = True
11 | torch.backends.cudnn.benchmark = False
12 |
13 | # Arguments
14 | parser = argparse.ArgumentParser(description=
15 | 'Training PPG2Mel VC model.')
16 | parser.add_argument('--config', type=str,
17 | help='Path to experiment config, e.g., config/vc.yaml')
18 | parser.add_argument('--name', default=None, type=str, help='Name for logging.')
19 | parser.add_argument('--logdir', default='log/', type=str,
20 | help='Logging path.', required=False)
21 | parser.add_argument('--ckpdir', default='ckpt/', type=str,
22 | help='Checkpoint path.', required=False)
23 | parser.add_argument('--outdir', default='result/', type=str,
24 | help='Decode output path.', required=False)
25 | parser.add_argument('--load', default=None, type=str,
26 | help='Load pre-trained model (for training only)', required=False)
27 | parser.add_argument('--warm_start', action='store_true',
28 | help='Load model weights only, ignore specified layers.')
29 | parser.add_argument('--seed', default=0, type=int,
30 | help='Random seed for reproducable results.', required=False)
31 | parser.add_argument('--njobs', default=8, type=int,
32 | help='Number of threads for dataloader/decoding.', required=False)
33 | parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
34 | parser.add_argument('--no-pin', action='store_true',
35 | help='Disable pin-memory for dataloader')
36 | parser.add_argument('--test', action='store_true', help='Test the model.')
37 | parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
38 | parser.add_argument('--finetune', action='store_true', help='Finetune model')
39 | parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
40 | parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
41 | parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
42 |
43 | ###
44 |
45 | paras = parser.parse_args()
46 | setattr(paras, 'gpu', not paras.cpu)
47 | setattr(paras, 'pin_memory', not paras.no_pin)
48 | setattr(paras, 'verbose', not paras.no_msg)
49 | # Make the config dict dot visitable
50 | config = HpsYaml(paras.config) # yaml.load(open(paras.config, 'r'), Loader=yaml.FullLoader)
51 |
52 | np.random.seed(paras.seed)
53 | torch.manual_seed(paras.seed)
54 | if torch.cuda.is_available():
55 | torch.cuda.manual_seed_all(paras.seed)
56 | # For debug use
57 | # torch.autograd.set_detect_anomaly(True)
58 |
59 | # Hack to preserve GPU ram just in case OOM later on server
60 | # if paras.gpu and paras.reserve_gpu > 0:
61 | # buff = torch.randn(int(paras.reserve_gpu*1e9//4)).cuda()
62 | # del buff
63 |
64 | if paras.oneshotvc:
65 | print(">>> OneShot VC training ...")
66 | if paras.bilstm:
67 | from bin.train_ppg2mel_oneshotvc import Solver
68 | else:
69 | from bin.train_linglf02mel_seq2seq_oneshotvc import Solver
70 | mode = "train"
71 | solver = Solver(config, paras, mode)
72 | solver.load_data()
73 | solver.set_model()
74 | solver.exec()
75 | print(">>> Oneshot VC train finished!")
76 | sys.exit(0)
77 | elif paras.lsa:
78 | print(">>> Use location sensitive attention based seq2seq model")
79 | from bin.train_linglf02mel_seq2seq_lsa import Solver
80 | mode = "train"
81 | solver = Solver(config, paras, mode)
82 | solver.load_data()
83 | solver.set_model()
84 | solver.exec()
85 | print(">>> VC train finished!")
86 | sys.exit(0)
87 | else:
88 | if paras.finetune:
89 | from bin.finetune_linglf02mel_seq2seq_encAddlf0 import Solver
90 | else:
91 | from bin.train_linglf02mel_seq2seq_encAddlf0 import Solver
92 | mode = 'train'
93 | solver = Solver(config, paras, mode)
94 | solver.load_data()
95 | solver.set_model()
96 | solver.exec()
97 | print(">>> VC train finished!")
98 | sys.exit(0)
99 |
--------------------------------------------------------------------------------
/path.sh:
--------------------------------------------------------------------------------
1 | # cuda related
2 | export CUDA_HOME=/usr/local/cuda
3 | export LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}"
4 |
5 | # path related
6 | export PRJ_ROOT="./"
7 | if [ -e "${PRJ_ROOT}/tool/venv/bin/activate" ]; then
8 | # shellcheck disable=SC1090
9 | . "${PRJ_ROOT}/tool/venv/bin/activate"
10 | fi
11 |
12 | # python related
13 | export OMP_NUM_THREADS=1
14 | export PYTHONIOENCODING=UTF-8
15 | export MPL_BACKEND=Agg
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchaudio==0.6.0
2 | tqdm==4.49.0
3 | librosa==0.8.0
4 | soundfile==0.10.3.post1
5 | resampy==0.2.2
6 | pyworld==0.2.11.post0
7 | PyYAML>=5.4
8 | typeguard==2.12.1
9 | matplotlib==3.3.2
10 | numpy==1.19.2
11 | scipy==1.3.3
12 | webrtcvad==2.0.10
13 | glob2
14 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | . ./path.sh || exit 1;
4 | export CUDA_VISIBLE_DEVICES=0
5 |
6 | ########## Train BiLSTM oneshot VC model ##########
7 | python main.py --config ./conf/bilstm_ppg2mel_vctk_libri_oneshotvc.yaml \
8 | --oneshotvc \
9 | --bilstm
10 | ###################################################
11 |
12 | ########## Train Seq2seq oneshot VC model ###########
13 | #python main.py --config ./conf/seq2seq_mol_ppg2mel_vctk_libri_oneshotvc.yaml \
14 | #--oneshotvc \
15 | ###################################################
16 |
--------------------------------------------------------------------------------
/speaker_encoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/speaker_encoder/__init__.py
--------------------------------------------------------------------------------
/speaker_encoder/audio.py:
--------------------------------------------------------------------------------
1 | from scipy.ndimage.morphology import binary_dilation
2 | from speaker_encoder.params_data import *
3 | from pathlib import Path
4 | from typing import Optional, Union
5 | import numpy as np
6 | import webrtcvad
7 | import librosa
8 | import struct
9 |
10 | int16_max = (2 ** 15) - 1
11 |
12 |
13 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14 | source_sr: Optional[int] = None):
15 | """
16 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18 |
19 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20 | just .wav), either the waveform as a numpy array of floats.
21 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22 | preprocessing. After preprocessing, the waveform's sampling rate will match the data
23 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24 | this argument will be ignored.
25 | """
26 | # Load the wav from disk if needed
27 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28 | wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29 | else:
30 | wav = fpath_or_wav
31 |
32 | # Resample the wav if needed
33 | if source_sr is not None and source_sr != sampling_rate:
34 | wav = librosa.resample(wav, source_sr, sampling_rate)
35 |
36 | # Apply the preprocessing: normalize volume and shorten long silences
37 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38 | wav = trim_long_silences(wav)
39 |
40 | return wav
41 |
42 |
43 | def wav_to_mel_spectrogram(wav):
44 | """
45 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46 | Note: this not a log-mel spectrogram.
47 | """
48 | frames = librosa.feature.melspectrogram(
49 | wav,
50 | sampling_rate,
51 | n_fft=int(sampling_rate * mel_window_length / 1000),
52 | hop_length=int(sampling_rate * mel_window_step / 1000),
53 | n_mels=mel_n_channels
54 | )
55 | return frames.astype(np.float32).T
56 |
57 |
58 | def trim_long_silences(wav):
59 | """
60 | Ensures that segments without voice in the waveform remain no longer than a
61 | threshold determined by the VAD parameters in params.py.
62 |
63 | :param wav: the raw waveform as a numpy array of floats
64 | :return: the same waveform with silences trimmed away (length <= original wav length)
65 | """
66 | # Compute the voice detection window size
67 | samples_per_window = (vad_window_length * sampling_rate) // 1000
68 |
69 | # Trim the end of the audio to have a multiple of the window size
70 | wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71 |
72 | # Convert the float waveform to 16-bit mono PCM
73 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74 |
75 | # Perform voice activation detection
76 | voice_flags = []
77 | vad = webrtcvad.Vad(mode=3)
78 | for window_start in range(0, len(wav), samples_per_window):
79 | window_end = window_start + samples_per_window
80 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81 | sample_rate=sampling_rate))
82 | voice_flags = np.array(voice_flags)
83 |
84 | # Smooth the voice detection with a moving average
85 | def moving_average(array, width):
86 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87 | ret = np.cumsum(array_padded, dtype=float)
88 | ret[width:] = ret[width:] - ret[:-width]
89 | return ret[width - 1:] / width
90 |
91 | audio_mask = moving_average(voice_flags, vad_moving_average_width)
92 | audio_mask = np.round(audio_mask).astype(np.bool)
93 |
94 | # Dilate the voiced regions
95 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96 | audio_mask = np.repeat(audio_mask, samples_per_window)
97 |
98 | return wav[audio_mask == True]
99 |
100 |
101 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102 | if increase_only and decrease_only:
103 | raise ValueError("Both increase only and decrease only are set")
104 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106 | return wav
107 | return wav * (10 ** (dBFS_change / 20))
108 |
--------------------------------------------------------------------------------
/speaker_encoder/ckpt/pretrained_bak_5805000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/speaker_encoder/ckpt/pretrained_bak_5805000.pt
--------------------------------------------------------------------------------
/speaker_encoder/compute_embed.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder import inference as encoder
2 | from multiprocessing.pool import Pool
3 | from functools import partial
4 | from pathlib import Path
5 | # from utils import logmmse
6 | # from tqdm import tqdm
7 | # import numpy as np
8 | # import librosa
9 |
10 |
11 | def embed_utterance(fpaths, encoder_model_fpath):
12 | if not encoder.is_loaded():
13 | encoder.load_model(encoder_model_fpath)
14 |
15 | # Compute the speaker embedding of the utterance
16 | wav_fpath, embed_fpath = fpaths
17 | wav = np.load(wav_fpath)
18 | wav = encoder.preprocess_wav(wav)
19 | embed = encoder.embed_utterance(wav)
20 | np.save(embed_fpath, embed, allow_pickle=False)
21 |
22 |
23 | def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24 |
25 | wav_dir = outdir_root.joinpath("audio")
26 | metadata_fpath = synthesizer_root.joinpath("train.txt")
27 | assert wav_dir.exists() and metadata_fpath.exists()
28 | embed_dir = synthesizer_root.joinpath("embeds")
29 | embed_dir.mkdir(exist_ok=True)
30 |
31 | # Gather the input wave filepath and the target output embed filepath
32 | with metadata_fpath.open("r") as metadata_file:
33 | metadata = [line.split("|") for line in metadata_file]
34 | fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35 |
36 | # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37 | # Embed the utterances in separate threads
38 | func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39 | job = Pool(n_processes).imap(func, fpaths)
40 | list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
--------------------------------------------------------------------------------
/speaker_encoder/config.py:
--------------------------------------------------------------------------------
1 | librispeech_datasets = {
2 | "train": {
3 | "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4 | "other": ["LibriSpeech/train-other-500"]
5 | },
6 | "test": {
7 | "clean": ["LibriSpeech/test-clean"],
8 | "other": ["LibriSpeech/test-other"]
9 | },
10 | "dev": {
11 | "clean": ["LibriSpeech/dev-clean"],
12 | "other": ["LibriSpeech/dev-other"]
13 | },
14 | }
15 | libritts_datasets = {
16 | "train": {
17 | "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18 | "other": ["LibriTTS/train-other-500"]
19 | },
20 | "test": {
21 | "clean": ["LibriTTS/test-clean"],
22 | "other": ["LibriTTS/test-other"]
23 | },
24 | "dev": {
25 | "clean": ["LibriTTS/dev-clean"],
26 | "other": ["LibriTTS/dev-other"]
27 | },
28 | }
29 | voxceleb_datasets = {
30 | "voxceleb1" : {
31 | "train": ["VoxCeleb1/wav"],
32 | "test": ["VoxCeleb1/test_wav"]
33 | },
34 | "voxceleb2" : {
35 | "train": ["VoxCeleb2/dev/aac"],
36 | "test": ["VoxCeleb2/test_wav"]
37 | }
38 | }
39 |
40 | other_datasets = [
41 | "LJSpeech-1.1",
42 | "VCTK-Corpus/wav48",
43 | ]
44 |
45 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
46 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/__init__.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
3 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/random_cycler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | class RandomCycler:
4 | """
5 | Creates an internal copy of a sequence and allows access to its items in a constrained random
6 | order. For a source sequence of n items and one or several consecutive queries of a total
7 | of m items, the following guarantees hold (one implies the other):
8 | - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9 | - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10 | """
11 |
12 | def __init__(self, source):
13 | if len(source) == 0:
14 | raise Exception("Can't create RandomCycler from an empty collection")
15 | self.all_items = list(source)
16 | self.next_items = []
17 |
18 | def sample(self, count: int):
19 | shuffle = lambda l: random.sample(l, len(l))
20 |
21 | out = []
22 | while count > 0:
23 | if count >= len(self.all_items):
24 | out.extend(shuffle(list(self.all_items)))
25 | count -= len(self.all_items)
26 | continue
27 | n = min(count, len(self.next_items))
28 | out.extend(self.next_items[:n])
29 | count -= n
30 | self.next_items = self.next_items[n:]
31 | if len(self.next_items) == 0:
32 | self.next_items = shuffle(list(self.all_items))
33 | return out
34 |
35 | def __next__(self):
36 | return self.sample(1)[0]
37 |
38 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/speaker.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.random_cycler import RandomCycler
2 | from speaker_encoder.data_objects.utterance import Utterance
3 | from pathlib import Path
4 |
5 | # Contains the set of utterances of a single speaker
6 | class Speaker:
7 | def __init__(self, root: Path):
8 | self.root = root
9 | self.name = root.name
10 | self.utterances = None
11 | self.utterance_cycler = None
12 |
13 | def _load_utterances(self):
14 | with self.root.joinpath("_sources.txt").open("r") as sources_file:
15 | sources = [l.split(",") for l in sources_file]
16 | sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17 | self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18 | self.utterance_cycler = RandomCycler(self.utterances)
19 |
20 | def random_partial(self, count, n_frames):
21 | """
22 | Samples a batch of unique partial utterances from the disk in a way that all
23 | utterances come up at least once every two cycles and in a random order every time.
24 |
25 | :param count: The number of partial utterances to sample from the set of utterances from
26 | that speaker. Utterances are guaranteed not to be repeated if is not larger than
27 | the number of utterances available.
28 | :param n_frames: The number of frames in the partial utterance.
29 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30 | frames are the frames of the partial utterances and range is the range of the partial
31 | utterance with regard to the complete utterance.
32 | """
33 | if self.utterances is None:
34 | self._load_utterances()
35 |
36 | utterances = self.utterance_cycler.sample(count)
37 |
38 | a = [(u,) + u.random_partial(n_frames) for u in utterances]
39 |
40 | return a
41 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/speaker_batch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import List
3 | from speaker_encoder.data_objects.speaker import Speaker
4 |
5 | class SpeakerBatch:
6 | def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7 | self.speakers = speakers
8 | self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9 |
10 | # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11 | # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12 | self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
13 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/speaker_verification_dataset.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.random_cycler import RandomCycler
2 | from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3 | from speaker_encoder.data_objects.speaker import Speaker
4 | from speaker_encoder.params_data import partials_n_frames
5 | from torch.utils.data import Dataset, DataLoader
6 | from pathlib import Path
7 |
8 | # TODO: improve with a pool of speakers for data efficiency
9 |
10 | class SpeakerVerificationDataset(Dataset):
11 | def __init__(self, datasets_root: Path):
12 | self.root = datasets_root
13 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14 | if len(speaker_dirs) == 0:
15 | raise Exception("No speakers found. Make sure you are pointing to the directory "
16 | "containing all preprocessed speaker directories.")
17 | self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18 | self.speaker_cycler = RandomCycler(self.speakers)
19 |
20 | def __len__(self):
21 | return int(1e10)
22 |
23 | def __getitem__(self, index):
24 | return next(self.speaker_cycler)
25 |
26 | def get_logs(self):
27 | log_string = ""
28 | for log_fpath in self.root.glob("*.txt"):
29 | with log_fpath.open("r") as log_file:
30 | log_string += "".join(log_file.readlines())
31 | return log_string
32 |
33 |
34 | class SpeakerVerificationDataLoader(DataLoader):
35 | def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36 | batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37 | worker_init_fn=None):
38 | self.utterances_per_speaker = utterances_per_speaker
39 |
40 | super().__init__(
41 | dataset=dataset,
42 | batch_size=speakers_per_batch,
43 | shuffle=False,
44 | sampler=sampler,
45 | batch_sampler=batch_sampler,
46 | num_workers=num_workers,
47 | collate_fn=self.collate,
48 | pin_memory=pin_memory,
49 | drop_last=False,
50 | timeout=timeout,
51 | worker_init_fn=worker_init_fn
52 | )
53 |
54 | def collate(self, speakers):
55 | return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56 |
--------------------------------------------------------------------------------
/speaker_encoder/data_objects/utterance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Utterance:
5 | def __init__(self, frames_fpath, wave_fpath):
6 | self.frames_fpath = frames_fpath
7 | self.wave_fpath = wave_fpath
8 |
9 | def get_frames(self):
10 | return np.load(self.frames_fpath)
11 |
12 | def random_partial(self, n_frames):
13 | """
14 | Crops the frames into a partial utterance of n_frames
15 |
16 | :param n_frames: The number of frames of the partial utterance
17 | :return: the partial utterance frames and a tuple indicating the start and end of the
18 | partial utterance in the complete utterance.
19 | """
20 | frames = self.get_frames()
21 | if frames.shape[0] == n_frames:
22 | start = 0
23 | else:
24 | start = np.random.randint(0, frames.shape[0] - n_frames)
25 | end = start + n_frames
26 | return frames[start:end], (start, end)
--------------------------------------------------------------------------------
/speaker_encoder/hparams.py:
--------------------------------------------------------------------------------
1 | ## Mel-filterbank
2 | mel_window_length = 25 # In milliseconds
3 | mel_window_step = 10 # In milliseconds
4 | mel_n_channels = 40
5 |
6 |
7 | ## Audio
8 | sampling_rate = 16000
9 | # Number of spectrogram frames in a partial utterance
10 | partials_n_frames = 160 # 1600 ms
11 |
12 |
13 | ## Voice Activation Detection
14 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15 | # This sets the granularity of the VAD. Should not need to be changed.
16 | vad_window_length = 30 # In milliseconds
17 | # Number of frames to average together when performing the moving average smoothing.
18 | # The larger this value, the larger the VAD variations must be to not get smoothed out.
19 | vad_moving_average_width = 8
20 | # Maximum number of consecutive silent frames a segment can have.
21 | vad_max_silence_length = 6
22 |
23 |
24 | ## Audio volume normalization
25 | audio_norm_target_dBFS = -30
26 |
27 |
28 | ## Model parameters
29 | model_hidden_size = 256
30 | model_embedding_size = 256
31 | model_num_layers = 3
--------------------------------------------------------------------------------
/speaker_encoder/inference.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.params_data import *
2 | from speaker_encoder.model import SpeakerEncoder
3 | from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4 | from matplotlib import cm
5 | from speaker_encoder import audio
6 | from pathlib import Path
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 |
11 | _model = None # type: SpeakerEncoder
12 | _device = None # type: torch.device
13 |
14 |
15 | def load_model(weights_fpath: Path, device=None):
16 | """
17 | Loads the model in memory. If this function is not explicitely called, it will be run on the
18 | first call to embed_frames() with the default weights file.
19 |
20 | :param weights_fpath: the path to saved model weights.
21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22 | model will be loaded and will run on this device. Outputs will however always be on the cpu.
23 | If None, will default to your GPU if it"s available, otherwise your CPU.
24 | """
25 | # TODO: I think the slow loading of the encoder might have something to do with the device it
26 | # was saved on. Worth investigating.
27 | global _model, _device
28 | if device is None:
29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30 | elif isinstance(device, str):
31 | _device = torch.device(device)
32 | _model = SpeakerEncoder(_device, torch.device("cpu"))
33 | checkpoint = torch.load(weights_fpath)
34 | _model.load_state_dict(checkpoint["model_state"])
35 | _model.eval()
36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37 |
38 |
39 | def is_loaded():
40 | return _model is not None
41 |
42 |
43 | def embed_frames_batch(frames_batch):
44 | """
45 | Computes embeddings for a batch of mel spectrogram.
46 |
47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48 | (batch_size, n_frames, n_channels)
49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50 | """
51 | if _model is None:
52 | raise Exception("Model was not loaded. Call load_model() before inference.")
53 |
54 | frames = torch.from_numpy(frames_batch).to(_device)
55 | embed = _model.forward(frames).detach().cpu().numpy()
56 | return embed
57 |
58 |
59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60 | min_pad_coverage=0.75, overlap=0.5):
61 | """
62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63 | partial utterances of each. Both the waveform and the mel
64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those
66 | defined in params_data.py.
67 |
68 | The returned ranges may be indexing further than the length of the waveform. It is
69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70 |
71 | :param n_samples: the number of samples in the waveform
72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73 | utterance
74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75 | enough frames. If at least of are present,
76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78 | utterance, this parameter is ignored so that the function always returns at least 1 slice.
79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80 | utterances are entirely disjoint.
81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial
83 | utterances.
84 | """
85 | assert 0 <= overlap < 1
86 | assert 0 < min_pad_coverage <= 1
87 |
88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91 |
92 | # Compute the slices
93 | wav_slices, mel_slices = [], []
94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95 | for i in range(0, steps, frame_step):
96 | mel_range = np.array([i, i + partial_utterance_n_frames])
97 | wav_range = mel_range * samples_per_frame
98 | mel_slices.append(slice(*mel_range))
99 | wav_slices.append(slice(*wav_range))
100 |
101 | # Evaluate whether extra padding is warranted or not
102 | last_wav_range = wav_slices[-1]
103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104 | if coverage < min_pad_coverage and len(mel_slices) > 1:
105 | mel_slices = mel_slices[:-1]
106 | wav_slices = wav_slices[:-1]
107 |
108 | return wav_slices, mel_slices
109 |
110 |
111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112 | """
113 | Computes an embedding for a single utterance.
114 |
115 | # TODO: handle multiple wavs to benefit from batching on GPU
116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117 | :param using_partials: if True, then the utterance is split in partial utterances of
118 | frames and the utterance embedding is computed from their
119 | normalized average. If False, the utterance is instead computed from feeding the entire
120 | spectogram to the network.
121 | :param return_partials: if True, the partial embeddings will also be returned along with the
122 | wav slices that correspond to the partial embeddings.
123 | :param kwargs: additional arguments to compute_partial_splits()
124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125 | is True, the partial utterances as a numpy array of float32 of shape
126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127 | returned. If is simultaneously set to False, both these values will be None
128 | instead.
129 | """
130 | # Process the entire utterance if not using partials
131 | if not using_partials:
132 | frames = audio.wav_to_mel_spectrogram(wav)
133 | embed = embed_frames_batch(frames[None, ...])[0]
134 | if return_partials:
135 | return embed, None, None
136 | return embed
137 |
138 | # Compute where to split the utterance into partials and pad if necessary
139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140 | max_wave_length = wave_slices[-1].stop
141 | if max_wave_length >= len(wav):
142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143 |
144 | # Split the utterance into partials
145 | frames = audio.wav_to_mel_spectrogram(wav)
146 | frames_batch = np.array([frames[s] for s in mel_slices])
147 | partial_embeds = embed_frames_batch(frames_batch)
148 |
149 | # Compute the utterance embedding from the partial embeddings
150 | raw_embed = np.mean(partial_embeds, axis=0)
151 | embed = raw_embed / np.linalg.norm(raw_embed, 2)
152 |
153 | if return_partials:
154 | return embed, partial_embeds, wave_slices
155 | return embed
156 |
157 |
158 | def embed_speaker(wavs, **kwargs):
159 | raise NotImplemented()
160 |
161 |
162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163 | if ax is None:
164 | ax = plt.gca()
165 |
166 | if shape is None:
167 | height = int(np.sqrt(len(embed)))
168 | shape = (height, -1)
169 | embed = embed.reshape(shape)
170 |
171 | cmap = cm.get_cmap()
172 | mappable = ax.imshow(embed, cmap=cmap)
173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174 | cbar.set_clim(*color_range)
175 |
176 | ax.set_xticks([]), ax.set_yticks([])
177 | ax.set_title(title)
178 |
--------------------------------------------------------------------------------
/speaker_encoder/model.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.params_model import *
2 | from speaker_encoder.params_data import *
3 | from scipy.interpolate import interp1d
4 | from sklearn.metrics import roc_curve
5 | from torch.nn.utils import clip_grad_norm_
6 | from scipy.optimize import brentq
7 | from torch import nn
8 | import numpy as np
9 | import torch
10 |
11 |
12 | class SpeakerEncoder(nn.Module):
13 | def __init__(self, device, loss_device):
14 | super().__init__()
15 | self.loss_device = loss_device
16 |
17 | # Network defition
18 | self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19 | hidden_size=model_hidden_size, # 256
20 | num_layers=model_num_layers, # 3
21 | batch_first=True).to(device)
22 | self.linear = nn.Linear(in_features=model_hidden_size,
23 | out_features=model_embedding_size).to(device)
24 | self.relu = torch.nn.ReLU().to(device)
25 |
26 | # Cosine similarity scaling (with fixed initial parameter values)
27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29 |
30 | # Loss
31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32 |
33 | def do_gradient_ops(self):
34 | # Gradient scale
35 | self.similarity_weight.grad *= 0.01
36 | self.similarity_bias.grad *= 0.01
37 |
38 | # Gradient clipping
39 | clip_grad_norm_(self.parameters(), 3, norm_type=2)
40 |
41 | def forward(self, utterances, hidden_init=None):
42 | """
43 | Computes the embeddings of a batch of utterance spectrograms.
44 |
45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46 | (batch_size, n_frames, n_channels)
47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48 | batch_size, hidden_size). Will default to a tensor of zeros if None.
49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50 | """
51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52 | # and the final cell state.
53 | out, (hidden, cell) = self.lstm(utterances, hidden_init)
54 |
55 | # We take only the hidden state of the last layer
56 | embeds_raw = self.relu(self.linear(hidden[-1]))
57 |
58 | # L2-normalize it
59 | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60 |
61 | return embeds
62 |
63 | def similarity_matrix(self, embeds):
64 | """
65 | Computes the similarity matrix according the section 2.1 of GE2E.
66 |
67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68 | utterances_per_speaker, embedding_size)
69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70 | utterances_per_speaker, speakers_per_batch)
71 | """
72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73 |
74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76 | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77 |
78 | # Exclusive centroids (1 per utterance)
79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80 | centroids_excl /= (utterances_per_speaker - 1)
81 | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82 |
83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85 | # We vectorize the computation for efficiency.
86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87 | speakers_per_batch).to(self.loss_device)
88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89 | for j in range(speakers_per_batch):
90 | mask = np.where(mask_matrix[j])[0]
91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93 |
94 | ## Even more vectorized version (slower maybe because of transpose)
95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96 | # ).to(self.loss_device)
97 | # eye = np.eye(speakers_per_batch, dtype=np.int)
98 | # mask = np.where(1 - eye)
99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100 | # mask = np.where(eye)
101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102 | # sim_matrix2 = sim_matrix2.transpose(1, 2)
103 |
104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105 | return sim_matrix
106 |
107 | def loss(self, embeds):
108 | """
109 | Computes the softmax loss according the section 2.1 of GE2E.
110 |
111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112 | utterances_per_speaker, embedding_size)
113 | :return: the loss and the EER for this batch of embeddings.
114 | """
115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116 |
117 | # Loss
118 | sim_matrix = self.similarity_matrix(embeds)
119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120 | speakers_per_batch))
121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123 | loss = self.loss_fn(sim_matrix, target)
124 |
125 | # EER (not backpropagated)
126 | with torch.no_grad():
127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128 | labels = np.array([inv_argmax(i) for i in ground_truth])
129 | preds = sim_matrix.detach().cpu().numpy()
130 |
131 | # Snippet from https://yangcha.github.io/EER-ROC/
132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134 |
135 | return loss, eer
--------------------------------------------------------------------------------
/speaker_encoder/params_data.py:
--------------------------------------------------------------------------------
1 |
2 | ## Mel-filterbank
3 | mel_window_length = 25 # In milliseconds
4 | mel_window_step = 10 # In milliseconds
5 | mel_n_channels = 40
6 |
7 |
8 | ## Audio
9 | sampling_rate = 16000
10 | # Number of spectrogram frames in a partial utterance
11 | partials_n_frames = 160 # 1600 ms
12 | # Number of spectrogram frames at inference
13 | inference_n_frames = 80 # 800 ms
14 |
15 |
16 | ## Voice Activation Detection
17 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18 | # This sets the granularity of the VAD. Should not need to be changed.
19 | vad_window_length = 30 # In milliseconds
20 | # Number of frames to average together when performing the moving average smoothing.
21 | # The larger this value, the larger the VAD variations must be to not get smoothed out.
22 | vad_moving_average_width = 8
23 | # Maximum number of consecutive silent frames a segment can have.
24 | vad_max_silence_length = 6
25 |
26 |
27 | ## Audio volume normalization
28 | audio_norm_target_dBFS = -30
29 |
30 |
--------------------------------------------------------------------------------
/speaker_encoder/params_model.py:
--------------------------------------------------------------------------------
1 |
2 | ## Model parameters
3 | model_hidden_size = 256
4 | model_embedding_size = 256
5 | model_num_layers = 3
6 |
7 |
8 | ## Training parameters
9 | learning_rate_init = 1e-4
10 | speakers_per_batch = 64
11 | utterances_per_speaker = 10
12 |
--------------------------------------------------------------------------------
/speaker_encoder/train.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.visualizations import Visualizations
2 | from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3 | from speaker_encoder.params_model import *
4 | from speaker_encoder.model import SpeakerEncoder
5 | from utils.profiler import Profiler
6 | from pathlib import Path
7 | import torch
8 |
9 | def sync(device: torch.device):
10 | # FIXME
11 | return
12 | # For correct profiling (cuda operations are async)
13 | if device.type == "cuda":
14 | torch.cuda.synchronize(device)
15 |
16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18 | no_visdom: bool):
19 | # Create a dataset and a dataloader
20 | dataset = SpeakerVerificationDataset(clean_data_root)
21 | loader = SpeakerVerificationDataLoader(
22 | dataset,
23 | speakers_per_batch, # 64
24 | utterances_per_speaker, # 10
25 | num_workers=8,
26 | )
27 |
28 | # Setup the device on which to run the forward pass and the loss. These can be different,
29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30 | # hyperparameters) faster on the CPU.
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 | # FIXME: currently, the gradient is None if loss_device is cuda
33 | loss_device = torch.device("cpu")
34 |
35 | # Create the model and the optimizer
36 | model = SpeakerEncoder(device, loss_device)
37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38 | init_step = 1
39 |
40 | # Configure file path for the model
41 | state_fpath = models_dir.joinpath(run_id + ".pt")
42 | backup_dir = models_dir.joinpath(run_id + "_backups")
43 |
44 | # Load any existing model
45 | if not force_restart:
46 | if state_fpath.exists():
47 | print("Found existing model \"%s\", loading it and resuming training." % run_id)
48 | checkpoint = torch.load(state_fpath)
49 | init_step = checkpoint["step"]
50 | model.load_state_dict(checkpoint["model_state"])
51 | optimizer.load_state_dict(checkpoint["optimizer_state"])
52 | optimizer.param_groups[0]["lr"] = learning_rate_init
53 | else:
54 | print("No model \"%s\" found, starting training from scratch." % run_id)
55 | else:
56 | print("Starting the training from scratch.")
57 | model.train()
58 |
59 | # Initialize the visualization environment
60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61 | vis.log_dataset(dataset)
62 | vis.log_params()
63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64 | vis.log_implementation({"Device": device_name})
65 |
66 | # Training loop
67 | profiler = Profiler(summarize_every=10, disabled=False)
68 | for step, speaker_batch in enumerate(loader, init_step):
69 | profiler.tick("Blocking, waiting for batch (threaded)")
70 |
71 | # Forward pass
72 | inputs = torch.from_numpy(speaker_batch.data).to(device)
73 | sync(device)
74 | profiler.tick("Data to %s" % device)
75 | embeds = model(inputs)
76 | sync(device)
77 | profiler.tick("Forward pass")
78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79 | loss, eer = model.loss(embeds_loss)
80 | sync(loss_device)
81 | profiler.tick("Loss")
82 |
83 | # Backward pass
84 | model.zero_grad()
85 | loss.backward()
86 | profiler.tick("Backward pass")
87 | model.do_gradient_ops()
88 | optimizer.step()
89 | profiler.tick("Parameter update")
90 |
91 | # Update visualizations
92 | # learning_rate = optimizer.param_groups[0]["lr"]
93 | vis.update(loss.item(), eer, step)
94 |
95 | # Draw projections and save them to the backup folder
96 | if umap_every != 0 and step % umap_every == 0:
97 | print("Drawing and saving projections (step %d)" % step)
98 | backup_dir.mkdir(exist_ok=True)
99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100 | embeds = embeds.detach().cpu().numpy()
101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102 | vis.save()
103 |
104 | # Overwrite the latest version of the model
105 | if save_every != 0 and step % save_every == 0:
106 | print("Saving the model (step %d)" % step)
107 | torch.save({
108 | "step": step + 1,
109 | "model_state": model.state_dict(),
110 | "optimizer_state": optimizer.state_dict(),
111 | }, state_fpath)
112 |
113 | # Make a backup
114 | if backup_every != 0 and step % backup_every == 0:
115 | print("Making a backup (step %d)" % step)
116 | backup_dir.mkdir(exist_ok=True)
117 | backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118 | torch.save({
119 | "step": step + 1,
120 | "model_state": model.state_dict(),
121 | "optimizer_state": optimizer.state_dict(),
122 | }, backup_fpath)
123 |
124 | profiler.tick("Extras (visualizations, saving)")
125 |
--------------------------------------------------------------------------------
/speaker_encoder/visualizations.py:
--------------------------------------------------------------------------------
1 | from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2 | from datetime import datetime
3 | from time import perf_counter as timer
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | # import webbrowser
7 | import visdom
8 | import umap
9 |
10 | colormap = np.array([
11 | [76, 255, 0],
12 | [0, 127, 70],
13 | [255, 0, 0],
14 | [255, 217, 38],
15 | [0, 135, 255],
16 | [165, 0, 165],
17 | [255, 167, 255],
18 | [0, 255, 255],
19 | [255, 96, 38],
20 | [142, 76, 0],
21 | [33, 0, 127],
22 | [0, 0, 0],
23 | [183, 183, 183],
24 | ], dtype=np.float) / 255
25 |
26 |
27 | class Visualizations:
28 | def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29 | # Tracking data
30 | self.last_update_timestamp = timer()
31 | self.update_every = update_every
32 | self.step_times = []
33 | self.losses = []
34 | self.eers = []
35 | print("Updating the visualizations every %d steps." % update_every)
36 |
37 | # If visdom is disabled TODO: use a better paradigm for that
38 | self.disabled = disabled
39 | if self.disabled:
40 | return
41 |
42 | # Set the environment name
43 | now = str(datetime.now().strftime("%d-%m %Hh%M"))
44 | if env_name is None:
45 | self.env_name = now
46 | else:
47 | self.env_name = "%s (%s)" % (env_name, now)
48 |
49 | # Connect to visdom and open the corresponding window in the browser
50 | try:
51 | self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52 | except ConnectionError:
53 | raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54 | "start it.")
55 | # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56 |
57 | # Create the windows
58 | self.loss_win = None
59 | self.eer_win = None
60 | # self.lr_win = None
61 | self.implementation_win = None
62 | self.projection_win = None
63 | self.implementation_string = ""
64 |
65 | def log_params(self):
66 | if self.disabled:
67 | return
68 | from speaker_encoder import params_data
69 | from speaker_encoder import params_model
70 | param_string = "Model parameters:
"
71 | for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72 | value = getattr(params_model, param_name)
73 | param_string += "\t%s: %s
" % (param_name, value)
74 | param_string += "Data parameters:
"
75 | for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76 | value = getattr(params_data, param_name)
77 | param_string += "\t%s: %s
" % (param_name, value)
78 | self.vis.text(param_string, opts={"title": "Parameters"})
79 |
80 | def log_dataset(self, dataset: SpeakerVerificationDataset):
81 | if self.disabled:
82 | return
83 | dataset_string = ""
84 | dataset_string += "Speakers: %s\n" % len(dataset.speakers)
85 | dataset_string += "\n" + dataset.get_logs()
86 | dataset_string = dataset_string.replace("\n", "
")
87 | self.vis.text(dataset_string, opts={"title": "Dataset"})
88 |
89 | def log_implementation(self, params):
90 | if self.disabled:
91 | return
92 | implementation_string = ""
93 | for param, value in params.items():
94 | implementation_string += "%s: %s\n" % (param, value)
95 | implementation_string = implementation_string.replace("\n", "
")
96 | self.implementation_string = implementation_string
97 | self.implementation_win = self.vis.text(
98 | implementation_string,
99 | opts={"title": "Training implementation"}
100 | )
101 |
102 | def update(self, loss, eer, step):
103 | # Update the tracking data
104 | now = timer()
105 | self.step_times.append(1000 * (now - self.last_update_timestamp))
106 | self.last_update_timestamp = now
107 | self.losses.append(loss)
108 | self.eers.append(eer)
109 | print(".", end="")
110 |
111 | # Update the plots every steps
112 | if step % self.update_every != 0:
113 | return
114 | time_string = "Step time: mean: %5dms std: %5dms" % \
115 | (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116 | print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117 | (step, np.mean(self.losses), np.mean(self.eers), time_string))
118 | if not self.disabled:
119 | self.loss_win = self.vis.line(
120 | [np.mean(self.losses)],
121 | [step],
122 | win=self.loss_win,
123 | update="append" if self.loss_win else None,
124 | opts=dict(
125 | legend=["Avg. loss"],
126 | xlabel="Step",
127 | ylabel="Loss",
128 | title="Loss",
129 | )
130 | )
131 | self.eer_win = self.vis.line(
132 | [np.mean(self.eers)],
133 | [step],
134 | win=self.eer_win,
135 | update="append" if self.eer_win else None,
136 | opts=dict(
137 | legend=["Avg. EER"],
138 | xlabel="Step",
139 | ylabel="EER",
140 | title="Equal error rate"
141 | )
142 | )
143 | if self.implementation_win is not None:
144 | self.vis.text(
145 | self.implementation_string + ("%s" % time_string),
146 | win=self.implementation_win,
147 | opts={"title": "Training implementation"},
148 | )
149 |
150 | # Reset the tracking
151 | self.losses.clear()
152 | self.eers.clear()
153 | self.step_times.clear()
154 |
155 | def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156 | max_speakers=10):
157 | max_speakers = min(max_speakers, len(colormap))
158 | embeds = embeds[:max_speakers * utterances_per_speaker]
159 |
160 | n_speakers = len(embeds) // utterances_per_speaker
161 | ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162 | colors = [colormap[i] for i in ground_truth]
163 |
164 | reducer = umap.UMAP()
165 | projected = reducer.fit_transform(embeds)
166 | plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167 | plt.gca().set_aspect("equal", "datalim")
168 | plt.title("UMAP projection (step %d)" % step)
169 | if not self.disabled:
170 | self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171 | if out_fpath is not None:
172 | plt.savefig(out_fpath)
173 | plt.clf()
174 |
175 | def save(self):
176 | if not self.disabled:
177 | self.vis.save([self.env_name])
178 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .mel_decoder_mol_encAddlf0 import MelDecoderMOL
2 | from .mel_decoder_mol_v2 import MelDecoderMOLv2
3 | from .rnn_ppg2mel import BiRnnPpg2MelModel
4 | from .mel_decoder_lsa import MelDecoderLSA
5 |
6 |
7 | def build_model(model_name: str):
8 | if model_name == "seq2seqmol":
9 | return MelDecoderMOL
10 | elif model_name == "seq2seqmolv2":
11 | return MelDecoderMOLv2
12 | elif model_name == "bilstm":
13 | return BiRnnPpg2MelModel
14 | elif model_name == "seq2seqlsa":
15 | return MelDecoderLSA
16 | else:
17 | raise ValueError(f"Unknown model name: {model_name}.")
18 |
--------------------------------------------------------------------------------
/src/abs_model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from abc import abstractmethod
3 | from typing import Tuple
4 |
5 | import torch
6 |
7 |
8 | class AbsMelDecoder(torch.nn.Module, ABC):
9 | """The abstract PPG-based voice conversion class
10 | This "model" is one of mediator objects for "Task" class.
11 |
12 | """
13 |
14 | @abstractmethod
15 | def forward(
16 | self,
17 | bottle_neck_features: torch.Tensor,
18 | feature_lengths: torch.Tensor,
19 | speech: torch.Tensor,
20 | speech_lengths: torch.Tensor,
21 | logf0_uv: torch.Tensor = None,
22 | spembs: torch.Tensor = None,
23 | styleembs: torch.Tensor = None,
24 | ) -> torch.Tensor:
25 | raise NotImplementedError
26 |
--------------------------------------------------------------------------------
/src/audio_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | import torch.utils.data
6 | import numpy as np
7 | from librosa.util import normalize
8 | from scipy.io.wavfile import read
9 | from librosa.filters import mel as librosa_mel_fn
10 |
11 | MAX_WAV_VALUE = 32768.0
12 |
13 |
14 | def load_wav(full_path):
15 | sampling_rate, data = read(full_path)
16 | return data, sampling_rate
17 |
18 |
19 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
20 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
21 |
22 |
23 | def dynamic_range_decompression(x, C=1):
24 | return np.exp(x) / C
25 |
26 |
27 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
28 | return torch.log(torch.clamp(x, min=clip_val) * C)
29 |
30 |
31 | def dynamic_range_decompression_torch(x, C=1):
32 | return torch.exp(x) / C
33 |
34 |
35 | def spectral_normalize_torch(magnitudes):
36 | output = dynamic_range_compression_torch(magnitudes)
37 | return output
38 |
39 |
40 | def spectral_de_normalize_torch(magnitudes):
41 | output = dynamic_range_decompression_torch(magnitudes)
42 | return output
43 |
44 |
45 | mel_basis = {}
46 | hann_window = {}
47 |
48 |
49 | def mel_spectrogram(
50 | y,
51 | n_fft=1024,
52 | num_mels=80,
53 | sampling_rate=24000,
54 | hop_size=240,
55 | win_size=1024,
56 | fmin=0,
57 | fmax=8000,
58 | center=False,
59 | output_energy=False,
60 | ):
61 | if torch.min(y) < -1.:
62 | print('min value is ', torch.min(y))
63 | if torch.max(y) > 1.:
64 | print('max value is ', torch.max(y))
65 |
66 | global mel_basis, hann_window
67 | if fmax not in mel_basis:
68 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
69 | mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
70 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
71 |
72 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
73 | y = y.squeeze(1)
74 |
75 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
76 | center=center, pad_mode='reflect', normalized=False, onesided=True)
77 |
78 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
79 | mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
80 | mel_spec = spectral_normalize_torch(mel_spec)
81 | if output_energy:
82 | energy = torch.norm(spec, dim=1)
83 | return mel_spec, energy
84 | else:
85 | return mel_spec
86 |
87 |
88 | # def get_dataset_filelist(a):
89 | # with open(a.input_training_file, 'r', encoding='utf-8') as fi:
90 | # training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
91 | # for x in fi.read().split('\n') if len(x) > 0]
92 |
93 | # with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
94 | # validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
95 | # for x in fi.read().split('\n') if len(x) > 0]
96 | # return training_files, validation_files
97 |
98 |
99 | # class MelDataset(torch.utils.data.Dataset):
100 | # def __init__(self, training_files, segment_size, n_fft, num_mels,
101 | # hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
102 | # device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
103 | # self.audio_files = training_files
104 | # random.seed(1234)
105 | # if shuffle:
106 | # random.shuffle(self.audio_files)
107 | # self.segment_size = segment_size
108 | # self.sampling_rate = sampling_rate
109 | # self.split = split
110 | # self.n_fft = n_fft
111 | # self.num_mels = num_mels
112 | # self.hop_size = hop_size
113 | # self.win_size = win_size
114 | # self.fmin = fmin
115 | # self.fmax = fmax
116 | # self.fmax_loss = fmax_loss
117 | # self.cached_wav = None
118 | # self.n_cache_reuse = n_cache_reuse
119 | # self._cache_ref_count = 0
120 | # self.device = device
121 | # self.fine_tuning = fine_tuning
122 | # self.base_mels_path = base_mels_path
123 |
124 | # def __getitem__(self, index):
125 | # filename = self.audio_files[index]
126 | # if self._cache_ref_count == 0:
127 | # audio, sampling_rate = load_wav(filename)
128 | # audio = audio / MAX_WAV_VALUE
129 | # if not self.fine_tuning:
130 | # audio = normalize(audio) * 0.95
131 | # self.cached_wav = audio
132 | # if sampling_rate != self.sampling_rate:
133 | # raise ValueError("{} SR doesn't match target {} SR".format(
134 | # sampling_rate, self.sampling_rate))
135 | # self._cache_ref_count = self.n_cache_reuse
136 | # else:
137 | # audio = self.cached_wav
138 | # self._cache_ref_count -= 1
139 |
140 | # audio = torch.FloatTensor(audio)
141 | # audio = audio.unsqueeze(0)
142 |
143 | # if not self.fine_tuning:
144 | # if self.split:
145 | # if audio.size(1) >= self.segment_size:
146 | # max_audio_start = audio.size(1) - self.segment_size
147 | # audio_start = random.randint(0, max_audio_start)
148 | # audio = audio[:, audio_start:audio_start+self.segment_size]
149 | # else:
150 | # audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
151 |
152 | # mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
153 | # self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
154 | # center=False)
155 | # else:
156 | # mel = np.load(
157 | # os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
158 | # mel = torch.from_numpy(mel)
159 |
160 | # if len(mel.shape) < 3:
161 | # mel = mel.unsqueeze(0)
162 |
163 | # if self.split:
164 | # frames_per_seg = math.ceil(self.segment_size / self.hop_size)
165 |
166 | # if audio.size(1) >= self.segment_size:
167 | # mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
168 | # mel = mel[:, :, mel_start:mel_start + frames_per_seg]
169 | # audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
170 | # else:
171 | # mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
172 | # audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
173 |
174 | # mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
175 | # self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
176 | # center=False)
177 |
178 | # return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
179 |
180 | # def __len__(self):
181 | # return len(self.audio_files)
182 |
--------------------------------------------------------------------------------
/src/basic_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torch.autograd import Function
5 |
6 | def tile(x, count, dim=0):
7 | """
8 | Tiles x on dimension dim count times.
9 | """
10 | perm = list(range(len(x.size())))
11 | if dim != 0:
12 | perm[0], perm[dim] = perm[dim], perm[0]
13 | x = x.permute(perm).contiguous()
14 | out_size = list(x.size())
15 | out_size[0] *= count
16 | batch = x.size(0)
17 | x = x.view(batch, -1) \
18 | .transpose(0, 1) \
19 | .repeat(count, 1) \
20 | .transpose(0, 1) \
21 | .contiguous() \
22 | .view(*out_size)
23 | if dim != 0:
24 | x = x.permute(perm).contiguous()
25 | return x
26 |
27 | class Linear(torch.nn.Module):
28 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
29 | super(Linear, self).__init__()
30 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
31 |
32 | torch.nn.init.xavier_uniform_(
33 | self.linear_layer.weight,
34 | gain=torch.nn.init.calculate_gain(w_init_gain))
35 |
36 | def forward(self, x):
37 | return self.linear_layer(x)
38 |
39 | class Conv1d(torch.nn.Module):
40 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
41 | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
42 | super(Conv1d, self).__init__()
43 | if padding is None:
44 | assert(kernel_size % 2 == 1)
45 | padding = int(dilation * (kernel_size - 1)/2)
46 |
47 | self.conv = torch.nn.Conv1d(in_channels, out_channels,
48 | kernel_size=kernel_size, stride=stride,
49 | padding=padding, dilation=dilation,
50 | bias=bias)
51 | torch.nn.init.xavier_uniform_(
52 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
53 |
54 | def forward(self, x):
55 | # x: BxDxT
56 | return self.conv(x)
57 |
58 | class Flatten(nn.Module):
59 | def forward(self, x):
60 | return x.view(x.size(0), -1)
61 |
62 | # Reshape layer
63 | class Reshape(nn.Module):
64 | def __init__(self, outer_shape):
65 | super(Reshape, self).__init__()
66 | self.outer_shape = outer_shape
67 | def forward(self, x):
68 | return x.view(x.size(0), *self.outer_shape)
69 |
70 | # Sample from the Gumbel-Softmax distribution and optionally discretize.
71 | class GumbelSoftmax(nn.Module):
72 |
73 | def __init__(self, f_dim, c_dim):
74 | super(GumbelSoftmax, self).__init__()
75 | self.logits = nn.Linear(f_dim, c_dim)
76 | self.f_dim = f_dim
77 | self.c_dim = c_dim
78 |
79 | def sample_gumbel(self, shape, is_cuda=False, eps=1e-20):
80 | U = torch.rand(shape)
81 | if is_cuda:
82 | U = U.cuda()
83 | return -torch.log(-torch.log(U + eps) + eps)
84 |
85 | def gumbel_softmax_sample(self, logits, temperature):
86 | y = logits + self.sample_gumbel(logits.size(), logits.is_cuda)
87 | return F.softmax(y / temperature, dim=-1)
88 |
89 | def gumbel_softmax(self, logits, temperature, hard=False):
90 | """
91 | ST-gumple-softmax
92 | input: [*, n_class]
93 | return: flatten --> [*, n_class] an one-hot vector
94 | """
95 | #categorical_dim = 10
96 | y = self.gumbel_softmax_sample(logits, temperature)
97 |
98 | if not hard:
99 | return y
100 |
101 | shape = y.size()
102 | _, ind = y.max(dim=-1)
103 | y_hard = torch.zeros_like(y).view(-1, shape[-1])
104 | y_hard.scatter_(1, ind.view(-1, 1), 1)
105 | y_hard = y_hard.view(*shape)
106 | # Set gradients w.r.t. y_hard gradients w.r.t. y
107 | y_hard = (y_hard - y).detach() + y
108 | return y_hard
109 |
110 | def forward(self, x, temperature=1.0, hard=False):
111 | logits = self.logits(x).view(-1, self.c_dim)
112 | prob = F.softmax(logits, dim=-1)
113 | y = self.gumbel_softmax(logits, temperature, hard)
114 | return logits, prob, y
115 |
116 | class Softmax(nn.Module):
117 |
118 | def __init__(self, in_dim, out_dim):
119 | super().__init__()
120 | self.logits = nn.Linear(in_dim, out_dim)
121 | self.in_dim = in_dim
122 | self.out_dim = out_dim
123 |
124 | def forward(self, x):
125 | logits = self.logits(x).view(-1, self.out_dim)
126 | prob = F.softmax(logits, dim=-1)
127 | return logits, prob
128 |
129 | # Sample from a Gaussian distribution
130 | class Gaussian(nn.Module):
131 | def __init__(self, in_dim, z_dim, use_bias=True):
132 | super(Gaussian, self).__init__()
133 | self.mu = nn.Linear(in_dim, z_dim, bias=use_bias)
134 | self.log_std = nn.Linear(in_dim, z_dim, bias=use_bias)
135 | # self.mu.weight.data.fill_(0.0)
136 | # self.log_std.weight.data.fill_(0.0)
137 | # if use_bias:
138 | # self.mu.bias.data.fill_(0.0)
139 | # self.log_std.bias.data.fill_(0.0)
140 |
141 | def reparameterize(self, mu, log_std):
142 | std = torch.exp(log_std)
143 | noise = torch.randn_like(std)
144 | z = mu + noise * std
145 | return z
146 |
147 | def forward(self, x):
148 | mu = self.mu(x)
149 | log_std = self.log_std(x)
150 | z = self.reparameterize(mu, log_std)
151 | return mu, log_std, z
152 |
153 | # https://github.com/fungtion/DANN/blob/master/models/functions.py
154 | class ReverseLayerF(Function):
155 |
156 | @staticmethod
157 | def forward(ctx, x, alpha):
158 | ctx.alpha = alpha
159 | return x.view_as(x)
160 |
161 | @staticmethod
162 | def backward(ctx, grad_output):
163 | output = grad_output.neg() * ctx.alpha
164 | return output, None
165 | # class Gaussian(nn.Module):
166 | # def __init__(self, in_dim, z_dim):
167 | # super(Gaussian, self).__init__()
168 | # self.mu = nn.Linear(in_dim, z_dim)
169 | # self.var = nn.Linear(in_dim, z_dim)
170 |
171 | # def reparameterize(self, mu, var):
172 | # std = torch.sqrt(var + 1e-10)
173 | # noise = torch.randn_like(std)
174 | # z = mu + noise * std
175 | # return z
176 |
177 | # def forward(self, x):
178 | # mu = self.mu(x)
179 | # var = F.softplus(self.var(x))
180 | # z = self.reparameterize(mu, var)
181 | # return mu, var, z
182 |
183 | def tile(x, count, dim=0):
184 | """
185 | Tiles x on dimension dim count times.
186 | """
187 | perm = list(range(len(x.size())))
188 | if dim != 0:
189 | perm[0], perm[dim] = perm[dim], perm[0]
190 | x = x.permute(perm).contiguous()
191 | out_size = list(x.size())
192 | out_size[0] *= count
193 | batch = x.size(0)
194 | x = x.view(batch, -1) \
195 | .transpose(0, 1) \
196 | .repeat(count, 1) \
197 | .transpose(0, 1) \
198 | .contiguous() \
199 | .view(*out_size)
200 | if dim != 0:
201 | x = x.permute(perm).contiguous()
202 | return x
203 |
204 |
205 | def sort_batch(data, lengths):
206 | '''
207 | sort data by length
208 | sorted_data[initial_index] == data
209 | '''
210 | sorted_lengths, sorted_index = lengths.sort(0, descending=True)
211 | sorted_data = data[sorted_index]
212 | _, initial_index = sorted_index.sort(0, descending=False)
213 |
214 | return sorted_data, sorted_lengths, initial_index
215 |
--------------------------------------------------------------------------------
/src/cnn_postnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .basic_layers import Linear, Conv1d
5 |
6 |
7 | class Postnet(nn.Module):
8 | """Postnet
9 | - Five 1-d convolution with 512 channels and kernel size 5
10 | """
11 | def __init__(self, num_mels=80,
12 | num_layers=5,
13 | hidden_dim=512,
14 | kernel_size=5):
15 | super(Postnet, self).__init__()
16 | self.convolutions = nn.ModuleList()
17 |
18 | self.convolutions.append(
19 | nn.Sequential(
20 | Conv1d(
21 | num_mels, hidden_dim,
22 | kernel_size=kernel_size, stride=1,
23 | padding=int((kernel_size - 1) / 2),
24 | dilation=1, w_init_gain='tanh'),
25 | nn.BatchNorm1d(hidden_dim)))
26 |
27 | for i in range(1, num_layers - 1):
28 | self.convolutions.append(
29 | nn.Sequential(
30 | Conv1d(
31 | hidden_dim,
32 | hidden_dim,
33 | kernel_size=kernel_size, stride=1,
34 | padding=int((kernel_size - 1) / 2),
35 | dilation=1, w_init_gain='tanh'),
36 | nn.BatchNorm1d(hidden_dim)))
37 |
38 | self.convolutions.append(
39 | nn.Sequential(
40 | Conv1d(
41 | hidden_dim, num_mels,
42 | kernel_size=kernel_size, stride=1,
43 | padding=int((kernel_size - 1) / 2),
44 | dilation=1, w_init_gain='linear'),
45 | nn.BatchNorm1d(num_mels)))
46 |
47 | def forward(self, x):
48 | # x: (B, num_mels, T_dec)
49 | for i in range(len(self.convolutions) - 1):
50 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
51 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
52 | return x
53 |
--------------------------------------------------------------------------------
/src/f0_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import pyworld
4 | from scipy.interpolate import interp1d
5 | from scipy.signal import firwin, get_window, lfilter
6 |
7 |
8 | def compute_f0(wav, sr=16000, frame_period=10.0):
9 | """Compute f0 from wav using pyworld harvest algorithm."""
10 | wav = wav.astype(np.float64)
11 | f0, _ = pyworld.harvest(
12 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0)
13 | return f0.astype(np.float32)
14 |
15 |
16 | def low_pass_filter(x, fs, cutoff=70, padding=True):
17 | """FUNCTION TO APPLY LOW PASS FILTER
18 |
19 | Args:
20 | x (ndarray): Waveform sequence
21 | fs (int): Sampling frequency
22 | cutoff (float): Cutoff frequency of low pass filter
23 |
24 | Return:
25 | (ndarray): Low pass filtered waveform sequence
26 | """
27 |
28 | nyquist = fs // 2
29 | norm_cutoff = cutoff / nyquist
30 |
31 | # low cut filter
32 | numtaps = 255
33 | fil = firwin(numtaps, norm_cutoff)
34 | x_pad = np.pad(x, (numtaps, numtaps), 'edge')
35 | lpf_x = lfilter(fil, 1, x_pad)
36 | lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2]
37 |
38 | return lpf_x
39 |
40 |
41 | def convert_continuos_f0(f0):
42 | """CONVERT F0 TO CONTINUOUS F0
43 |
44 | Args:
45 | f0 (ndarray): original f0 sequence with the shape (T)
46 |
47 | Return:
48 | (ndarray): continuous f0 with the shape (T)
49 | """
50 | # get uv information as binary
51 | uv = np.float32(f0 != 0)
52 |
53 | # get start and end of f0
54 | if (f0 == 0).all():
55 | logging.warn("all of the f0 values are 0.")
56 | return uv, f0
57 | start_f0 = f0[f0 != 0][0]
58 | end_f0 = f0[f0 != 0][-1]
59 |
60 | # padding start and end of f0 sequence
61 | start_idx = np.where(f0 == start_f0)[0][0]
62 | end_idx = np.where(f0 == end_f0)[0][-1]
63 | f0[:start_idx] = start_f0
64 | f0[end_idx:] = end_f0
65 |
66 | # get non-zero frame index
67 | nz_frames = np.where(f0 != 0)[0]
68 |
69 | # perform linear interpolation
70 | f = interp1d(nz_frames, f0[nz_frames])
71 | cont_f0 = f(np.arange(0, f0.shape[0]))
72 |
73 | return uv, cont_f0
74 |
75 |
76 | def get_cont_lf0(f0, frame_period=10.0, lpf=False):
77 | uv, cont_f0 = convert_continuos_f0(f0)
78 | if lpf:
79 | cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20)
80 | cont_lf0_lpf = cont_f0_lpf.copy()
81 | nonzero_indices = np.nonzero(cont_lf0_lpf)
82 | cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices])
83 | # cont_lf0_lpf = np.log(cont_f0_lpf)
84 | return uv, cont_lf0_lpf
85 | else:
86 | nonzero_indices = np.nonzero(cont_f0)
87 | cont_lf0 = cont_f0.copy()
88 | cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0])
89 | return uv, cont_lf0
90 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from typing import Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .nets_utils import make_pad_mask
9 |
10 |
11 | class MaskedMSELoss(nn.Module):
12 | def __init__(self, frames_per_step):
13 | super().__init__()
14 | self.frames_per_step = frames_per_step
15 | self.mel_loss_criterion = nn.MSELoss(reduction='none')
16 | # self.loss = nn.MSELoss()
17 | self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
18 |
19 | def get_mask(self, lengths, max_len=None):
20 | # lengths: [B,]
21 | if max_len is None:
22 | max_len = torch.max(lengths)
23 | batch_size = lengths.size(0)
24 | seq_range = torch.arange(0, max_len).long()
25 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
26 | seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
27 | return (seq_range_expand < seq_length_expand).float()
28 |
29 | def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
30 | stop_target, stop_pred):
31 | ## process stop_target
32 | B = stop_target.size(0)
33 | stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
34 | stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
35 | stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
36 |
37 | mel_trg.requires_grad = False
38 | # (B, T, 1)
39 | mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
40 | # (B, T, D)
41 | mel_mask = mel_mask.expand_as(mel_trg)
42 | mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
43 | mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
44 |
45 | mel_loss = mel_loss_pre + mel_loss_post
46 |
47 | # stop token loss
48 | stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
49 |
50 | return mel_loss, stop_loss
51 |
--------------------------------------------------------------------------------
/src/loss_fn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MaskedMSELoss(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.loss = nn.MSELoss(reduction='none')
9 |
10 | def get_mask(self, lengths):
11 | # lengths: [B,]
12 | max_len = torch.max(lengths)
13 | batch_size = lengths.size(0)
14 | seq_range = torch.arange(0, max_len).long()
15 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
16 | seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
17 | return (seq_range_expand < seq_length_expand).float()
18 |
19 | def forward(self, mel_pred, mel_trg, lengths):
20 | # (B, T, 1)
21 | mask = self.get_mask(lengths).unsqueeze(-1)
22 | # (B, T, D)
23 | mask_ = mask.expand_as(mel_trg)
24 | loss = self.loss(mel_pred, mel_trg)
25 | return ((loss * mask_).sum()) / mask_.sum()
26 |
--------------------------------------------------------------------------------
/src/lsa_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from .basic_layers import Linear, Conv1d
5 |
6 |
7 |
8 | class LocationLayer(nn.Module):
9 | def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
10 | super().__init__()
11 | padding = int((attention_kernel_size - 1) / 2)
12 | self.location_conv = Conv1d(2, attention_n_filters,
13 | kernel_size=attention_kernel_size,
14 | padding=padding, bias=False,
15 | stride=1, dilation=1)
16 | self.location_dense = Linear(attention_n_filters, attention_dim,
17 | bias=False, w_init_gain='tanh')
18 |
19 | def forward(self, attention_weights_cat):
20 | processed_attention = self.location_conv(attention_weights_cat)
21 | processed_attention = processed_attention.transpose(1, 2)
22 | processed_attention = self.location_dense(processed_attention)
23 | return processed_attention
24 |
25 | class LocationSensitiveAttention(nn.Module):
26 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
27 | attention_location_n_filters, attention_location_kernel_size):
28 | super().__init__()
29 | self.query_layer = Linear(attention_rnn_dim, attention_dim,
30 | bias=False, w_init_gain='tanh')
31 | self.memory_layer = Linear(embedding_dim, attention_dim,
32 | bias=False, w_init_gain='tanh')
33 | self.v = Linear(attention_dim, 1, bias=False)
34 | self.location_layer = LocationLayer(attention_location_n_filters,
35 | attention_location_kernel_size,
36 | attention_dim)
37 | self.score_mask_value = -float('inf')
38 |
39 | def get_alignment_energies(self, query, processed_memory,
40 | attention_weights_cat):
41 | """
42 | Args:
43 | query: attention rnn output (B, att_rnn_dim)
44 | processed_memory: processed encoder outputs (B, T_in, enc_dim)
45 | attention_weights_cat: cumulative and previous attention weights (B, 2, T_in)
46 | Returns:
47 | attention alignment (B, T_in)
48 | """
49 | processed_query = self.query_layer(query.unsqueeze(1))
50 | processed_attention_weights = self.location_layer(attention_weights_cat)
51 | energies = self.v(torch.tanh(
52 | processed_query + processed_attention_weights + processed_memory))
53 | energies = energies.squeeze(-1)
54 | return energies
55 |
56 | def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask):
57 | """
58 | Args:
59 | attention_hidden_state: attention rnn last output
60 | memory: encoder outputs
61 | processed_memory: processed encoder outputs
62 | attention_weights_cat: previous and cumulative attention weights
63 | mask: binary mask for padded data
64 | """
65 | alignment = self.get_alignment_energies(
66 | attention_hidden_state, processed_memory, attention_weights_cat)
67 |
68 | if mask is not None:
69 | alignment.data.masked_fill_(mask, self.score_mask_value)
70 |
71 | attention_weights = F.softmax(alignment, dim=1)
72 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
73 | attention_context = attention_context.squeeze(1)
74 |
75 | return attention_context, attention_weights
76 |
77 |
78 | class ForwardAttentionV2(nn.Module):
79 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
80 | attention_location_n_filters, attention_location_kernel_size):
81 | super(ForwardAttentionV2, self).__init__()
82 | self.query_layer = Linear(attention_rnn_dim, attention_dim,
83 | bias=False, w_init_gain='tanh')
84 | self.memory_layer = Linear(embedding_dim, attention_dim, bias=False,
85 | w_init_gain='tanh')
86 | self.v = Linear(attention_dim, 1, bias=False)
87 | self.location_layer = LocationLayer(attention_location_n_filters,
88 | attention_location_kernel_size,
89 | attention_dim)
90 | self.score_mask_value = -float(1e20)
91 |
92 | def get_alignment_energies(self, query, processed_memory,
93 | attention_weights_cat):
94 | """
95 | PARAMS
96 | ------
97 | query: decoder output (batch, n_mel_channels * n_frames_per_step)
98 | processed_memory: processed encoder outputs (B, T_in, attention_dim)
99 | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
100 | RETURNS
101 | -------
102 | alignment (batch, max_time)
103 | """
104 |
105 | processed_query = self.query_layer(query.unsqueeze(1))
106 | processed_attention_weights = self.location_layer(attention_weights_cat)
107 | energies = self.v(torch.tanh(
108 | processed_query + processed_attention_weights + processed_memory))
109 |
110 | energies = energies.squeeze(-1)
111 | return energies
112 |
113 | def forward(self, attention_hidden_state, memory, processed_memory,
114 | attention_weights_cat, mask, log_alpha):
115 | """
116 | PARAMS
117 | ------
118 | attention_hidden_state: attention rnn last output
119 | memory: encoder outputs
120 | processed_memory: processed encoder outputs
121 | attention_weights_cat: previous and cummulative attention weights
122 | mask: binary mask for padded data
123 | """
124 | log_energy = self.get_alignment_energies(
125 | attention_hidden_state, processed_memory, attention_weights_cat)
126 |
127 | #log_energy =
128 |
129 | if mask is not None:
130 | log_energy.data.masked_fill_(mask, self.score_mask_value)
131 |
132 | #attention_weights = F.softmax(alignment, dim=1)
133 |
134 | #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
135 | #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
136 |
137 | #log_total_score = log_alpha + content_score
138 |
139 | #previous_attention_weights = attention_weights_cat[:,0,:]
140 |
141 | log_alpha_shift_padded = []
142 | max_time = log_energy.size(1)
143 | for sft in range(2):
144 | shifted = log_alpha[:,:max_time-sft]
145 | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
146 | log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
147 |
148 | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
149 |
150 | log_alpha_new = biased + log_energy
151 |
152 | attention_weights = F.softmax(log_alpha_new, dim=1)
153 |
154 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
155 | attention_context = attention_context.squeeze(1)
156 |
157 | return attention_context, attention_weights, log_alpha_new
158 |
159 |
--------------------------------------------------------------------------------
/src/mol_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MOLAttention(nn.Module):
7 | """ Discretized Mixture of Logistic (MOL) attention.
8 | C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
9 | GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
10 | """
11 | def __init__(
12 | self,
13 | query_dim,
14 | r=1,
15 | M=5,
16 | ):
17 | """
18 | Args:
19 | query_dim: attention_rnn_dim.
20 | M: number of mixtures.
21 | """
22 | super().__init__()
23 | if r < 1:
24 | self.r = float(r)
25 | else:
26 | self.r = int(r)
27 | self.M = M
28 | self.score_mask_value = 0.0 # -float("inf")
29 | self.eps = 1e-5
30 | # Position arrary for encoder time steps
31 | self.J = None
32 | # Query layer: [w, sigma,]
33 | self.query_layer = torch.nn.Sequential(
34 | nn.Linear(query_dim, 256, bias=True),
35 | nn.ReLU(),
36 | nn.Linear(256, 3*M, bias=True)
37 | )
38 | self.mu_prev = None
39 | self.initialize_bias()
40 |
41 | def initialize_bias(self):
42 | """Initialize sigma and Delta."""
43 | # sigma
44 | torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
45 | # Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
46 | # softplus(-0.432) = 0.5003
47 | if self.r == 2:
48 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
49 | elif self.r == 4:
50 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
51 | elif self.r == 1:
52 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
53 | else:
54 | torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
55 |
56 |
57 | def init_states(self, memory):
58 | """Initialize mu_prev and J.
59 | This function should be called by the decoder before decoding one batch.
60 | Args:
61 | memory: (B, T, D_enc) encoder output.
62 | """
63 | B, T_enc, _ = memory.size()
64 | device = memory.device
65 | self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
66 | # self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
67 | self.mu_prev = torch.zeros(B, self.M).to(device)
68 |
69 | def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
70 | """
71 | att_rnn_h: attetion rnn hidden state.
72 | memory: encoder outputs (B, T_enc, D).
73 | mask: binary mask for padded data (B, T_enc).
74 | """
75 | # [B, 3M]
76 | mixture_params = self.query_layer(att_rnn_h)
77 |
78 | # [B, M]
79 | w_hat = mixture_params[:, :self.M]
80 | sigma_hat = mixture_params[:, self.M:2*self.M]
81 | Delta_hat = mixture_params[:, 2*self.M:3*self.M]
82 |
83 | # print("w_hat: ", w_hat)
84 | # print("sigma_hat: ", sigma_hat)
85 | # print("Delta_hat: ", Delta_hat)
86 |
87 | # Dropout to de-correlate attention heads
88 | w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
89 |
90 | # Mixture parameters
91 | w = torch.softmax(w_hat, dim=-1) + self.eps
92 | sigma = F.softplus(sigma_hat) + self.eps
93 | Delta = F.softplus(Delta_hat)
94 | mu_cur = self.mu_prev + Delta
95 | # print("w:", w)
96 | j = self.J[:memory.size(1) + 1]
97 |
98 | # Attention weights
99 | # CDF of logistic distribution
100 | phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
101 | (mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
102 | # print("phi_t:", phi_t)
103 |
104 | # Discretize attention weights
105 | # (B, T_enc + 1)
106 | alpha_t = torch.sum(phi_t, dim=1)
107 | alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
108 | alpha_t[alpha_t == 0] = self.eps
109 | # print("alpha_t: ", alpha_t.size())
110 | # Apply masking
111 | if mask is not None:
112 | alpha_t.data.masked_fill_(mask, self.score_mask_value)
113 |
114 | context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
115 | if memory_pitch is not None:
116 | context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
117 |
118 | self.mu_prev = mu_cur
119 |
120 | if memory_pitch is not None:
121 | return context, context_pitch, alpha_t
122 | return context, alpha_t
123 |
124 |
--------------------------------------------------------------------------------
/src/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class Optimizer():
6 | def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
7 | **kwargs):
8 |
9 | # Setup torch optimizer
10 | self.opt_type = optimizer
11 | self.init_lr = lr
12 | self.sch_type = lr_scheduler
13 | opt = getattr(torch.optim, optimizer)
14 | if lr_scheduler == 'warmup':
15 | warmup_step = 4000.0
16 | init_lr = lr
17 | self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
18 | np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
19 | self.opt = opt(parameters, lr=1.0)
20 | else:
21 | self.lr_scheduler = None
22 | self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
23 |
24 | def get_opt_state_dict(self):
25 | return self.opt.state_dict()
26 |
27 | def load_opt_state_dict(self, state_dict):
28 | self.opt.load_state_dict(state_dict)
29 |
30 | def pre_step(self, step):
31 | if self.lr_scheduler is not None:
32 | cur_lr = self.lr_scheduler(step)
33 | for param_group in self.opt.param_groups:
34 | param_group['lr'] = cur_lr
35 | else:
36 | cur_lr = self.init_lr
37 | self.opt.zero_grad()
38 | return cur_lr
39 |
40 | def step(self):
41 | self.opt.step()
42 |
43 | def create_msg(self):
44 | return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
45 | .format(self.opt_type, self.init_lr, self.sch_type)]
46 |
--------------------------------------------------------------------------------
/src/option.py:
--------------------------------------------------------------------------------
1 | # Default parameters which will be imported by solver
2 | default_hparas = {
3 | 'GRAD_CLIP': 5.0, # Grad. clip threshold
4 | 'PROGRESS_STEP': 100, # Std. output refresh freq.
5 | # Decode steps for objective validation (step = ratio*input_txt_len)
6 | 'DEV_STEP_RATIO': 1.2,
7 | # Number of examples (alignment/text) to show in tensorboard
8 | 'DEV_N_EXAMPLE': 4,
9 | 'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
10 | }
11 |
--------------------------------------------------------------------------------
/src/rnn_ppg2mel.py:
--------------------------------------------------------------------------------
1 | """Sequential implementation of Recurrent Neural Network Duration Model."""
2 | from typing import Tuple
3 | from typing import Union
4 |
5 | import torch
6 | import torch.nn as nn
7 | from typeguard import check_argument_types
8 |
9 |
10 | class BiRnnPpg2MelModel(torch.nn.Module):
11 | """ Bidirectional RNN-based PPG-to-Mel Model for voice conversion tasks.
12 | RNN could be LSTM-based or GRU-based.
13 | """
14 | def __init__(
15 | self,
16 | input_size: int,
17 | multi_spk: bool = False,
18 | num_speakers: int = 1,
19 | spk_embed_dim: int = 256,
20 | use_spk_dvec: bool = False,
21 | multi_styles: bool = False,
22 | num_styles: int = 3,
23 | style_embed_dim: int = 256,
24 | dense_layer_size: int = 256,
25 | num_layers: int = 4,
26 | bidirectional: bool = True,
27 | hidden_dim: int = 256,
28 | dropout_rate: float = 0.5,
29 | output_size: int = 80,
30 | rnn_type: str = "lstm"
31 | ):
32 | assert check_argument_types()
33 | super().__init__()
34 |
35 | self.multi_spk = multi_spk
36 | self.spk_embed_dim = spk_embed_dim
37 | self.use_spk_dvec= use_spk_dvec
38 | self.multi_styles = multi_styles
39 | self.style_embed_dim = style_embed_dim
40 | self.hidden_dim = hidden_dim
41 | self.num_layers = num_layers
42 | self.num_direction = 2 if bidirectional else 1
43 |
44 | self.ppg_dense_layer = nn.Linear(input_size - 2, hidden_dim)
45 | self.logf0_uv_layer = nn.Linear(2, hidden_dim)
46 |
47 | projection_input_size = hidden_dim
48 | if self.multi_spk:
49 | if not self.use_spk_dvec:
50 | self.spk_embedding = nn.Embedding(num_speakers, spk_embed_dim)
51 | projection_input_size += self.spk_embed_dim
52 | if self.multi_styles:
53 | self.style_embedding = nn.Embedding(num_styles, style_embed_dim)
54 | projection_input_size += self.style_embed_dim
55 |
56 | self.reduce_proj = nn.Sequential(
57 | nn.Linear(projection_input_size, hidden_dim),
58 | nn.ReLU(),
59 | nn.Dropout(dropout_rate)
60 | )
61 |
62 | rnn_type = rnn_type.upper()
63 | if rnn_type in ["LSTM", "GRU"]:
64 | rnn_class = getattr(nn, rnn_type)
65 | self.rnn = rnn_class(
66 | hidden_dim, hidden_dim, num_layers,
67 | bidirectional=bidirectional,
68 | dropout=dropout_rate,
69 | batch_first=True)
70 | else:
71 | # Default: use BiLSTM
72 | self.rnn = nn.LSTM(
73 | hidden_dim, hidden_dim, num_layers,
74 | bidirectional=bidirectional,
75 | dropout_rate=dropout_rate,
76 | batch_first=True)
77 | # Fully connected layers
78 | self.hidden2out_layers = nn.Sequential(
79 | nn.Linear(self.num_direction * hidden_dim, dense_layer_size),
80 | nn.ReLU(),
81 | nn.Dropout(dropout_rate),
82 | nn.Linear(dense_layer_size, output_size)
83 | )
84 |
85 | def forward(
86 | self,
87 | ppg: torch.Tensor,
88 | ppg_lengths: torch.Tensor,
89 | logf0_uv: torch.Tensor,
90 | spembs: torch.Tensor = None,
91 | styleembs: torch.Tensor = None,
92 | ) -> torch.Tensor:
93 | """
94 | Args:
95 | ppg (tensor): [B, T, D_ppg]
96 | ppg_lengths (tensor): [B,]
97 | logf0_uv (tensor): [B, T, 2], concatented logf0 and u/v flags.
98 | spembs (tensor): [B,] index-represented speaker.
99 | styleembs (tensor): [B,] index-repreented speaking style (e.g. emotion).
100 | """
101 | ppg = self.ppg_dense_layer(ppg)
102 | logf0_uv = self.logf0_uv_layer(logf0_uv)
103 |
104 | ## Concatenate/add ppg and logf0_uv
105 | x = ppg + logf0_uv
106 | B, T, _ = x.size()
107 |
108 | if self.multi_spk:
109 | assert spembs is not None
110 | # spk_embs = self.spk_embedding(torch.LongTensor([0,]*ppg.size(0)).to(ppg.device))
111 | if not self.use_spk_dvec:
112 | spk_embs = self.spk_embedding(spembs)
113 | spk_embs = torch.nn.functional.normalize(
114 | spk_embs).unsqueeze(1).expand(-1, T, -1)
115 | else:
116 | spk_embs = torch.nn.functional.normalize(
117 | spembs).unsqueeze(1).expand(-1, T, -1)
118 | x = torch.cat([x, spk_embs], dim=2)
119 |
120 | if self.multi_styles and styleembs is not None:
121 | style_embs = self.style_embedding(styleembs)
122 | style_embs = torch.nn.functional.normalize(
123 | style_embs).unsqueeze(1).expand(-1, T, -1)
124 | x = torch.cat([x, style_embs], dim=2)
125 | ## FC projection
126 | x = self.reduce_proj(x)
127 |
128 | if ppg_lengths is not None:
129 | x = torch.nn.utils.rnn.pack_padded_sequence(x, ppg_lengths,
130 | batch_first=True,
131 | enforce_sorted=False)
132 | x, _ = self.rnn(x)
133 | if ppg_lengths is not None:
134 | x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
135 | x = self.hidden2out_layers(x)
136 |
137 | return x
138 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 | import math
5 | import time
6 | import torch
7 | import numpy as np
8 | from torch import nn
9 | import editdistance as ed
10 |
11 |
12 | class Timer():
13 | ''' Timer for recording training time distribution. '''
14 |
15 | def __init__(self):
16 | self.prev_t = time.time()
17 | self.clear()
18 |
19 | def set(self):
20 | self.prev_t = time.time()
21 |
22 | def cnt(self, mode):
23 | self.time_table[mode] += time.time()-self.prev_t
24 | self.set()
25 | if mode == 'bw':
26 | self.click += 1
27 |
28 | def show(self):
29 | total_time = sum(self.time_table.values())
30 | self.time_table['avg'] = total_time/self.click
31 | self.time_table['rd'] = 100*self.time_table['rd']/total_time
32 | self.time_table['fw'] = 100*self.time_table['fw']/total_time
33 | self.time_table['bw'] = 100*self.time_table['bw']/total_time
34 | msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format(
35 | **self.time_table)
36 | self.clear()
37 | return msg
38 |
39 | def clear(self):
40 | self.time_table = {'rd': 0, 'fw': 0, 'bw': 0}
41 | self.click = 0
42 |
43 | # Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168
44 |
45 |
46 | def init_weights(module):
47 | # Exceptions
48 | if type(module) == nn.Embedding:
49 | module.weight.data.normal_(0, 1)
50 | else:
51 | for p in module.parameters():
52 | data = p.data
53 | if data.dim() == 1:
54 | # bias
55 | data.zero_()
56 | elif data.dim() == 2:
57 | # linear weight
58 | n = data.size(1)
59 | stdv = 1. / math.sqrt(n)
60 | data.normal_(0, stdv)
61 | elif data.dim() in [3, 4]:
62 | # conv weight
63 | n = data.size(1)
64 | for k in data.size()[2:]:
65 | n *= k
66 | stdv = 1. / math.sqrt(n)
67 | data.normal_(0, stdv)
68 | else:
69 | raise NotImplementedError
70 |
71 |
72 | def init_gate(bias):
73 | n = bias.size(0)
74 | start, end = n // 4, n // 2
75 | bias.data[start:end].fill_(1.)
76 | return bias
77 |
78 | # Convert Tensor to Figure on tensorboard
79 |
80 |
81 | def feat_to_fig(feat):
82 | # feat TxD tensor
83 | data = _save_canvas(feat.numpy())
84 | return torch.FloatTensor(data), "HWC"
85 |
86 |
87 | def _save_canvas(data, meta=None):
88 | fig, ax = plt.subplots(figsize=(16, 8))
89 | if meta is None:
90 | ax.imshow(data, aspect="auto", origin="lower")
91 | else:
92 | ax.bar(meta[0], data[0], tick_label=meta[1], fc=(0, 0, 1, 0.5))
93 | ax.bar(meta[0], data[1], tick_label=meta[1], fc=(1, 0, 0, 0.5))
94 | fig.canvas.draw()
95 | # Note : torch tb add_image takes color as [0,1]
96 | data = np.array(fig.canvas.renderer._renderer)[:, :, :-1]/255.0
97 | plt.close(fig)
98 | return data
99 |
100 | # Reference : https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings-in-python
101 |
102 |
103 | def human_format(num):
104 | magnitude = 0
105 | while num >= 1000:
106 | magnitude += 1
107 | num /= 1000.0
108 | # add more suffixes if you need them
109 | return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
110 |
111 |
112 | def cal_er(tokenizer, pred, truth, mode='wer', ctc=False):
113 | # Calculate error rate of a batch
114 | if pred is None:
115 | return np.nan
116 | elif len(pred.shape) >= 3:
117 | pred = pred.argmax(dim=-1)
118 | er = []
119 | for p, t in zip(pred, truth):
120 | p = tokenizer.decode(p.tolist(), ignore_repeat=ctc)
121 | t = tokenizer.decode(t.tolist())
122 | if mode == 'wer':
123 | p = p.split(' ')
124 | t = t.split(' ')
125 | er.append(float(ed.eval(p, t))/len(t))
126 | return sum(er)/len(er)
127 |
128 |
129 | def load_embedding(text_encoder, embedding_filepath):
130 | with open(embedding_filepath, "r") as f:
131 | vocab_size, embedding_size = [int(x)
132 | for x in f.readline().strip().split()]
133 | embeddings = np.zeros((text_encoder.vocab_size, embedding_size))
134 |
135 | unk_count = 0
136 |
137 | for line in f:
138 | vocab, emb = line.strip().split(" ", 1)
139 | # fasttext's is
140 | if vocab == "":
141 | vocab = ""
142 |
143 | if text_encoder.token_type == "subword":
144 | idx = text_encoder.spm.piece_to_id(vocab)
145 | else:
146 | # get rid of
147 | idx = text_encoder.encode(vocab)[0]
148 |
149 | if idx == text_encoder.unk_idx:
150 | unk_count += 1
151 | embeddings[idx] += np.asarray([float(x)
152 | for x in emb.split(" ")])
153 | else:
154 | # Suppose there is only one (w, v) pair in embedding file
155 | embeddings[idx] = np.asarray(
156 | [float(x) for x in emb.split(" ")])
157 |
158 | # Average vector
159 | if unk_count != 0:
160 | embeddings[text_encoder.unk_idx] /= unk_count
161 |
162 | return embeddings
163 |
--------------------------------------------------------------------------------
/src/vc_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def gcd(a, b):
6 | """Greatest common divisor."""
7 | a, b = (a, b) if a >=b else (b, a)
8 | if a%b == 0:
9 | return b
10 | else :
11 | return gcd(b, a%b)
12 |
13 | def lcm(a, b):
14 | """Least common multiple"""
15 | return a * b // gcd(a, b)
16 |
17 | def get_mask_from_lengths(lengths, max_len=None):
18 | if max_len is None:
19 | max_len = torch.max(lengths).item()
20 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
21 | mask = (ids < lengths.unsqueeze(1)).bool()
22 | return mask
23 |
24 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | . ./path.sh || exit 1;
2 | export CUDA_VISIBLE_DEVICES=7
3 |
4 | stage=1
5 | stop_stage=1
6 | config=$1
7 | model_file=$2
8 | src_wav_dir=$3
9 | ref_wav_path=$4
10 | echo ${config}
11 |
12 | # =============== One-shot VC ================
13 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
14 | exp_name="$(basename "${config}" .yaml)"
15 | echo Experiment name: "${exp_name}"
16 | # src_wav_dir="/home/shaunxliu/data/cmu_arctic/cmu_us_rms_arctic/wav"
17 | # ref_wav_path="/home/shaunxliu/data/cmu_arctic/cmu_us_slt_arctic/wav/arctic_a0001.wav"
18 | output_dir="vc_gen_wavs/$(basename "${config}" .yaml)"
19 |
20 | python convert_from_wav.py \
21 | --ppg2mel_model_train_config ${config} \
22 | --ppg2mel_model_file ${model_file} \
23 | --src_wav_dir "${src_wav_dir}" \
24 | --ref_wav_path "${ref_wav_path}" \
25 | -o "${output_dir}"
26 | fi
27 |
--------------------------------------------------------------------------------
/tools/Makefile:
--------------------------------------------------------------------------------
1 | PYTHON:= python3.7
2 | CUDA_VERSION:= 10.1
3 | PYTORCH_VERSION:= 1.6
4 | DOT:= .
5 | .PHONY: all clean
6 |
7 | all: virtualenv
8 |
9 | virtualenv:
10 | test -d venv || virtualenv -p $(PYTHON) venv
11 | . venv/bin/activate; pip install torch==$(PYTORCH_VERSION) \
12 | -f https://download.pytorch.org/whl/cu$(subst $(DOT),,$(CUDA_VERSION))/torch_stable.html
13 | . venv/bin/activate; cd ../; pip install -r requirements.txt
14 | touch venv/bin/activate
15 |
16 | clean:
17 | rm -fr venv
18 | find -iname "*.pyc" -delete
19 |
--------------------------------------------------------------------------------
/utils/f0_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import pyworld
4 | from scipy.interpolate import interp1d
5 | from scipy.signal import firwin, get_window, lfilter
6 |
7 |
8 | def compute_f0(wav, sr=16000, frame_period=10.0):
9 | """Compute f0 from wav using pyworld harvest algorithm."""
10 | wav = wav.astype(np.float64)
11 | f0, _ = pyworld.harvest(
12 | wav, sr, frame_period=frame_period, f0_floor=20.0, f0_ceil=600.0)
13 | return f0.astype(np.float32)
14 |
15 |
16 | def low_pass_filter(x, fs, cutoff=70, padding=True):
17 | """FUNCTION TO APPLY LOW PASS FILTER
18 |
19 | Args:
20 | x (ndarray): Waveform sequence
21 | fs (int): Sampling frequency
22 | cutoff (float): Cutoff frequency of low pass filter
23 |
24 | Return:
25 | (ndarray): Low pass filtered waveform sequence
26 | """
27 |
28 | nyquist = fs // 2
29 | norm_cutoff = cutoff / nyquist
30 |
31 | # low cut filter
32 | numtaps = 255
33 | fil = firwin(numtaps, norm_cutoff)
34 | x_pad = np.pad(x, (numtaps, numtaps), 'edge')
35 | lpf_x = lfilter(fil, 1, x_pad)
36 | lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2]
37 |
38 | return lpf_x
39 |
40 |
41 | def convert_continuous_f0(f0):
42 | """CONVERT F0 TO CONTINUOUS F0
43 |
44 | Args:
45 | f0 (ndarray): original f0 sequence with the shape (T)
46 |
47 | Return:
48 | (ndarray): continuous f0 with the shape (T)
49 | """
50 | # get uv information as binary
51 | uv = np.float32(f0 != 0)
52 |
53 | # get start and end of f0
54 | if (f0 == 0).all():
55 | logging.warn("all of the f0 values are 0.")
56 | return uv, f0
57 | start_f0 = f0[f0 != 0][0]
58 | end_f0 = f0[f0 != 0][-1]
59 |
60 | # padding start and end of f0 sequence
61 | start_idx = np.where(f0 == start_f0)[0][0]
62 | end_idx = np.where(f0 == end_f0)[0][-1]
63 | f0[:start_idx] = start_f0
64 | f0[end_idx:] = end_f0
65 |
66 | # get non-zero frame index
67 | nz_frames = np.where(f0 != 0)[0]
68 |
69 | # perform linear interpolation
70 | f = interp1d(nz_frames, f0[nz_frames])
71 | cont_f0 = f(np.arange(0, f0.shape[0]))
72 |
73 | return uv, cont_f0
74 |
75 |
76 | def get_cont_lf0(f0, frame_period=10.0, lpf=False):
77 | uv, cont_f0 = convert_continuous_f0(f0)
78 | if lpf:
79 | cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20)
80 | cont_lf0_lpf = cont_f0_lpf.copy()
81 | nonzero_indices = np.nonzero(cont_lf0_lpf)
82 | cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices])
83 | # cont_lf0_lpf = np.log(cont_f0_lpf)
84 | return uv, cont_lf0_lpf
85 | else:
86 | nonzero_indices = np.nonzero(cont_f0)
87 | cont_lf0 = cont_f0.copy()
88 | cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0])
89 | return uv, cont_lf0
90 |
--------------------------------------------------------------------------------
/utils/file_related.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def load_filepaths_and_text(filename, split="|"):
6 | with open(filename, encoding='utf-8') as f:
7 | filepaths_and_text = [line.strip().split(split) for line in f]
8 | return filepaths_and_text
9 |
10 |
11 |
--------------------------------------------------------------------------------
/utils/load_yaml.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 |
4 | def load_hparams(filename):
5 | stream = open(filename, 'r')
6 | docs = yaml.safe_load_all(stream)
7 | hparams_dict = dict()
8 | for doc in docs:
9 | for k, v in doc.items():
10 | hparams_dict[k] = v
11 | return hparams_dict
12 |
13 | def merge_dict(user, default):
14 | if isinstance(user, dict) and isinstance(default, dict):
15 | for k, v in default.items():
16 | if k not in user:
17 | user[k] = v
18 | else:
19 | user[k] = merge_dict(user[k], v)
20 | return user
21 |
22 | class Dotdict(dict):
23 | """
24 | a dictionary that supports dot notation
25 | as well as dictionary access notation
26 | usage: d = DotDict() or d = DotDict({'val1':'first'})
27 | set attributes: d.val2 = 'second' or d['val2'] = 'second'
28 | get attributes: d.val2 or d['val2']
29 | """
30 | __getattr__ = dict.__getitem__
31 | __setattr__ = dict.__setitem__
32 | __delattr__ = dict.__delitem__
33 |
34 | def __init__(self, dct=None):
35 | dct = dict() if not dct else dct
36 | for key, value in dct.items():
37 | if hasattr(value, 'keys'):
38 | value = Dotdict(value)
39 | self[key] = value
40 |
41 | class HpsYaml(Dotdict):
42 | def __init__(self, yaml_file):
43 | super(Dotdict, self).__init__()
44 | hps = load_hparams(yaml_file)
45 | hp_dict = Dotdict(hps)
46 | for k, v in hp_dict.items():
47 | setattr(self, k, v)
48 |
49 | __getattr__ = Dotdict.__getitem__
50 | __setattr__ = Dotdict.__setitem__
51 | __delattr__ = Dotdict.__delitem__
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/utils/tensor_ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def pad(inputs, max_length=None):
7 |
8 | if max_length:
9 | out_list = list()
10 | for i, mat in enumerate(inputs):
11 | mat_padded = F.pad(
12 | mat, (0, 0, 0, max_length-mat.size(0)), "constant", 0.0)
13 | out_list.append(mat_padded)
14 | out_padded = torch.stack(out_list)
15 | return out_padded
16 | else:
17 | out_list = list()
18 | max_length = max([inputs[i].size(0)for i in range(len(inputs))])
19 |
20 | for i, mat in enumerate(inputs):
21 | mat_padded = F.pad(
22 | mat, (0, 0, 0, max_length-mat.size(0)), "constant", 0.0)
23 | out_list.append(mat_padded)
24 | out_padded = torch.stack(out_list)
25 | return out_padded
26 |
--------------------------------------------------------------------------------
/vocoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/vocoders/__init__.py
--------------------------------------------------------------------------------
/vocoders/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
16 |
--------------------------------------------------------------------------------
/vocoders/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def scan_checkpoint(cp_dir, prefix):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern)
55 | if len(cp_list) == 0:
56 | return None
57 | return sorted(cp_list)[-1]
58 |
59 |
--------------------------------------------------------------------------------
/vocoders/vctk_24k10ms/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0002,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [8,5,3,2],
12 | "upsample_kernel_sizes": [15,15,5,5],
13 | "upsample_initial_channel": 512,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "segment_size": 12000,
18 | "num_mels": 80,
19 | "num_freq": 1025,
20 | "n_fft": 1024,
21 | "hop_size": 240,
22 | "win_size": 1024,
23 |
24 | "sampling_rate": 24000,
25 |
26 | "fmin": 0,
27 | "fmax": 8000,
28 | "fmax_for_loss": null,
29 |
30 | "num_workers": 4,
31 |
32 | "dist_config": {
33 | "dist_backend": "nccl",
34 | "dist_url": "tcp://localhost:54321",
35 | "world_size": 1
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/vocoders/vctk_24k10ms/g_02830000:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liusongxiang/ppg-vc/b59cb9862cf4b82a3bdb589950e25cab85fc9b03/vocoders/vctk_24k10ms/g_02830000
--------------------------------------------------------------------------------