├── .gitignore ├── Ch1 ├── Ch2 ├── Ch3 ├── Ch4 ├── Ch5 ├── Ch6 ├── Ch7 ├── Ch8 ├── Ch9 ├── LICENSE ├── README.md ├── assets ├── converted_melspectrogram.jpg ├── melspectrogram_visualization.jpg ├── model_architecture.png ├── train_commitment_loss.jpg ├── train_perplexity.jpg ├── train_reconstruction_loss.jpg └── train_total_loss.jpg ├── config.py ├── conversion.py ├── data ├── korean_emotional_speech_dataset.py └── vctk.py ├── dataset.py ├── evaluate.py ├── model.py ├── module.py ├── network.py ├── prepro.py ├── requirements.txt ├── train.py ├── utils ├── audio │ ├── __init__.py │ ├── audio_preprocessing.py │ ├── stft.py │ └── tools.py ├── checkpoint.py ├── dataset.py ├── figure.py ├── path.py ├── scheduler.py ├── vocoder.py └── writer.py └── vocoder └── vocgan └── generator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # VQVC ignore list 2 | preprocessed/ 3 | logs/ 4 | ckpts/ 5 | eval_results/ 6 | results/ 7 | vocoder/vocgan/pretrained_models 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /Ch1: -------------------------------------------------------------------------------- 1 | we propose a vector quantization (VQ) based 2 | one-shot voice conversion (VC) approach without any supervision 3 | on speaker label. We model the content embedding 4 | as a series of discrete codes and take the difference between 5 | quantize-before and quantize-after vector as the speaker embedding. 6 | We show that this approach has a strong ability to 7 | disentangle the content and speaker information with reconstruction 8 | loss only, and one-shot VC is thus achieved 9 | -------------------------------------------------------------------------------- /Ch2: -------------------------------------------------------------------------------- 1 | A typical example of Voice Conversion (VC) task is to change 2 | the voice of a source speaker to the voice of a target speaker 3 | without changing the linguistic information. To imitate the 4 | target speaker, a VC system should modify the tone, accent, 5 | and pronunciation of the voice of the source speaker, and this 6 | task can also be formulated as a style transfer problem. In 7 | practice, this technology can be used in many applications 8 | such as entertainment, creativity industry, and virtual implantation 9 | -------------------------------------------------------------------------------- /Ch3: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | DIR_REF='SF1' 5 | DIR_TST='TF1' 6 | 7 | #FILENAME=t1 8 | FILENAME=100020 9 | 10 | #----------------------------------------------------------------- 11 | # TODO Cleanup or delete 12 | WINDOW_SIZE=20; FRAME_RATE=4; 13 | PRM_NAME="mMFPC,ed1E" 14 | PRM_OPT="-q 2,1 -e 0.1 -l $WINDOW_SIZE -d $FRAME_RATE -v HAMMING -o 20 -c 16" 15 | ZASKA="Zaska -P $PRM_NAME $PRM_OPT" 16 | 17 | 18 | # Compute mfcc $DIR_REF/$FILENAME.wav $DIR_TST/$FILENAME.wav => mfcc/$DIR_REF/$FILENAME.prm mfcc/$DIR_TST/$FILENAME.prm 19 | $ZASKA -t RAW -x wav=msw -n . -p mfcc -F $DIR_REF/$FILENAME $DIR_TST/$FILENAME 20 | 21 | # Align: mfcc/$DIR_REF/$FILENAME.prm, mfcc/$DIR_TST/$FILENAME.prm => dtw/${DIR_REF}-$DIR_TST/$FILENAME.dtw 22 | b=2 23 | dtw -b -$b -t mfcc/$DIR_REF -r mfcc/$DIR_TST -a dtw/beam$b -w -B -f -F $FILENAME 24 | 25 | 26 | 27 | # Check: create transcription labels, for each 0.1 seconds in reference: 28 | 29 | step=0.1 30 | dur=$(soxi -D $DIR_REF/$FILENAME.wav) 31 | 32 | perl -e ' 33 | print "#\n"; 34 | $t = 0; 35 | $n = 1; 36 | while ($t <= '$dur') { 37 | print "$t\t121\tT$n\n"; 38 | $t += '$step'; 39 | $n ++; 40 | } 41 | ' > $DIR_REF/$FILENAME.ts 42 | 43 | 44 | echo '#' > $DIR_TST/$FILENAME.ts 45 | 46 | cat $DIR_REF/$FILENAME.ts | grep -v '#' | cut -f 1 |\ 47 | dtw_project -p dtw/${DIR_REF}-$DIR_TST/$FILENAME.dtw - |\ 48 | perl -ne ' 49 | BEGIN {$n=1} 50 | chomp; 51 | print "$_\t121\tR$n\n"; 52 | $n++' >> $DIR_TST/$FILENAME.ts 53 | -------------------------------------------------------------------------------- /Ch4: -------------------------------------------------------------------------------- 1 | We proposed a novel VQ-based one-shot VC with a selflearned 2 | speaker representation. The disentanglement experiments 3 | and visualization show that the VQVC learns a 4 | meaningful embedding space without any supervision, and 5 | an ablation study on the quantization and IN shows that the 6 | normalization on codebooks Q and placing the IN before the 7 | quantization achieve the best result. Further, we perform VC 8 | to unseen speakers with only one utterance, and subjective 9 | evaluations showed good results in terms of similarity to 10 | target speakers. 11 | -------------------------------------------------------------------------------- /Ch5: -------------------------------------------------------------------------------- 1 | https://github.com/Jackson-Kang/VQVC-Pytorch 2 | -------------------------------------------------------------------------------- /Ch6: -------------------------------------------------------------------------------- 1 | I’m Hesam Najafi 2 | I’m master student in south Tehran university in major biomedical engineering. 3 | I’m so grateful of my professor dr.Mahdi Eslami that helped us to learn machine learning and digital signal processing and also helped us to learn more about pythons, 4 | GitHub and LinkedIn. 5 | -------------------------------------------------------------------------------- /Ch7: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/1sI4JriYpi2bxnWEemld1l7AW-gZjYYY2/view?usp=drivesdk 2 | -------------------------------------------------------------------------------- /Ch8: -------------------------------------------------------------------------------- 1 | https://colab.research.google.com/drive/1CVNFf7HhkOW4tY-psXl6zi4oDc2j8err 2 | 3 | https://colab.research.google.com/drive/14t-e-08Rns837D0_hrR23TD_rdTz54YEN 4 | -------------------------------------------------------------------------------- /Ch9: -------------------------------------------------------------------------------- 1 | https://colab.research.google.com/drive/1mdlnu19J7jJqfgUhLCUfNB2-scxC9GFwA 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jackson Kang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQVC-Pytorch 2 | ONE-SHOT VOICE CONVERSION BY VECTOR QUANTIZATION 3 | # Hesam Najafi 4 | # 401141140111021 5 | # How-to-run 6 | 1. Install dependencies. 7 | * python=3.7 8 | * [pytorch](https://pytorch.org/)=1.7 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | 2. Download dataset and pretrained VocGAN model. 15 | * Please download [VCTK dataset](https://datashare.ed.ac.uk/handle/10283/3443) and edit ```dataset_path``` in ```config.py```. 16 | * Download [VocGAN pretrained model](https://github.com/Jackson-Kang/VQVC-Pytorch#pretrained-models) 17 | 18 | 3. Preprocess 19 | * preprocess mel-spectrogram via following command: 20 | ``` 21 | python prepro.py 1 1 22 | ``` 23 | * first argument: mel-preprocessing 24 | * second argument: metadata split (You may change the portion of samples used on train/eval via ```data_split_ratio``` in ```config.py```) 25 | 26 | 4. Train the model 27 | ``` 28 | python train.py 29 | ``` 30 | * In ```config.py```, you may edit ```train_visible_device``` to choose GPU for training. 31 | * Same as paper, 60K steps are enough. 32 | * Training the model spends only 30 minutes. 33 | 34 | 5. Voice conversion 35 | * After training, point the source and reference speech for voice conversion. (You may edit ```src_paths``` and ```ref_paths``` in ```conversion.py```.) 36 | * As a result of conversion, you may check samples in directory ```results```. 37 | ``` 38 | python conversion.py 39 | ``` 40 | 41 | 42 | # Inference results 43 | * You may hear [audio samples](https://jackson-kang.github.io/opensource_samples/vqvc/). 44 | 45 | * Visualization of converted mel-spectrogram 46 | - source mel(top), reference mel(middle), converted mel(bottom) 47 | 48 | ![converted_melspectrogram](./assets/converted_melspectrogram.jpg) 49 | 50 | 51 | # Pretrained models 52 | 1. [VQVC pretrained model](https://drive.google.com/file/d/1wiG8CyzNhq7dVZG3LZqCJ5bnoPTPS08a/view?usp=sharing) 53 | * download pretrained VQVC model and place it in ```ckpts/VCTK-Corpus/``` 54 | 2. [VocGAN pretrained model](https://drive.google.com/file/d/1nfD84ot7o3u2tFR7YkSp2vQWVnNJ-md_/view?usp=sharing) 55 | * download pretrained VocGAN model and place it in ```vocoder/vocgan/pretrained_models``` 56 | 57 | # Experimental Notes 58 | * Trimming silence and stride of convolution are very important to transfer the style from reference speech. 59 | * Unlike paper, I used [NVIDIA's preprocessing method](https://github.com/NVIDIA/tacotron2/blob/fc0cf6a89a47166350b65daa1beaa06979e4cddf/stft.py) to use pretrained [VocGAN](https://arxiv.org/pdf/2007.15256.pdf) model. 60 | * Training is very unstable. (After 70K steps, perplexity of codebook is substantially decreased to 1.) 61 | * **(Future work)** The model trained on [Korean Emotional Speech dataset](https://www.aihub.or.kr/keti_data_board/expression) is not completed yet. 62 | 63 | # References (or acknowledgements) 64 | * [One-shot Voice Conversion by Vector Quantization](https://ieeexplore.ieee.org/document/9053854) (D. Y. Wu et. al., 2020) 65 | * [VocGAN implementation](https://github.com/rishikksh20/VocGAN) by rishikksh20 66 | * [NVIDIA's preprocessing method](https://github.com/NVIDIA/tacotron2/blob/fc0cf6a89a47166350b65daa1beaa06979e4cddf/stft.py) 67 | -------------------------------------------------------------------------------- /assets/converted_melspectrogram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/converted_melspectrogram.jpg -------------------------------------------------------------------------------- /assets/melspectrogram_visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/melspectrogram_visualization.jpg -------------------------------------------------------------------------------- /assets/model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/model_architecture.png -------------------------------------------------------------------------------- /assets/train_commitment_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/train_commitment_loss.jpg -------------------------------------------------------------------------------- /assets/train_perplexity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/train_perplexity.jpg -------------------------------------------------------------------------------- /assets/train_reconstruction_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/train_reconstruction_loss.jpg -------------------------------------------------------------------------------- /assets/train_total_loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahdeslami11/vqvc-pytorch/621cd2481891e492cbd947b62979494cf0f14a69/assets/train_total_loss.jpg -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.path import get_path 3 | 4 | class Arguments: 5 | 6 | """ 7 | path configurations 8 | """ 9 | dataset_name = "VCTK-Corpus" 10 | dataset_path = get_path("/home/minsu/dataset/VCTK/", dataset_name) 11 | 12 | converted_sample_dir = "results" 13 | prepro_dir = "preprocessed" 14 | model_log_dir = "logs" 15 | model_checkpoint_dir = "ckpts" 16 | 17 | 18 | # path for loading audio(wav) samples to be preprocessed 19 | wav_dir = get_path(dataset_path, "wavs") 20 | 21 | # by default, preprocessed samples and metadata are stored in "prepro_path" 22 | prepro_path = get_path(prepro_dir, dataset_name) 23 | prepro_mel_dir = get_path(prepro_path, "mels") 24 | prepro_meta_dir = get_path(prepro_path, "metas") 25 | prepro_meta_train = get_path(prepro_meta_dir, "meta_train.csv") 26 | prepro_meta_eval = get_path(prepro_meta_dir, "meta_eval.csv") 27 | prepro_meta_unseen = get_path(prepro_meta_dir, "meta_unseen.csv") 28 | 29 | mel_stat_path = get_path(prepro_path, "mel_stats.npy") 30 | 31 | model_log_path = get_path(model_log_dir, dataset_name) 32 | model_checkpoint_path = get_path(model_checkpoint_dir, dataset_name) 33 | 34 | 35 | """ 36 | preprocessing hyperparams 37 | """ 38 | max_frame_length = 40 # window size of random resampling 39 | 40 | sr = 22050 # 22050kHz sampling rate 41 | n_mels = 80 42 | filter_length = 1024 43 | hop_length = 256 44 | win_length = 1024 45 | 46 | max_wav_value = 32768.0 # for other dataset 47 | mel_fmin = 0 48 | mel_fmax = 8000 49 | 50 | trim_silence = True 51 | top_db = 15 # threshold for trimming silence 52 | 53 | """ 54 | VQVC hyperparameters 55 | """ 56 | 57 | n_embeddings = 256 # of codes in VQ-codebook 58 | z_dim=32 # bottleneck dimension 59 | 60 | commitment_cost = 0.01 # commitment cost 61 | 62 | norm_epsilon = 1e-4 63 | speaker_emb_reduction=1 64 | 65 | warmup_steps = 1000 66 | init_lr = 1e-3 # initial learning rate 67 | max_lr = 4e-2 # maximum learning rate 68 | gamma = 0.25 69 | milestones = [20000] 70 | 71 | 72 | """ 73 | data & training setting 74 | """ 75 | grad_clip_thresh=3.0 76 | seed = 999 77 | n_workers = 10 78 | 79 | #scheduler setting 80 | 81 | use_cuda = True 82 | mem_mode = True 83 | 84 | data_split_ratio = [0.95, 0.05] # [train, evaluation] in 0 ~ 1 range 85 | 86 | train_visible_devices = "7" 87 | conversion_visible_devices = "7" 88 | 89 | train_batch_size = 120 90 | eval_batch_size = 100 91 | eval_step = 1000 92 | eval_path = "eval_results" 93 | save_checkpoint_step = 5000 94 | 95 | log_tensorboard = True 96 | max_training_step = 60000 97 | 98 | # vocoder setting 99 | vocoder = "vocgan" 100 | vocoder_pretrained_model_name = "vocgan_universal_pretrained_model_epoch_1280.pt" 101 | vocoder_pretrained_model_path = get_path("./vocoder", "{}", "pretrained_models", vocoder_pretrained_model_name).format(vocoder) 102 | 103 | -------------------------------------------------------------------------------- /conversion.py: -------------------------------------------------------------------------------- 1 | from config import Arguments as args 2 | 3 | import os 4 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 5 | os.environ["CUDA_VISIBLE_DEVICES"]=args.conversion_visible_devices 6 | 7 | import sys 8 | import torch 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from model import VQVC 13 | 14 | from utils.dataset import get_src_and_ref_mels, normalize, de_normalize 15 | from utils.vocoder import get_vocgan, vocgan_infer 16 | from utils.path import get_path, create_dir 17 | from utils.checkpoint import load_checkpoint 18 | from utils.figure import draw_converted_melspectrogram 19 | 20 | from config import Arguments as args 21 | 22 | 23 | def convert(model, vocoder, mel_stat, conversion_wav_paths, DEVICE=None): 24 | 25 | 26 | for idx, (src_wav_path, ref_wav_path) in tqdm(enumerate(conversion_wav_paths), total=len(conversion_wav_paths), unit='B', ncols=70, leave=False): 27 | 28 | mel_mean, mel_std = mel_stat 29 | 30 | src_mel, ref_mel = get_src_and_ref_mels(src_wav_path, ref_wav_path, trim_silence=args.trim_silence, frame_length=args.filter_length, hop_length=args.hop_length, top_db=args.top_db) 31 | src_mel, ref_mel = normalize(src_mel, mel_mean, mel_std), normalize(ref_mel, mel_mean, mel_std) 32 | 33 | src_mel = torch.from_numpy(src_mel).float().unsqueeze(0).to(DEVICE) 34 | ref_mel = torch.from_numpy(ref_mel).float().unsqueeze(0).to(DEVICE) 35 | 36 | mel_converted, mel_src_code, mel_src_style, mel_ref_code, mel_ref_style = model.convert(src_mel, ref_mel) 37 | 38 | src_wav_name = src_wav_path.split("/")[-1] 39 | ref_wav_name = ref_wav_path.split("/")[-1] 40 | 41 | src_wav_path = get_path(args.converted_sample_dir, "{}_src_{}".format(idx, src_wav_name)) 42 | ref_wav_path = get_path(args.converted_sample_dir, "{}_ref_{}".format(idx, ref_wav_name)) 43 | converted_wav_path = get_path(args.converted_sample_dir, "{}_converted_{}_{}".format(idx, src_wav_name.replace(".wav", ""), ref_wav_name)) 44 | src_code_wav_path = get_path(args.converted_sample_dir, "{}_src_code_{}".format(idx, src_wav_name)) 45 | src_style_wav_path = get_path(args.converted_sample_dir, "{}_src_style_{}".format(idx, src_wav_name)) 46 | ref_code_wav_path = get_path(args.converted_sample_dir, "{}_ref_code_{}".format(idx, ref_wav_name)) 47 | ref_style_wav_path = get_path(args.converted_sample_dir, "{}_ref_style_{}".format(idx, ref_wav_name)) 48 | 49 | mel_mean, mel_std = torch.from_numpy(mel_mean).float().to(DEVICE), torch.from_numpy(mel_std).float().to(DEVICE) 50 | 51 | 52 | src_mel = de_normalize(src_mel, mel_mean, mel_std) 53 | ref_mel = de_normalize(ref_mel, mel_mean, mel_std) 54 | mel_converted = de_normalize(mel_converted, mel_mean, mel_std) 55 | mel_src_code = de_normalize(mel_src_code, mel_mean, mel_std) 56 | mel_src_style = de_normalize(mel_src_style, mel_mean, mel_std) 57 | mel_ref_code = de_normalize(mel_ref_code, mel_mean, mel_std) 58 | mel_ref_style = de_normalize(mel_ref_style, mel_mean, mel_std) 59 | 60 | vocgan_infer(src_mel.transpose(1, 2), vocoder, path=src_wav_path) 61 | vocgan_infer(ref_mel.transpose(1, 2), vocoder, path = ref_wav_path) 62 | vocgan_infer(mel_converted.transpose(1, 2), vocoder, path = converted_wav_path) 63 | vocgan_infer(mel_src_code.transpose(1, 2), vocoder, path = src_code_wav_path) 64 | vocgan_infer(mel_src_style.transpose(1, 2), vocoder, path= src_style_wav_path) 65 | vocgan_infer(mel_ref_code.transpose(1, 2), vocoder, path= ref_code_wav_path) 66 | vocgan_infer(mel_ref_style.transpose(1, 2), vocoder, path = ref_style_wav_path) 67 | 68 | src_mel = src_mel.transpose(1, 2).squeeze().detach().cpu().numpy() 69 | ref_mel = ref_mel.transpose(1, 2).squeeze().detach().cpu().numpy() 70 | mel_converted = mel_converted.transpose(1, 2).squeeze().detach().cpu().numpy() 71 | mel_src_code = mel_src_code.transpose(1, 2).squeeze().detach().cpu().numpy() 72 | mel_src_style = mel_src_style.transpose(1, 2).squeeze().detach().cpu().numpy() 73 | mel_ref_code = mel_ref_code.transpose(1, 2).squeeze().detach().cpu().numpy() 74 | mel_ref_style = mel_ref_style.transpose(1, 2).squeeze().detach().cpu().numpy() 75 | 76 | fig = draw_converted_melspectrogram(src_mel, ref_mel, mel_converted, mel_src_code, mel_src_style, mel_ref_code, mel_ref_style) 77 | fig.savefig(get_path(args.converted_sample_dir, "contents_{}_style_{}.png".format(src_wav_name.replace(".wav", ""), ref_wav_name.replace(".wav", "")))) 78 | 79 | 80 | def main(DEVICE): 81 | 82 | # load model 83 | model = VQVC().to(DEVICE) 84 | vocoder = get_vocgan(ckpt_path=args.vocoder_pretrained_model_path).to(DEVICE) 85 | 86 | load_checkpoint(args.model_checkpoint_path, model) 87 | mel_stat = np.load(args.mel_stat_path) 88 | 89 | dataset_root = args.wav_dir 90 | 91 | src_paths = [get_path(dataset_root, "p226_354.wav"), get_path(dataset_root, "p225_335.wav")] 92 | ref_paths = [get_path(dataset_root, "p225_335.wav"), get_path(dataset_root, "p226_354.wav")] 93 | 94 | create_dir(args.converted_sample_dir) 95 | 96 | convert(model, vocoder, mel_stat, conversion_wav_paths=tuple(zip(src_paths, ref_paths)), DEVICE=DEVICE) 97 | 98 | 99 | if __name__ == "__main__": 100 | 101 | print("[LOG] Start conversion...") 102 | 103 | DEVICE = torch.device("cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu") 104 | 105 | main(DEVICE) 106 | print("[LOG] Finish..") 107 | -------------------------------------------------------------------------------- /data/korean_emotional_speech_dataset.py: -------------------------------------------------------------------------------- 1 | from utils.path import * 2 | from utils.audio.tools import get_mel 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import glob, os, sys 7 | from multiprocessing import Pool 8 | 9 | from scipy.io.wavfile import write 10 | import librosa, ffmpeg 11 | from sklearn.preprocessing import StandardScaler 12 | 13 | 14 | def job(wav_filename): 15 | 16 | original_wav_filename, prepro_wav_dir, sampling_rate = wav_filename 17 | filename = original_wav_filename.split("/")[-1] 18 | new_wav_filename = get_path(prepro_wav_dir, filename) 19 | 20 | if not os.path.exists(new_wav_filename): 21 | try: 22 | out, err = (ffmpeg 23 | .input(original_wav_filename) 24 | .output(new_wav_filename, acodec='pcm_s16le', ac=1, ar=sampling_rate) 25 | .overwrite_output() 26 | .run(capture_stdout=True, capture_stderr=True)) 27 | 28 | except ffmpeg.Error as err: 29 | print(err.stderr, file=sys.stderr) 30 | raise 31 | 32 | 33 | def preprocess(data_path, prepro_wav_dir, prepro_path, mel_path, sampling_rate, n_workers=10, filter_length=1024, hop_length=256, top_db=10): 34 | p = Pool(n_workers) 35 | mel_scaler = StandardScaler(copy=False) 36 | 37 | prepro_wav_dir = create_dir(prepro_wav_dir) 38 | wav_paths=[[filename, prepro_wav_dir, sampling_rate] for filename in list(glob.glob(get_path(data_path, "**", "wav", "*.wav")))] 39 | 40 | print("\t[LOG] converting wav format...") 41 | with tqdm(total=len(wav_paths)) as pbar: 42 | for _ in tqdm(p.imap_unordered(job, wav_paths)): 43 | pbar.update() 44 | 45 | print("\t[LOG] saving mel-spectrogram...") 46 | with tqdm(total=len(wav_paths)) as pbar: 47 | for wav_filename in tqdm(glob.glob(get_path(prepro_wav_dir, "*.wav"))): 48 | mel_filename = wav_filename.split("/")[-1].replace("wav", "npy") 49 | mel_savepath = get_path(mel_path, mel_filename) 50 | mel_spectrogram, _ = get_mel(wav_filename, trim_silence=True, frame_length=filter_length, hop_length=hop_length, top_db=top_db) 51 | 52 | mel_scaler.partial_fit(mel_spectrogram) 53 | np.save(mel_savepath, mel_spectrogram) 54 | np.save(get_path(prepro_path, "mel_stats.npy"), np.array([mel_scaler.mean_, mel_scaler.scale_])) 55 | 56 | print("Done!") 57 | 58 | 59 | def split_unseen_emotions(prepro_mel_dir): 60 | print("[LOG] SEEN emotion: ANGRY(ang), HAPPY(hap), SAD(sad), NEUTRAL(neu) \n\tUNSEEN emotion: SURPRISE(sur), FEAR(fea), DISGUSTING(dis)") 61 | 62 | seen_emotion_list, unseen_emotion_list = ["ang", "sad", "hap", "neu"], ["sur", "fea", "dis"] 63 | 64 | seen_emotion_files, unseen_emotion_files = [], [] 65 | 66 | preprocessed_file_list = glob.glob(get_path(prepro_mel_dir, "*.npy")) 67 | 68 | for preprocessed_mel_file in preprocessed_file_list: 69 | emotion = preprocessed_mel_file.split("/")[-1].split("_")[1] 70 | if emotion in seen_emotion_list: 71 | seen_emotion_files.append(preprocessed_mel_file) 72 | elif emotion in unseen_emotion_list: 73 | unseen_emotion_files.append(preprocessed_mel_file) 74 | else: 75 | print("[WARNING] File({}) cannot be identified by emotion label.\n\t(This file will not contain in whole dataset.)".format(preprocessed_mel_file)) 76 | 77 | return seen_emotion_files, unseen_emotion_files 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /data/vctk.py: -------------------------------------------------------------------------------- 1 | from utils.path import * 2 | from utils.audio.tools import get_mel 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import glob, os, sys 7 | from multiprocessing import Pool 8 | 9 | from scipy.io.wavfile import write 10 | import librosa, ffmpeg 11 | from sklearn.preprocessing import StandardScaler 12 | 13 | def job(wav_filename): 14 | 15 | original_wav_filename, prepro_wav_dir, sampling_rate = wav_filename 16 | filename = original_wav_filename.split("/")[-1] 17 | new_wav_filename = get_path(prepro_wav_dir, filename) 18 | 19 | if not os.path.exists(new_wav_filename): 20 | try: 21 | out, err = (ffmpeg 22 | .input(original_wav_filename) 23 | .output(new_wav_filename, acodec='pcm_s16le', ac=1, ar=sampling_rate) 24 | .overwrite_output() 25 | .run(capture_stdout=True, capture_stderr=True)) 26 | 27 | except ffmpeg.Error as err: 28 | print(err.stderr, file=sys.stderr) 29 | raise 30 | 31 | 32 | def preprocess(data_path, prepro_wav_dir, prepro_path, mel_path, sampling_rate, n_workers=10, filter_length=1024, hop_length=256, trim_silence=True, top_db=60): 33 | p = Pool(n_workers) 34 | 35 | mel_scaler = StandardScaler(copy=False) 36 | 37 | prepro_wav_dir = create_dir(prepro_wav_dir) 38 | wav_paths=[[filename, prepro_wav_dir, sampling_rate] for filename in list(glob.glob(get_path(data_path, "wav48", "**", "*.wav")))] 39 | 40 | print("\t[LOG] converting wav format...") 41 | with tqdm(total=len(wav_paths)) as pbar: 42 | for _ in tqdm(p.imap_unordered(job, wav_paths)): 43 | pbar.update() 44 | 45 | print("\t[LOG] saving mel-spectrogram...") 46 | with tqdm(total=len(wav_paths)) as pbar: 47 | for wav_filename in tqdm(glob.glob(get_path(prepro_wav_dir, "*.wav"))): 48 | mel_filename = wav_filename.split("/")[-1].replace("wav", "npy") 49 | mel_savepath = get_path(mel_path, mel_filename) 50 | mel_spectrogram, _ = get_mel(wav_filename, trim_silence=trim_silence, frame_length=filter_length, hop_length=hop_length, top_db=top_db) 51 | 52 | mel_scaler.partial_fit(mel_spectrogram) 53 | np.save(mel_savepath, mel_spectrogram) 54 | 55 | np.save(get_path(prepro_path, "mel_stats.npy"), np.array([mel_scaler.mean_, mel_scaler.scale_])) 56 | 57 | print("Done!") 58 | 59 | 60 | 61 | def split_unseen_speakers(prepro_mel_dir): 62 | 63 | print("[LOG] 6 UNSEEN speakers: \n\t p226(Male, English, Surrey) \n\t p256(Male, English, Birmingham) \ 64 | \n\t p266(Female, Irish, Athlone) \n\t p297(Female, American, Newyork) \ 65 | \n\t p323 (Female, SouthAfrican, Pretoria)\n\t p376(Male, Indian)") 66 | 67 | unseen_speaker_list = ["p226", "p256", "p266", "p297", "p323", "p376"] 68 | 69 | seen_speaker_files, unseen_speaker_files = [], [] 70 | 71 | preprocessed_file_list = glob.glob(get_path(prepro_mel_dir, "*.npy")) 72 | 73 | for preprocessed_mel_file in preprocessed_file_list: 74 | speaker = preprocessed_mel_file.split("/")[-1].split("_")[0] 75 | if speaker in unseen_speaker_list: 76 | unseen_speaker_files.append(preprocessed_mel_file) 77 | else: 78 | seen_speaker_files.append(preprocessed_mel_file) 79 | 80 | return seen_speaker_files, unseen_speaker_files 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from config import Arguments as args 2 | from utils.path import get_path 3 | from utils.dataset import get_label_dictionary, normalize 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import glob 9 | 10 | from torch.utils.data import dataset as dset 11 | 12 | class SpeechDataset(dset.Dataset): 13 | """ 14 | class: Dataset 15 | 16 | returns: 17 | text: (Tx, Dim) 18 | mel: (Ty, Dim) 19 | 20 | explanation: 21 | dataloader for training model 22 | mem_mode (in config.py): mags cannot be loaded onto the memory due to the large # of features 23 | """ 24 | 25 | def __init__(self, mem_mode, meta_dir, dataset_name, mel_stat_path, max_frame_length=120): 26 | 27 | self.__mel_file_paths = self.__get_mel_filename(meta_dir=meta_dir) 28 | self.__label_dictionary = get_label_dictionary(dataset_name) 29 | self.max_frame_length = max_frame_length 30 | self.mel_mean, self.mel_std = np.load(mel_stat_path).astype(np.float) 31 | 32 | if args.mem_mode: 33 | self.__mels = list(map(lambda mel_file_path: torch.tensor(np.load(mel_file_path)), self.__mel_file_paths)) 34 | 35 | def __len__(self): 36 | return len(self.__mel_file_paths) 37 | 38 | def __getitem__(self, index): 39 | mel = self.__mels[index] if args.mem_mode else np.load(self.__mel_file_paths[index]) 40 | 41 | T_mel, _ = mel.shape 42 | 43 | while T_mel <= self.max_frame_length: 44 | mel = torch.cat((mel, mel), dim=0) 45 | T_mel, _ = mel.shape 46 | 47 | index = np.random.randint(T_mel - self.max_frame_length + 1) 48 | normalized_mel = normalize(mel[index: index + self.max_frame_length], mean=self.mel_mean, std=self.mel_std) 49 | 50 | return normalized_mel, 0 51 | 52 | 53 | def __get_mel_filename(self, meta_dir): 54 | with open(meta_dir, "r") as f: 55 | mel_file_paths = list(map(lambda filename : filename.rstrip(), f.readlines())) 56 | return mel_file_paths 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from config import Arguments as args 2 | 3 | from utils.vocoder import vocgan_infer 4 | from utils.path import create_dir, get_path 5 | from utils.dataset import de_normalize 6 | 7 | import torch 8 | 9 | def evaluate(model, vocoder, eval_data_loader, criterion, global_step, mel_stat, writer=None, DEVICE=None): 10 | 11 | eval_path = create_dir(args.eval_path) 12 | model.eval() 13 | mel_mean, mel_std = mel_stat 14 | 15 | with torch.no_grad(): 16 | eval_loss, eval_recon_loss, eval_perplexity, eval_commitment_loss = 0, 0, 0, 0 17 | 18 | for step, (mels, _) in enumerate(eval_data_loader): 19 | 20 | mels = mels.float().to(DEVICE) 21 | 22 | mels_hat, mels_code, mels_style, commitment_loss, perplexity = model.evaluate(mels.detach()) 23 | 24 | commitment_loss = args.commitment_cost * commitment_loss 25 | recon_loss = criterion(mels, mels_hat) 26 | 27 | total_loss = commitment_loss + recon_loss 28 | 29 | eval_perplexity += perplexity.item() 30 | eval_recon_loss += recon_loss.item() 31 | eval_commitment_loss += commitment_loss.item() 32 | eval_loss += total_loss.item() 33 | 34 | mel = de_normalize(mels[0], mean=mel_mean, std=mel_std).float() 35 | mel_hat = de_normalize(mels_hat[0], mean=mel_mean, std=mel_std).float() 36 | mel_code = de_normalize(mels_code[0], mean=mel_mean, std=mel_std).float() 37 | mel_style = de_normalize(mels_style[0], mean=mel_mean, std=mel_std).float() 38 | 39 | vocgan_infer(mel.transpose(0, 1), vocoder, path=get_path(args.eval_path, "{:0>3}_GT.wav".format(global_step//1000))) 40 | vocgan_infer(mel_hat.transpose(0, 1), vocoder, path=get_path(args.eval_path, "{:0>3}_reconstructed.wav".format(global_step//1000))) 41 | vocgan_infer(mel_code.transpose(0, 1), vocoder, path=get_path(args.eval_path, "{:0>3}_code.wav".format(global_step//1000))) 42 | vocgan_infer(mel_style.transpose(0, 1), vocoder, path=get_path(args.eval_path, "{:0>3}_style.wav".format(global_step//1000))) 43 | 44 | mel = mel.view(-1, args.n_mels).detach().cpu().numpy().T 45 | mel_hat = mel_hat.view(-1, args.n_mels).detach().cpu().numpy().T 46 | mel_code = mel_code.view(-1, args.n_mels).detach().cpu().numpy().T 47 | mel_style = mel_style.view(-1, args.n_mels).detach().cpu().numpy().T 48 | 49 | if args.log_tensorboard: 50 | writer.add_scalars(mode="eval_reconstruction_loss", global_step=global_step, loss=eval_recon_loss / len(eval_data_loader)) 51 | writer.add_scalars(mode="eval_commitment_loss", global_step=global_step, loss=eval_commitment_loss / len(eval_data_loader)) 52 | writer.add_scalars(mode="eval_perplexity", global_step=global_step, loss=eval_perplexity / len(eval_data_loader)) 53 | writer.add_scalars(mode="eval_total_loss", global_step=global_step, loss=eval_loss / len(eval_data_loader)) 54 | writer.add_mel_figures(mode="eval-mels_", global_step=global_step, mel=mel, mel_hat=mel_hat, mel_code=mel_code, mel_style=mel_style) 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from config import Arguments as args 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from network import Encoder, VQEmbeddingEMA, Decoder 7 | import module as mm 8 | 9 | class VQVC(nn.Module): 10 | """ 11 | VQVC 12 | 13 | Args: 14 | mels: (N, T, C) 15 | 16 | Returns: 17 | encode: 18 | z_enc: (N, T, z_dim) 19 | z_quan: (N, T, z_dim) 20 | c: (N, T, c_dim) 21 | indices: (N, 22 | forward: 23 | z_enc: (N, T, z_dim) 24 | z_quan: (N, T, z_dim) 25 | c: (N, T, c_dim) 26 | loss: (1, ) 27 | perplexity (1, ) 28 | """ 29 | 30 | def __init__(self, speaker_emb_reduction=3): 31 | super(VQVC, self).__init__() 32 | self.name = 'VQVC' 33 | 34 | self.speaker_emb_reduction = args.speaker_emb_reduction 35 | 36 | self.encoder = Encoder(mel_channels=args.n_mels, z_dim=args.z_dim) 37 | self.codebook = VQEmbeddingEMA(args.n_embeddings, args.z_dim) 38 | self.decoder = Decoder(in_channels=args.z_dim, mel_channels=args.n_mels) 39 | 40 | def average_through_time(self, x, dim): 41 | x = torch.mean(x, dim=dim, keepdim=True) 42 | return x 43 | 44 | def forward(self, mels): 45 | 46 | # encoder 47 | z_enc = self.encoder(mels) 48 | 49 | # quantization 50 | z_quan, commitment_loss, perplexity = self.codebook(z_enc) 51 | 52 | # speaker emb 53 | speaker_emb_ = z_enc - z_quan 54 | speaker_emb = self.average_through_time(speaker_emb_, dim=1) 55 | 56 | # decoder 57 | mels_hat = self.decoder(z_quan, speaker_emb) 58 | 59 | return mels_hat, commitment_loss, perplexity 60 | 61 | def evaluate(self, mels): 62 | # encoder 63 | z_enc = self.encoder(mels) 64 | 65 | # contents emb 66 | z_quan, commitment_loss, perplexity = self.codebook(z_enc) 67 | 68 | # speaker emb 69 | speaker_emb_ = z_enc - z_quan 70 | speaker_emb = self.average_through_time(speaker_emb_, dim=1) 71 | 72 | # decoder 73 | mels_hat, mels_code, mels_style = self.decoder.evaluate(z_quan, speaker_emb, speaker_emb_) 74 | 75 | return mels_hat, mels_code, mels_style, commitment_loss, perplexity 76 | 77 | 78 | def convert(self, src_mel, ref_mel): 79 | # source z_enc 80 | z_src_enc = self.encoder(src_mel) 81 | 82 | # source contents 83 | src_contents, _, _ = self.codebook(z_src_enc) 84 | 85 | # source style emb 86 | src_style_emb_ = z_src_enc - src_contents 87 | 88 | # ref z_enc 89 | ref_enc = self.encoder(ref_mel) 90 | 91 | # ref contents 92 | ref_contents, _, _ = self.codebook(ref_enc) 93 | 94 | # ref speaker emb 95 | ref_speaker_emb_ = ref_enc - ref_contents 96 | ref_speaker_emb = self.average_through_time(ref_speaker_emb_, dim=1) 97 | 98 | # decoder to generate mel 99 | mel_converted, mel_src_code, mel_src_style, mel_ref_code, mel_ref_style = self.decoder.convert(src_contents, src_style_emb_, ref_contents, ref_speaker_emb, ref_speaker_emb_) 100 | 101 | return mel_converted, mel_src_code, mel_src_style, mel_ref_code, mel_ref_style 102 | 103 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Linear(nn.Linear): 5 | """ 6 | Linear 7 | Args: 8 | x: (N, T, C_in) 9 | Returns: 10 | y: (N, T, C_out) 11 | """ 12 | 13 | def __init__(self, in_features, out_features, bias=True, activation_fn=None, ln=None, drop_rate=0.): 14 | super(Linear, self).__init__(in_features, out_features, bias=bias) 15 | 16 | self.activation_fn = activation_fn(inplace=True) if activation_fn is not None else None 17 | self.layer_norm = nn.LayerNorm(out_features) if ln is not None else None 18 | self.activation_fn = activation_fn(inplace=True) if activation_fn is not None else None 19 | self.drop_out = nn.Dropout(drop_rate) if drop_rate > 0 else None 20 | 21 | def forward(self, x): 22 | y = super(Linear, self).forward(x) 23 | y = self.layer_norm(y) if self.layer_norm is not None else y 24 | y = self.activation_fn(y) if self.activation_fn is not None else y 25 | y = self.drop_out(y) if self.drop_out is not None else y 26 | return y 27 | 28 | 29 | class Conv1d(nn.Conv1d): 30 | """ 31 | Convolution 1d 32 | Args: 33 | x: (N, T, C_in) 34 | Returns: 35 | y: (N, T, C_out) 36 | """ 37 | 38 | def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0., 39 | stride=1, padding='same', dilation=1, groups=1, bias=True, ln=False): 40 | 41 | if padding == 'same': 42 | padding = kernel_size // 2 * dilation 43 | self.even_kernel = not bool(kernel_size % 2) 44 | 45 | super(Conv1d, self).__init__(in_channels, out_channels, kernel_size, 46 | stride=stride, padding=padding, dilation=dilation, 47 | groups=groups, bias=bias) 48 | 49 | self.activation_fn = activation_fn(inplace=True) if activation_fn is not None else None 50 | self.drop_out = nn.Dropout(drop_rate) if drop_rate > 0 else None 51 | self.layer_norm = nn.LayerNorm(out_channels) if ln else None 52 | 53 | def forward(self, x): 54 | y = x.transpose(1, 2) 55 | y = super(Conv1d, self).forward(y) 56 | y = y.transpose(1, 2) 57 | y = self.layer_norm(y) if self.layer_norm is not None else y 58 | y = self.activation_fn(y) if self.activation_fn is not None else y 59 | y = self.drop_out(y) if self.drop_out is not None else y 60 | y = y[:, :-1, :] if self.even_kernel else y 61 | return y 62 | 63 | class Conv1dResBlock(Conv1d): 64 | """ 65 | Convolution 1d with Residual connection 66 | 67 | Args: 68 | x: (N, T, C_in) 69 | Returns: 70 | y: (N, T, C_out) 71 | """ 72 | def __init__(self, in_channels, out_channels, kernel_size, activation_fn=None, drop_rate=0., 73 | stride=1, padding='same', dilation=1, groups=1, bias=True, ln=False): 74 | 75 | super(Conv1dResBlock, self).__init__(in_channels, out_channels, kernel_size, activation_fn, 76 | drop_rate, stride, padding, dilation, groups=groups, bias=bias, 77 | ln=ln) 78 | 79 | def forward(self, x): 80 | residual = x 81 | x = super(Conv1dResBlock, self).forward(x) 82 | x = x + residual 83 | 84 | return x 85 | 86 | class Upsample(nn.Upsample): 87 | """ 88 | Upsampling via interporlation 89 | 90 | Args: 91 | x: (N, T, C) 92 | Returns: 93 | y: (N, S * T, C) 94 | (S: scale_factor) 95 | """ 96 | 97 | def __init__(self, scale_factor=2, mode='nearest'): 98 | super(Upsample, self).__init__(scale_factor=scale_factor, mode=mode) 99 | 100 | def forward(self, x): 101 | x = x.transpose(1, 2) 102 | x = super(Upsample, self).forward(x) 103 | x = x.transpose(1, 2) 104 | 105 | return x 106 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import module as mm 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder 9 | Args: 10 | mel: (N, Tx, C_mel) log-melspectrogram (variable length) 11 | Returns: 12 | y_: (N, Tx, C_hidden) 13 | """ 14 | 15 | def __init__(self, mel_channels=80, z_dim=256): 16 | super(Encoder, self).__init__() 17 | self.encoder = nn.Sequential( 18 | mm.Conv1d(mel_channels, 64, kernel_size=4, stride=2, padding='same', bias=True, activation_fn=nn.ReLU), 19 | mm.Conv1d(64, 256, kernel_size=3, padding='same', bias=True), 20 | mm.Conv1d(256, 128, kernel_size=3, padding='same', bias=True), 21 | mm.Conv1dResBlock(128, 128, kernel_size=3, padding='same', bias=True, activation_fn=nn.ReLU), 22 | mm.Conv1dResBlock(128, 128, kernel_size=3, padding='same', bias=True, activation_fn=nn.ReLU), 23 | mm.Conv1d(128, z_dim, kernel_size=1, padding='same', bias=False) 24 | ) 25 | 26 | def forward(self, mels): 27 | z = self.encoder(mels) 28 | return z 29 | 30 | 31 | class VQEmbeddingEMA(nn.Module): 32 | """ 33 | VQEmbeddingEMA 34 | - vector quantization module 35 | - ref 36 | from VectorQuantizedCPC official repository 37 | (https://github.com/bshall/VectorQuantizedCPC/blob/master/model.py) 38 | 39 | encode: 40 | args: 41 | x: (N, T, z_dim) 42 | returns: 43 | quantized: (N, T, z_dim) 44 | indices: (N, T) 45 | forward: 46 | args: 47 | x: (N, T, z_dim) 48 | returns: 49 | quantized: (N, T, z_dim) 50 | loss: (N, 1) 51 | perplexity: (N, 1) 52 | """ 53 | 54 | 55 | def __init__(self, n_embeddings, embedding_dim, epsilon=1e-5): 56 | super(VQEmbeddingEMA, self).__init__() 57 | self.epsilon = epsilon 58 | 59 | init_bound = 1 / n_embeddings 60 | embedding = torch.Tensor(n_embeddings, embedding_dim) 61 | embedding.uniform_(-init_bound, init_bound) 62 | embedding = embedding / (torch.norm(embedding, dim=1, keepdim=True) + 1e-4) 63 | self.register_buffer("embedding", embedding) 64 | self.register_buffer("ema_count", torch.zeros(n_embeddings)) 65 | self.register_buffer("ema_weight", self.embedding.clone()) 66 | 67 | def instance_norm(self, x, dim, epsilon=1e-5): 68 | mu = torch.mean(x, dim=dim, keepdim=True) 69 | std = torch.std(x, dim=dim, keepdim=True) 70 | 71 | z = (x - mu) / (std + epsilon) 72 | return z 73 | 74 | 75 | def forward(self, x): 76 | 77 | x = self.instance_norm(x, dim=1) 78 | 79 | embedding = self.embedding / (torch.norm(self.embedding, dim=1, keepdim=True) + 1e-4) 80 | 81 | M, D = embedding.size() 82 | x_flat = x.detach().reshape(-1, D) 83 | 84 | distances = torch.addmm(torch.sum(embedding ** 2, dim=1) + 85 | torch.sum(x_flat ** 2, dim=1, keepdim=True), 86 | x_flat, embedding.t(), 87 | alpha=-2.0, beta=1.0) 88 | 89 | indices = torch.argmin(distances.float(), dim=-1).detach() 90 | encodings = F.one_hot(indices, M).float() 91 | quantized = F.embedding(indices, self.embedding) 92 | 93 | quantized = quantized.view_as(x) 94 | 95 | commitment_loss = F.mse_loss(x, quantized.detach()) 96 | 97 | quantized_ = x + (quantized - x).detach() 98 | quantized_ = (quantized_ + quantized)/2 99 | 100 | avg_probs = torch.mean(encodings, dim=0) 101 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 102 | 103 | return quantized_, commitment_loss, perplexity 104 | 105 | 106 | 107 | class Decoder(nn.Module): 108 | """ 109 | Decoder 110 | 111 | args: 112 | z_enc: (N, T, z_dim) 113 | z_quan: (N, T, z_dim) 114 | return: 115 | mel_reconstructed: (N, T, C_mel) 116 | """ 117 | 118 | def __init__(self, in_channels, mel_channels=80): 119 | super(Decoder, self).__init__() 120 | 121 | self.res_blocks = nn.Sequential( 122 | mm.Conv1d(in_channels, 128, kernel_size=3, 123 | bias=False, padding='same'), 124 | mm.Conv1dResBlock(128, 128, kernel_size=3, 125 | bias=True, padding='same', activation_fn=nn.ReLU), 126 | mm.Conv1dResBlock(128, 128, kernel_size=3, 127 | bias=True, padding='same', activation_fn=nn.ReLU), 128 | mm.Upsample(scale_factor=2, mode='nearest'), 129 | mm.Conv1d(128, 256, kernel_size=2, 130 | bias=True, padding='same', activation_fn=nn.ReLU), 131 | mm.Linear(256, mel_channels) 132 | ) 133 | 134 | 135 | def forward(self, contents, speaker_emb): 136 | 137 | contents = self.norm(contents, dim=2) 138 | speaker_emb = self.norm(speaker_emb, dim=2) 139 | 140 | embedding = contents + speaker_emb 141 | 142 | mel_reconstructed = self.res_blocks(embedding) 143 | 144 | return mel_reconstructed 145 | 146 | 147 | def evaluate(self, src_contents, speaker_emb, speaker_emb_): 148 | 149 | # normalize the L2-norm of input vector into 1 on every time-step 150 | src_contents = self.norm(src_contents, dim=2) 151 | speaker_emb = self.norm(speaker_emb, dim=2) 152 | speaker_emb_ = self.norm(speaker_emb_, dim=2) 153 | 154 | embedding = src_contents + speaker_emb 155 | 156 | # converted mel_hat 157 | mel_converted = self.res_blocks(embedding) 158 | 159 | # only src-code 160 | mel_src_code = self.res_blocks(src_contents) 161 | 162 | # only ref-style 163 | mel_ref_style = self.res_blocks(speaker_emb_) 164 | 165 | return mel_converted, mel_src_code, mel_ref_style 166 | 167 | def convert(self, src_contents, src_style_emb_, ref_contents, ref_speaker_emb, ref_speaker_emb_): 168 | # normalize the L2-norm of input vector into 1 on every time-step 169 | src_contents = self.norm(src_contents, dim=2) 170 | src_style_emb_ = self.norm(src_style_emb_, dim=2) 171 | ref_contents = self.norm(ref_contents, dim=2) 172 | ref_speaker_emb = self.norm(ref_speaker_emb, dim=2) 173 | ref_speaker_emb_ = self.norm(ref_speaker_emb_, dim=2) 174 | 175 | embedding = src_contents + ref_speaker_emb 176 | 177 | # converted mel_hat 178 | mel_converted = self.res_blocks(embedding) 179 | 180 | # only src-code 181 | mel_src_code = self.res_blocks(src_contents) 182 | 183 | # only src_style_emb_ 184 | mel_src_style = self.res_blocks(src_style_emb_) 185 | 186 | # only ref-code 187 | mel_ref_code = self.res_blocks(ref_contents) 188 | 189 | # only ref_style_emb_ 190 | mel_ref_style = self.res_blocks(ref_speaker_emb_) 191 | 192 | return mel_converted, mel_src_code, mel_src_style, mel_ref_code, mel_ref_style 193 | 194 | 195 | 196 | 197 | def norm(self, x, dim, epsilon=1e-4): 198 | x_ = x / (torch.norm(x, dim=dim, keepdim = True) + epsilon) 199 | return x_ 200 | 201 | 202 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | from config import Arguments as args 2 | from utils.path import get_path, create_dir 3 | 4 | from data import korean_emotional_speech_dataset 5 | from data import vctk 6 | 7 | import codecs, sys 8 | import numpy as np 9 | import glob 10 | 11 | def prepro_wavs(): 12 | 13 | print("Start to preprocess {} wav signal...".format(args.dataset_name)) 14 | 15 | dataset_path = args.dataset_path 16 | wav_dir = args.wav_dir 17 | 18 | create_dir(args.prepro_dir) 19 | prepro_path = create_dir(args.prepro_path) 20 | mel_path = create_dir(args.prepro_mel_dir) 21 | sampling_rate = args.sr 22 | 23 | if "korean_emotional_speech" in args.dataset_name: 24 | korean_emotional_speech_dataset.preprocess(dataset_path, wav_dir, prepro_path, mel_path, sampling_rate, n_workers=args.n_workers, filter_length=args.filter_length, hop_length=args.hop_length, trim_silence=args.trim_silence, top_db=args.top_db) 25 | elif "VCTK" in args.dataset_name: 26 | vctk.preprocess(dataset_path, wav_dir, prepro_path, mel_path, sampling_rate, n_workers=args.n_workers, filter_length=args.filter_length, hop_length=args.hop_length, trim_silence=args.trim_silence, top_db=args.top_db) 27 | else: 28 | print("[ERROR] No Dataset named {}".format(args.dataset_name)) 29 | 30 | 31 | def write_meta(): 32 | """ 33 | [TO DO] apply sampling based on audio duration when splitting train, eval, test dataset 34 | """ 35 | 36 | # split dataset into meta-train, meta-eval, meta-test with split ratio 37 | print("[LOG] Start to split data with ratio:", args.data_split_ratio) 38 | 39 | assert np.sum(args.data_split_ratio) == 1., "sum of list data_split_ratio must be 1" 40 | 41 | meta_path = create_dir(args.prepro_meta_dir) 42 | 43 | meta_train = codecs.open(args.prepro_meta_train, mode="w") 44 | meta_eval = codecs.open(args.prepro_meta_eval, mode="w") 45 | meta_unseen = codecs.open(args.prepro_meta_unseen, mode="w") 46 | 47 | if "korean_emotional_speech" in args.dataset_name: 48 | seen_files, unseen_files = korean_emotional_speech_dataset.split_unseen_emotions(args.prepro_mel_dir) 49 | elif "VCTK" in args.dataset_name: 50 | seen_files, unseen_files = vctk.split_unseen_speakers(args.prepro_mel_dir) 51 | else: 52 | print("[ERROR] No Dataset named {}".format(args.dataset_name)) 53 | 54 | 55 | train_num = int(len(seen_files) * args.data_split_ratio[0]) 56 | 57 | train_file = "\n".join(seen_files[:train_num+1]) 58 | eval_file = "\n".join(seen_files[train_num+1:]) 59 | unseen_file = "\n".join(unseen_files) 60 | 61 | meta_train.writelines(train_file) 62 | meta_eval.writelines(eval_file) 63 | meta_unseen.writelines(unseen_file) 64 | 65 | print("[LOG] Done: split metadata") 66 | 67 | 68 | if __name__ == "__main__": 69 | 70 | assert len(sys.argv) == 3, "[ERROR] # of args must be 1" 71 | 72 | _, is_wav, is_meta = sys.argv 73 | 74 | print("Audio signal: {}\t\tWrite metadata: {}".format(is_wav, is_meta)) 75 | 76 | if is_wav in ["1", 1, "True"]: 77 | prepro_wavs() 78 | if is_meta in ["1", 1, "True"]: 79 | write_meta() 80 | 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.6.0 2 | librosa==0.8.0 3 | tqdm==4.56.2 4 | ffmpeg-python==0.2 5 | tensorboard==2.4.1 6 | matplotlib==3.3.4 7 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import Arguments as args 2 | import os 3 | 4 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 5 | os.environ["CUDA_VISIBLE_DEVICES"]=args.train_visible_devices 6 | 7 | import sys, random 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.optim import Adam 13 | from torch.utils.data import DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | from config import Arguments as args 17 | 18 | from model import VQVC 19 | 20 | from evaluate import evaluate 21 | from dataset import SpeechDataset #, collate_fn 22 | 23 | from utils.scheduler import WarmupScheduler 24 | from utils.checkpoint import load_checkpoint, save_checkpoint 25 | from utils.writer import Writer 26 | from utils.vocoder import get_vocgan 27 | 28 | from tqdm import tqdm 29 | 30 | 31 | def train(train_data_loader, eval_data_loader, model, reconstruction_loss, vocoder, mel_stat, optimizer, scheduler, global_step, writer=None, DEVICE=None): 32 | 33 | model.train() 34 | 35 | while global_step < args.max_training_step: 36 | 37 | for step, (mels, _) in tqdm(enumerate(train_data_loader), total=len(train_data_loader), unit='B', ncols=70, leave=False): 38 | mels = mels.float().to(DEVICE) 39 | optimizer.zero_grad() 40 | 41 | mels_hat, commitment_loss, perplexity = model(mels.detach()) 42 | 43 | commitment_loss = args.commitment_cost * commitment_loss 44 | recon_loss = reconstruction_loss(mels_hat, mels) 45 | 46 | loss = commitment_loss + recon_loss 47 | loss.backward() 48 | 49 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_thresh) 50 | optimizer.step() 51 | 52 | if global_step % args.save_checkpoint_step == 0: 53 | save_checkpoint(checkpoint_path=args.model_checkpoint_path, model=model, optimizer=optimizer, scheduler=scheduler, global_step=global_step) 54 | 55 | if global_step % args.eval_step == 0: 56 | evaluate(model=model, vocoder=vocoder, eval_data_loader=eval_data_loader, criterion=reconstruction_loss, mel_stat=mel_stat, global_step=global_step, writer=writer, DEVICE=DEVICE) 57 | model.train() 58 | 59 | if args.log_tensorboard: 60 | writer.add_scalars(mode="train_recon_loss", global_step=global_step, loss=recon_loss) 61 | writer.add_scalars(mode="train_commitment_loss", global_step=global_step, loss=commitment_loss) 62 | writer.add_scalars(mode="train_perplexity", global_step=global_step, loss=perplexity) 63 | writer.add_scalars(mode="train_total_loss", global_step=global_step, loss=loss) 64 | 65 | global_step += 1 66 | 67 | scheduler.step() 68 | 69 | def main(DEVICE): 70 | 71 | # define model, optimizer, scheduler 72 | model = VQVC().to(DEVICE) 73 | 74 | recon_loss = nn.L1Loss().to(DEVICE) 75 | vocoder = get_vocgan(ckpt_path=args.vocoder_pretrained_model_path).to(DEVICE) 76 | 77 | mel_stat = torch.tensor(np.load(args.mel_stat_path)).to(DEVICE) 78 | 79 | optimizer = Adam(model.parameters(), lr=args.init_lr) 80 | scheduler = WarmupScheduler( optimizer, warmup_epochs=args.warmup_steps, 81 | initial_lr=args.init_lr, max_lr=args.max_lr, 82 | milestones=args.milestones, gamma=args.gamma) 83 | 84 | global_step = load_checkpoint(checkpoint_path=args.model_checkpoint_path, model=model, optimizer=optimizer, scheduler=scheduler) 85 | 86 | # load dataset & dataloader 87 | train_dataset = SpeechDataset(mem_mode=args.mem_mode, meta_dir=args.prepro_meta_train, dataset_name = args.dataset_name, mel_stat_path=args.mel_stat_path, max_frame_length=args.max_frame_length) 88 | eval_dataset = SpeechDataset(mem_mode=args.mem_mode, meta_dir=args.prepro_meta_eval, dataset_name=args.dataset_name, mel_stat_path=args.mel_stat_path, max_frame_length=args.max_frame_length) 89 | 90 | train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=args.n_workers) 91 | eval_data_loader = DataLoader(dataset=eval_dataset, batch_size=args.train_batch_size, shuffle=False, pin_memory=True, drop_last=True) 92 | 93 | # tensorboard 94 | writer = Writer(args.model_log_path) if args.log_tensorboard else None 95 | 96 | # train the model! 97 | train(train_data_loader, eval_data_loader, model, recon_loss, vocoder, mel_stat, optimizer, scheduler, global_step, writer, DEVICE) 98 | 99 | 100 | if __name__ == "__main__": 101 | 102 | print("[LOG] Start training...") 103 | DEVICE = torch.device("cuda" if (torch.cuda.is_available() and args.use_cuda) else "cpu") 104 | 105 | seed = args.seed 106 | 107 | print("[Training environment]") 108 | print("\t\trandom_seed: ", seed) 109 | print("\t\tuse_cuda: ", args.use_cuda) 110 | print("\t\t{} threads are used...".format(torch.get_num_threads())) 111 | 112 | random.seed(seed) 113 | np.random.seed(seed) 114 | torch.manual_seed(seed) 115 | 116 | main(DEVICE) 117 | -------------------------------------------------------------------------------- /utils/audio/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | from NVIDIA's preprocessing 3 | 4 | 5 | reference) 6 | https://github.com/NVIDIA/tacotron2 7 | """ 8 | 9 | 10 | import utils.audio.tools 11 | import utils.audio.stft 12 | import utils.audio.audio_preprocessing 13 | -------------------------------------------------------------------------------- /utils/audio/audio_preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | from NVIDIA's preprocessing 3 | 4 | reference) 5 | https://github.com/NVIDIA/tacotron2 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | from scipy.signal import get_window 11 | import librosa.util as librosa_util 12 | from config import Arguments as args 13 | 14 | def window_sumsquare(window, n_frames, hop_length=args.hop_length, win_length=args.win_length, 15 | n_fft=args.filter_length, dtype=np.float32, norm=None): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | This is used to estimate modulation effects induced by windowing 20 | observations in short-time fourier transforms. 21 | Parameters 22 | ---------- 23 | window : string, tuple, number, callable, or list-like 24 | Window specification, as in `get_window` 25 | n_frames : int > 0 26 | The number of analysis frames 27 | hop_length : int > 0 28 | The number of samples to advance between frames 29 | win_length : [optional] 30 | The length of the window function. By default, this matches `n_fft`. 31 | n_fft : int > 0 32 | The length of each analysis frame. 33 | dtype : np.dtype 34 | The data type of the output 35 | Returns 36 | ------- 37 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 38 | The sum-squared envelope of the window function 39 | """ 40 | if win_length is None: 41 | win_length = n_fft 42 | 43 | n = n_fft + hop_length * (n_frames - 1) 44 | x = np.zeros(n, dtype=dtype) 45 | 46 | # Compute the squared window at the desired length 47 | win_sq = get_window(window, win_length, fftbins=True) 48 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 49 | win_sq = librosa_util.pad_center(win_sq, n_fft) 50 | 51 | # Fill the envelope 52 | for i in range(n_frames): 53 | sample = i * hop_length 54 | x[sample:min(n, sample + n_fft) 55 | ] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /utils/audio/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | from NVIDIA's preprocessing 3 | 4 | reference) 5 | https://github.com/NVIDIA/tacotron2 6 | """ 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | from scipy.signal import get_window 14 | from librosa.util import pad_center, tiny 15 | from librosa.filters import mel as librosa_mel_fn 16 | 17 | from utils.audio.audio_preprocessing import dynamic_range_compression 18 | from utils.audio.audio_preprocessing import dynamic_range_decompression 19 | from utils.audio.audio_preprocessing import window_sumsquare 20 | 21 | 22 | class STFT(torch.nn.Module): 23 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 24 | 25 | def __init__(self, filter_length, hop_length, win_length, 26 | window='hann'): 27 | super(STFT, self).__init__() 28 | self.filter_length = filter_length 29 | self.hop_length = hop_length 30 | self.win_length = win_length 31 | self.window = window 32 | self.forward_transform = None 33 | scale = self.filter_length / self.hop_length 34 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 35 | 36 | cutoff = int((self.filter_length / 2 + 1)) 37 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 38 | np.imag(fourier_basis[:cutoff, :])]) 39 | 40 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 41 | inverse_basis = torch.FloatTensor( 42 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 43 | 44 | if window is not None: 45 | assert(filter_length >= win_length) 46 | # get window and zero center pad it to filter_length 47 | fft_window = get_window(window, win_length, fftbins=True) 48 | fft_window = pad_center(fft_window, filter_length) 49 | fft_window = torch.from_numpy(fft_window).float() 50 | 51 | # window the bases 52 | forward_basis *= fft_window 53 | inverse_basis *= fft_window 54 | 55 | self.register_buffer('forward_basis', forward_basis.float()) 56 | self.register_buffer('inverse_basis', inverse_basis.float()) 57 | 58 | def transform(self, input_data): 59 | 60 | num_batches = input_data.size(0) 61 | num_samples = input_data.size(1) 62 | 63 | self.num_samples = num_samples 64 | 65 | # similar to librosa, reflect-pad the input 66 | input_data = input_data.view(num_batches, 1, num_samples) 67 | input_data = F.pad( 68 | input_data.unsqueeze(1), 69 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 70 | mode='reflect') 71 | input_data = input_data.squeeze(1) 72 | 73 | forward_transform = F.conv1d( 74 | input_data.cuda(), 75 | Variable(self.forward_basis, requires_grad=False).cuda(), 76 | stride=self.hop_length, 77 | padding=0).cpu() 78 | 79 | cutoff = int((self.filter_length / 2) + 1) 80 | real_part = forward_transform[:, :cutoff, :] 81 | imag_part = forward_transform[:, cutoff:, :] 82 | 83 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 84 | phase = torch.autograd.Variable( 85 | torch.atan2(imag_part.data, real_part.data)) 86 | 87 | return magnitude, phase 88 | 89 | def inverse(self, magnitude, phase): 90 | recombine_magnitude_phase = torch.cat( 91 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 92 | 93 | inverse_transform = F.conv_transpose1d( 94 | recombine_magnitude_phase, 95 | Variable(self.inverse_basis, requires_grad=False), 96 | stride=self.hop_length, 97 | padding=0) 98 | 99 | if self.window is not None: 100 | window_sum = window_sumsquare( 101 | self.window, magnitude.size(-1), hop_length=self.hop_length, 102 | win_length=self.win_length, n_fft=self.filter_length, 103 | dtype=np.float32) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0]) 107 | window_sum = torch.autograd.Variable( 108 | torch.from_numpy(window_sum), requires_grad=False) 109 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 110 | inverse_transform[:, :, 111 | approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 112 | 113 | # scale by hop ratio 114 | inverse_transform *= float(self.filter_length) / self.hop_length 115 | 116 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 117 | inverse_transform = inverse_transform[:, 118 | :, :-int(self.filter_length/2):] 119 | 120 | return inverse_transform 121 | 122 | def forward(self, input_data): 123 | self.magnitude, self.phase = self.transform(input_data) 124 | reconstruction = self.inverse(self.magnitude, self.phase) 125 | return reconstruction 126 | 127 | 128 | class TacotronSTFT(torch.nn.Module): 129 | def __init__(self, filter_length, hop_length, win_length, 130 | n_mel_channels, sampling_rate, mel_fmin=0.0, 131 | mel_fmax=8000.0): 132 | super(TacotronSTFT, self).__init__() 133 | self.n_mel_channels = n_mel_channels 134 | self.sampling_rate = sampling_rate 135 | self.stft_fn = STFT(filter_length, hop_length, win_length) 136 | mel_basis = librosa_mel_fn( 137 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 138 | mel_basis = torch.from_numpy(mel_basis).float() 139 | self.register_buffer('mel_basis', mel_basis) 140 | 141 | def spectral_normalize(self, magnitudes): 142 | output = dynamic_range_compression(magnitudes) 143 | return output 144 | 145 | def spectral_de_normalize(self, magnitudes): 146 | output = dynamic_range_decompression(magnitudes) 147 | return output 148 | 149 | def mel_spectrogram(self, y): 150 | """Computes mel-spectrograms from a batch of waves 151 | PARAMS 152 | ------ 153 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 154 | RETURNS 155 | ------- 156 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 157 | """ 158 | assert(torch.min(y.data) >= -1) 159 | assert(torch.max(y.data) <= 1) 160 | 161 | magnitudes, phases = self.stft_fn.transform(y) 162 | magnitudes = magnitudes.data 163 | mel_output = torch.matmul(self.mel_basis, magnitudes) 164 | mel_output = self.spectral_normalize(mel_output) 165 | energy = torch.norm(magnitudes, dim=1) 166 | 167 | return mel_output, energy 168 | -------------------------------------------------------------------------------- /utils/audio/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | from NVIDIA's preprocessing 3 | 4 | reference) 5 | https://github.com/NVIDIA/tacotron2 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | 11 | from scipy.io.wavfile import read 12 | from scipy.io.wavfile import write 13 | import scipy.signal as sps 14 | 15 | import librosa 16 | import os 17 | 18 | from . import stft as stft 19 | from .audio_preprocessing import griffin_lim 20 | from config import Arguments as args 21 | 22 | 23 | _stft = stft.TacotronSTFT( 24 | args.filter_length, args.hop_length, args.win_length, 25 | args.n_mels, args.sr, args.mel_fmin, args.mel_fmax) 26 | 27 | 28 | def load_wav_to_numpy(full_path): 29 | sampling_rate, data = read(full_path) 30 | return data.astype(np.float32), sampling_rate 31 | 32 | 33 | def get_mel(filename, trim_silence=False, frame_length=1024, hop_length=256, top_db=10): 34 | audio, sampling_rate = load_wav_to_numpy(filename) 35 | 36 | if sampling_rate != _stft.sampling_rate: 37 | raise ValueError("{} SR doesn't match target SR {}".format( 38 | sampling_rate, _stft.sampling_rate)) 39 | audio_norm = audio / args.max_wav_value 40 | if trim_silence: 41 | audio_norm = audio_norm[200:-200] 42 | audio_norm, idx = librosa.effects.trim(audio_norm, top_db=top_db, frame_length=frame_length, hop_length=hop_length) 43 | 44 | audio_norm = torch.FloatTensor(audio_norm) 45 | audio_norm = audio_norm.unsqueeze(0) 46 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 47 | melspec, energy = _stft.mel_spectrogram(audio_norm) 48 | melspec = torch.squeeze(melspec, 0).detach().cpu().numpy().T 49 | energy = torch.squeeze(energy, 0).detach().cpu().numpy() 50 | 51 | return melspec, energy 52 | 53 | def get_mel_from_wav(audio): 54 | sampling_rate = args.sr 55 | if sampling_rate != _stft.sampling_rate: 56 | raise ValueError("{} {} SR doesn't match target {} SR".format( 57 | sampling_rate, _stft.sampling_rate)) 58 | audio_norm = audio / args.max_wav_value 59 | audio_norm = audio_norm.unsqueeze(0) 60 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 61 | melspec, energy = _stft.mel_spectrogram(audio_norm) 62 | melspec = torch.squeeze(melspec, 0).detach().cpu().numpy().T 63 | energy = torch.squeeze(energy, 0).detach().cpu().numpy() 64 | 65 | 66 | return melspec, energy 67 | 68 | 69 | def inv_mel_spec(mel, out_filename, griffin_iters=60): 70 | mel = torch.stack([mel]) 71 | mel_decompress = _stft.spectral_de_normalize(mel) 72 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 73 | spec_from_mel_scaling = 1000 74 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 75 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 76 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 77 | 78 | audio = griffin_lim(torch.autograd.Variable( 79 | spec_from_mel[:, :, :-1]), _stft.stft_fn, griffin_iters) 80 | 81 | audio = audio.squeeze() 82 | audio = audio.cpu().numpy() 83 | audio_path = out_filename 84 | write(audio_path, args.sr, audio) 85 | -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, glob 3 | 4 | from .path import create_dir, get_path 5 | 6 | 7 | def load_checkpoint(checkpoint_path, model, optimizer=None, scheduler=None): 8 | 9 | if optimizer is not None: 10 | if not(os.path.exists(checkpoint_path)): 11 | print("[WARNING] No checkpoint exists. Start from scratch.") 12 | global_step = 0 13 | else: 14 | print("[WARNING] Already exists. Restart to train model.") 15 | last_model_path = sorted(glob.glob(get_path(checkpoint_path, '*.pth.tar')))[-1] 16 | state = torch.load(last_model_path) 17 | 18 | model.load_state_dict(state['model']) 19 | global_step = state['global_step'] 20 | optimizer.load_state_dict(state['optimizer']) 21 | scheduler.load_state_dict(state['scheduler']) 22 | else: 23 | last_model_path = sorted(glob.glob(get_path(checkpoint_path, '*.pth.tar')))[-1] 24 | state = torch.load(last_model_path) 25 | model.load_state_dict(state['model']) 26 | global_step = 0 27 | print("[WARNING] Model: {} has been loaded.".format(last_model_path.split("/")[-1].replace(".pth.tar", ""))) 28 | 29 | return global_step 30 | 31 | 32 | def save_checkpoint(checkpoint_path, global_step, model, optimizer, scheduler): 33 | 34 | create_dir("/".join(checkpoint_path.split("/")[:-1])) 35 | checkpoint_path = create_dir(checkpoint_path) 36 | 37 | cur_checkpoint_name = "model-{:03d}k.pth.tar".format(global_step//1000) 38 | 39 | state = { 40 | 'global_step': global_step, 41 | 'model': model.state_dict(), 42 | 'optimizer': optimizer.state_dict(), 43 | 'scheduler': scheduler.state_dict() 44 | } 45 | 46 | torch.save(state, get_path(checkpoint_path, cur_checkpoint_name)) 47 | 48 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from .path import * 2 | from .audio.tools import get_mel 3 | 4 | import numpy as np 5 | import os 6 | import torch 7 | 8 | def get_label_dictionary(dataset_name): 9 | if "korean_emotional_speech" in dataset_name: 10 | return {'ang': 0, 'fea': 1, 'neu': 2, 'sur': 3, 'dis': 4, 'hap':5, 'sad': 6} 11 | else: 12 | return None 13 | 14 | def get_src_and_ref_mels(src_path, ref_path, trim_silence=True, frame_length=1024, hop_length=1024, top_db=10): 15 | src_mel, ref_mel = None, None 16 | 17 | if os.path.isfile(src_path) and os.path.isfile(ref_path): 18 | src_mel, _ = get_mel(src_path, trim_silence=trim_silence, frame_length=frame_length, hop_length=hop_length, top_db=top_db) 19 | ref_mel, _ = get_mel(ref_path, trim_silence=trim_silence, frame_length=frame_length, hop_length=hop_length, top_db = top_db) 20 | else: 21 | print("[ERROR] No paths exist! Check your filename.: \n\t src_path: {} ref_path: {}".format(src_path, ref_path)) 22 | 23 | return src_mel, ref_mel 24 | 25 | def normalize(x, mean, std): 26 | zero_idxs = np.where(x==0.0)[0] 27 | z = (x - mean) / std 28 | z[zero_idxs] = 0.0 29 | return z 30 | 31 | def de_normalize(z, mean, std): 32 | zero_idxs = torch.where(z == 0.0)[0] 33 | x = mean + std * z 34 | x[zero_idxs] = 0.0 35 | return x 36 | -------------------------------------------------------------------------------- /utils/figure.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('pdf') 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | def draw_melspectrogram(mel, mel_hat, mel_code, mel_style): 7 | fig, axis = plt.subplots(4, 1, figsize=(20,30)) 8 | 9 | axis[0].set_title("Ground-truth mel") 10 | axis[0].imshow(mel, origin="lower", aspect="auto") 11 | 12 | axis[1].set_title("Reconstructed mel") 13 | axis[1].imshow(mel_hat, origin="lower", aspect="auto") 14 | 15 | axis[2].set_title("Contents mel") 16 | axis[2].imshow(mel_code, origin="lower", aspect="auto") 17 | 18 | axis[3].set_title("Style mel (normed_z - normed_z_quan") 19 | axis[3].imshow(mel_style, origin="lower", aspect="auto") 20 | 21 | return fig 22 | 23 | def draw_converted_melspectrogram(src_mel, ref_mel, mel_converted, mel_src_code, mel_src_style, mel_ref_code, mel_ref_style): 24 | fig, axis = plt.subplots(7, 1, figsize=(40, 60)) 25 | 26 | axis[0].set_title("Source mel") 27 | axis[0].imshow(src_mel, origin="lower", aspect="auto") 28 | 29 | axis[1].set_title("Reference mel") 30 | axis[1].imshow(ref_mel, origin="lower", aspect="auto") 31 | 32 | axis[2].set_title("Converted mel") 33 | axis[2].imshow(mel_converted, origin="lower", aspect="auto") 34 | 35 | axis[3].set_title("Source contents mel") 36 | axis[3].imshow(mel_src_code, origin="lower", aspect="auto") 37 | 38 | axis[4].set_title("Source style mel") 39 | axis[4].imshow(mel_src_style, origin="lower", aspect="auto") 40 | 41 | axis[5].set_title("Reference contents mel") 42 | axis[5].imshow(mel_ref_code, origin="lower", aspect="auto") 43 | 44 | axis[6].set_title("Reference style mel") 45 | axis[6].imshow(mel_ref_style, origin="lower", aspect="auto") 46 | 47 | return fig 48 | -------------------------------------------------------------------------------- /utils/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_path(*args): 4 | return os.path.join('', *args) 5 | 6 | def create_dir(*args): 7 | path = get_path(*args) 8 | if not os.path.exists(path): 9 | os.mkdir(path) 10 | return path 11 | 12 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from collections import Counter 3 | import warnings 4 | 5 | class WarmupScheduler(optim.lr_scheduler._LRScheduler): 6 | def __init__(self, optimizer, warmup_epochs, initial_lr, max_lr, milestones, gamma=0.1, last_epoch=-1): 7 | assert warmup_epochs < milestones[0] 8 | self.warmup_epochs = warmup_epochs 9 | self.milestones = Counter(milestones) 10 | self.gamma = gamma 11 | 12 | initial_lrs = self._format_param("initial_lr", optimizer, initial_lr) 13 | max_lrs = self._format_param("max_lr", optimizer, max_lr) 14 | if last_epoch == -1: 15 | for idx, group in enumerate(optimizer.param_groups): 16 | group["initial_lr"] = initial_lrs[idx] 17 | group["max_lr"] = max_lrs[idx] 18 | 19 | super(WarmupScheduler, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if not self._get_lr_called_within_step: 23 | warnings.warn("To get the last learning rate computed by the scheduler, " 24 | "please use `get_last_lr()`.", DeprecationWarning) 25 | 26 | if self.last_epoch <= self.warmup_epochs: 27 | pct = self.last_epoch / self.warmup_epochs 28 | return [ 29 | (group["max_lr"] - group["initial_lr"]) * pct + group["initial_lr"] 30 | for group in self.optimizer.param_groups] 31 | else: 32 | if self.last_epoch not in self.milestones: 33 | return [group['lr'] for group in self.optimizer.param_groups] 34 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] 35 | for group in self.optimizer.param_groups] 36 | 37 | @staticmethod 38 | def _format_param(name, optimizer, param): 39 | """Return correctly formatted lr/momentum for each param group.""" 40 | if isinstance(param, (list, tuple)): 41 | if len(param) != len(optimizer.param_groups): 42 | raise ValueError("expected {} values for {}, got {}".format( 43 | len(optimizer.param_groups), name, len(param))) 44 | return param 45 | else: 46 | return [param] * len(optimizer.param_groups) 47 | 48 | -------------------------------------------------------------------------------- /utils/vocoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vocoder.vocgan.generator import Generator 4 | from scipy.io import wavfile 5 | from config import Arguments as args 6 | 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | def get_vocgan(ckpt_path, n_mel_channels=args.n_mels, generator_ratio = [4, 4, 2, 2, 2, 2], n_residual_layers=4, mult=256, out_channels=1): 11 | 12 | checkpoint = torch.load(ckpt_path) 13 | model = Generator(n_mel_channels, n_residual_layers, 14 | ratios=generator_ratio, mult=mult, 15 | out_band=out_channels) 16 | 17 | model.load_state_dict(checkpoint['model_g']) 18 | model.to(device).eval() 19 | 20 | return model 21 | 22 | def vocgan_infer(mel, vocoder, path): 23 | model = vocoder 24 | 25 | with torch.no_grad(): 26 | if len(mel.shape) == 2: 27 | mel = mel.unsqueeze(0) 28 | 29 | audio = model.infer(mel).squeeze() 30 | audio = args.max_wav_value * audio[:-(args.hop_length*10)] 31 | audio = audio.clamp(min=-args.max_wav_value, max=args.max_wav_value-1) 32 | audio = audio.short().cpu().detach().numpy() 33 | 34 | wavfile.write(path, args.sr, audio) 35 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from .path import create_dir 3 | from .figure import draw_melspectrogram 4 | 5 | class Writer(SummaryWriter): 6 | 7 | def __init__(self, log_path): 8 | super(Writer, self).__init__(log_path) 9 | create_dir("/".join(log_path.split("/")[:-1])) 10 | create_dir(log_path) 11 | 12 | 13 | def add_scalars(self, mode, global_step, loss): 14 | # edit 15 | self.add_scalar("{}".format(mode), loss, global_step) 16 | 17 | def add_mel_figures(self, mode, global_step, mel, mel_hat, mel_code, mel_style): 18 | figure = draw_melspectrogram(mel, mel_hat, mel_code, mel_style) 19 | self.add_figure("{}_mel(top)_mel_hat(top_mid)_code(bottom_mid)_style(bottom)".format(mode), figure, global_step) 20 | 21 | -------------------------------------------------------------------------------- /vocoder/vocgan/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | [VocGAN] Generator 3 | this source code is implemenation of the modified-VocGAN from rishikksh20 4 | git repository: https://github.com/rishikksh20/VocGAN 5 | """ 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from config import Arguments as args 13 | 14 | 15 | MAX_WAV_VALUE = args.max_wav_value 16 | 17 | def weights_init(m): 18 | classname = m.__class__.__name__ 19 | if classname.find("Conv") != -1: 20 | m.weight.data.normal_(0.0, 0.02) 21 | elif classname.find("BatchNorm2d") != -1: 22 | m.weight.data.normal_(1.0, 0.02) 23 | m.bias.data.fill_(0) 24 | 25 | class ResStack(nn.Module): 26 | def __init__(self, channel, dilation=1): 27 | super(ResStack, self).__init__() 28 | 29 | self.block = nn.Sequential( 30 | nn.LeakyReLU(0.2), 31 | nn.ReflectionPad1d(dilation), 32 | nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=dilation)), 33 | nn.LeakyReLU(0.2), 34 | nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), 35 | ) 36 | 37 | 38 | self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) 39 | 40 | 41 | def forward(self, x): 42 | return self.shortcut(x) + self.block(x) 43 | 44 | def remove_weight_norm(self): 45 | nn.utils.remove_weight_norm(self.block[2]) 46 | nn.utils.remove_weight_norm(self.block[4]) 47 | nn.utils.remove_weight_norm(self.shortcut) 48 | 49 | 50 | # Modified VocGAN 51 | class Generator(nn.Module): 52 | def __init__(self, mel_channel, n_residual_layers, ratios=[4, 4, 2, 2, 2, 2], mult=256, out_band=1): 53 | super(Generator, self).__init__() 54 | self.mel_channel = mel_channel 55 | 56 | self.start = nn.Sequential( 57 | nn.ReflectionPad1d(3), 58 | nn.utils.weight_norm(nn.Conv1d(mel_channel, mult * 2, kernel_size=7, stride=1)) 59 | ) 60 | r = ratios[0] 61 | self.upsample_1 = nn.Sequential( 62 | nn.LeakyReLU(0.2), 63 | nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, 64 | kernel_size=r * 2, stride=r, 65 | padding=r // 2 + r % 2, 66 | output_padding=r % 2) 67 | ) 68 | ) 69 | self.res_stack_1 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) 70 | 71 | r = ratios[1] 72 | mult = mult // 2 73 | self.upsample_2 = nn.Sequential( 74 | nn.LeakyReLU(0.2), 75 | nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, 76 | kernel_size=r * 2, stride=r, 77 | padding=r // 2 + r % 2, 78 | output_padding=r % 2) 79 | ) 80 | ) 81 | self.res_stack_2 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) 82 | 83 | r = ratios[2] 84 | mult = mult // 2 85 | self.upsample_3 = nn.Sequential( 86 | nn.LeakyReLU(0.2), 87 | nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, 88 | kernel_size=r * 2, stride=r, 89 | padding=r // 2 + r % 2, 90 | output_padding=r % 2) 91 | ) 92 | ) 93 | 94 | self.skip_upsample_1 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, 95 | kernel_size=64, stride=32, 96 | padding=16, 97 | output_padding=0) 98 | ) 99 | self.res_stack_3 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) 100 | 101 | 102 | 103 | r = ratios[3] 104 | mult = mult // 2 105 | self.upsample_4 = nn.Sequential( 106 | nn.LeakyReLU(0.2), 107 | nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, 108 | kernel_size=r * 2, stride=r, 109 | padding=r // 2 + r % 2, 110 | output_padding=r % 2) 111 | ) 112 | ) 113 | 114 | self.skip_upsample_2 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, 115 | kernel_size=128, stride=64, 116 | padding=32, 117 | output_padding=0) 118 | ) 119 | self.res_stack_4 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) 120 | 121 | 122 | r = ratios[4] 123 | mult = mult // 2 124 | self.upsample_5 = nn.Sequential( 125 | nn.LeakyReLU(0.2), 126 | nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, 127 | kernel_size=r * 2, stride=r, 128 | padding=r // 2 + r % 2, 129 | output_padding=r % 2) 130 | ) 131 | ) 132 | 133 | self.skip_upsample_3 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, 134 | kernel_size=256, stride=128, 135 | padding=64, 136 | output_padding=0) 137 | ) 138 | self.res_stack_5 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) 139 | 140 | 141 | r = ratios[5] 142 | mult = mult // 2 143 | self.upsample_6 = nn.Sequential( 144 | nn.LeakyReLU(0.2), 145 | nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, 146 | kernel_size=r * 2, stride=r, 147 | padding=r // 2 + r % 2, 148 | output_padding=r % 2) 149 | ) 150 | ) 151 | 152 | self.skip_upsample_4 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, 153 | kernel_size=512, stride=256, 154 | padding=128, 155 | output_padding=0) 156 | ) 157 | self.res_stack_6 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) 158 | 159 | self.out = nn.Sequential( 160 | nn.LeakyReLU(0.2), 161 | nn.ReflectionPad1d(3), 162 | nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), 163 | nn.Tanh(), 164 | ) 165 | self.apply(weights_init) 166 | 167 | def forward(self, mel): 168 | mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram 169 | # Mel Shape [B, num_mels, T] -> torch.Size([3, 80, 10]) 170 | x = self.start(mel) # [B, dim*2, T] -> torch.Size([3, 512, 10]) 171 | 172 | x = self.upsample_1(x) 173 | x = self.res_stack_1(x) # [B, dim, T*4] -> torch.Size([3, 256, 40]) 174 | 175 | x = self.upsample_2(x) 176 | x = self.res_stack_2(x) # [B, dim/2, T*16] -> torch.Size([3, 128, 160]) 177 | 178 | x = self.upsample_3(x) 179 | x = x + self.skip_upsample_1(mel) 180 | x = self.res_stack_3(x) # [B, dim/4, T*32] -> torch.Size([3, 64, 320]) 181 | 182 | x = self.upsample_4(x) 183 | x = x + self.skip_upsample_2(mel) 184 | x = self.res_stack_4(x) # [B, dim/8, T*64] -> torch.Size([3, 32, 640]) 185 | 186 | x = self.upsample_5(x) 187 | x = x + self.skip_upsample_3(mel) 188 | x = self.res_stack_5(x) # [B, dim/16, T*128] -> torch.Size([3, 16, 1280]) 189 | 190 | x = self.upsample_6(x) 191 | x = x + self.skip_upsample_4(mel) 192 | x = self.res_stack_6(x) # [B, dim/32, T*256] -> torch.Size([3, 8, 2560]) 193 | 194 | out = self.out(x) # [B, 1, T*256] -> torch.Size([3, 1, 2560]) 195 | 196 | return out 197 | 198 | def eval(self, inference=False): 199 | super(Generator, self).eval() 200 | 201 | # don't remove weight norm while validation in training loop 202 | if inference: 203 | self.remove_weight_norm() 204 | 205 | def remove_weight_norm(self): 206 | """Remove weight normalization module from all of the layers.""" 207 | 208 | def _remove_weight_norm(m): 209 | try: 210 | torch.nn.utils.remove_weight_norm(m) 211 | except ValueError: # this module didn't have weight norm 212 | return 213 | 214 | self.apply(_remove_weight_norm) 215 | 216 | def apply_weight_norm(self): 217 | """Apply weight normalization module from all of the layers.""" 218 | 219 | def _apply_weight_norm(m): 220 | if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): 221 | torch.nn.utils.weight_norm(m) 222 | 223 | self.apply(_apply_weight_norm) 224 | 225 | 226 | def infer(self, mel): 227 | hop_length = 256 228 | # pad input mel with zeros to cut artifact 229 | # see https://github.com/seungwonpark/melgan/issues/8 230 | zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) 231 | mel = torch.cat((mel, zero), dim=2) 232 | 233 | audio = self.forward(mel) 234 | return audio 235 | 236 | --------------------------------------------------------------------------------