├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── WaveRNN.png ├── energy_tb.png ├── fast_speech.png ├── model.png ├── tacotron_wavernn.png ├── tensorboard.png ├── tensorboard_wavernn.png ├── training_viz.gif └── wavernn_alt_model_hrz2.png ├── configs ├── multispeaker.yaml └── singlespeaker.yaml ├── docs ├── .gitignore ├── 404.html ├── Gemfile ├── Gemfile.lock ├── README.md ├── _config.yml ├── _layouts │ └── default.html ├── assets │ └── css │ │ └── style.scss ├── favicon.png └── index.md ├── duration_extraction ├── __init__.py ├── duration_extraction_pipe.py └── duration_extractor.py ├── gen_forward.py ├── models ├── __init__.py ├── common_layers.py ├── fast_pitch.py ├── forward_tacotron.py ├── multi_fast_pitch.py ├── multi_forward_tacotron.py └── tacotron.py ├── notebook_utils ├── __init__.py └── synthesize.py ├── notebooks └── synthesize.ipynb ├── pitch_extraction ├── __init__.py └── pitch_extractor.py ├── preprocess.py ├── requirements.txt ├── sentences.txt ├── tests ├── __init__.py ├── resources │ ├── test_config.yaml │ ├── test_mel.npy │ └── wavs │ │ ├── 0.wav │ │ └── 1.wav ├── test_cleaner.py ├── test_collator.py ├── test_dataset_filter.py ├── test_dsp.py ├── test_duration_extraction_pipe.py ├── test_duration_extractor.py ├── test_forward_dataset.py ├── test_forward_tacotron.py ├── test_guided_attention_matrix.py ├── test_multi_forward_tacotron.py ├── test_recipes.py ├── test_taco_binned_dataloader.py ├── test_taco_dataset.py └── test_tokenizer.py ├── train_forward.py ├── train_tacotron.py ├── trainer ├── __init__.py ├── common.py ├── forward_trainer.py ├── multi_forward_trainer.py └── taco_trainer.py └── utils ├── __init__.py ├── checkpoints.py ├── dataset.py ├── decorators.py ├── display.py ├── distribution.py ├── dsp.py ├── files.py ├── metrics.py ├── paths.py └── text ├── LICENSE ├── __init__.py ├── cleaners.py ├── numbers.py ├── recipes.py ├── symbols.py └── tokenizer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE files 2 | .idea 3 | .vscode 4 | 5 | # Mac files 6 | .DS_Store 7 | 8 | # Environments 9 | .env 10 | .venv 11 | env/ 12 | venv/ 13 | ENV/ 14 | env.bak/ 15 | venv.bak/ 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Jupyter Notebook 48 | .ipynb_checkpoints 49 | 50 | # Jekyll 51 | _site/ 52 | .sass-cache/ 53 | .jekyll-cache/ 54 | .jekyll-metadata -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Axel Springer AI. All rights reserved. 4 | Copyright (c) 2019 fatchord (https://github.com/fatchord) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⏩ ForwardTacotron 2 | 3 | Inspired by Microsoft's [FastSpeech](https://www.microsoft.com/en-us/research/blog/fastspeech-new-text-to-speech-model-improves-on-speed-accuracy-and-controllability/) 4 | we modified Tacotron (Fork from fatchord's [WaveRNN](https://github.com/fatchord/WaveRNN)) to generate speech in a single forward pass using a duration predictor to align text and generated mel spectrograms. Hence, we call the model ForwardTacotron (see Figure 1). 5 | 6 |
7 |
8 |
10 | Figure 1: Model Architecture. 11 |
12 | 13 | The model has following advantages: 14 | - **Robustness:** No repeats and failed attention modes for challenging sentences. 15 | - **Speed:** The generation of a mel spectogram takes about 0.04s on a GeForce RTX 2080. 16 | - **Controllability:** It is possible to control the speed of the generated utterance. 17 | - **Efficiency:** In contrast to FastSpeech and Tacotron, the model of ForwardTacotron 18 | does not use any attention. Hence, the required memory grows linearly with text size, which makes it possible to synthesize large articles at once. 19 | 20 | 21 | ## UPDATE Improved attention mechanism (30.08.2023) 22 | - Faster tacotron attention buildup by adding alignment conditioning based on [one alignment to rule them all](https://arxiv.org/abs/2108.10447) 23 | - Improved attention translates to improved synth quality. 24 | 25 | ## 🔈 Samples 26 | 27 | [Can be found here.](https://as-ideas.github.io/ForwardTacotron/) 28 | 29 | The samples are generated with a model trained on LJSpeech and vocoded with WaveRNN, [MelGAN](https://github.com/seungwonpark/melgan), or [HiFiGAN](https://github.com/jik876/hifi-gan). 30 | You can try out the latest pretrained model with the following notebook: 31 | 32 | [](https://colab.research.google.com/github/as-ideas/ForwardTacotron/blob/master/notebooks/synthesize.ipynb) 33 | 34 | ## ⚙️ Installation 35 | 36 | Make sure you have: 37 | 38 | * Python >= 3.6 39 | 40 | Install espeak as phonemizer backend (for macOS use brew): 41 | ``` 42 | sudo apt-get install espeak 43 | ``` 44 | 45 | Then install the rest with pip: 46 | ``` 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ## 🚀 Training your own Model (Singlespeaker) 51 | 52 | Change the params in the config.yaml according to your needs and follow the steps below: 53 | 54 | (1) Download and preprocess the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset: 55 | ``` 56 | python preprocess.py --path /path/to/ljspeech 57 | ``` 58 | (2) Train Tacotron with: 59 | ``` 60 | python train_tacotron.py 61 | ``` 62 | Once the training is finished, the model will automatically extract the alignment features from the dataset. In case you stopped the training early, you 63 | can use the latest checkpoint to manually run the process with: 64 | ``` 65 | python train_tacotron.py --force_align 66 | ``` 67 | (3) Train ForwardTacotron with: 68 | ``` 69 | python train_forward.py 70 | ``` 71 | (4) Generate Sentences with Griffin-Lim vocoder: 72 | ``` 73 | python gen_forward.py --alpha 1 --input_text 'this is whatever you want it to be' griffinlim 74 | ``` 75 | If you want to use the [MelGAN](https://github.com/seungwonpark/melgan) vocoder, you can produce .mel files with: 76 | ``` 77 | python gen_forward.py --input_text 'this is whatever you want it to be' melgan 78 | ``` 79 | If you want to use the [HiFiGAN](https://github.com/jik876/hifi-gan) vocoder, you can produce .npy files with: 80 | ``` 81 | python gen_forward.py --input_text 'this is whatever you want it to be' hifigan 82 | ``` 83 | To vocode the resulting .mel or .npy files use the inference.py script from the MelGAN or HiFiGAN repo and point to the model output folder. 84 | 85 | For training the model on your own dataset just bring it to the LJSpeech-like format: 86 | ``` 87 | |- dataset_folder/ 88 | | |- metadata.csv 89 | | |- wav/ 90 | | |- file1.wav 91 | | |- ... 92 | ``` 93 | 94 | For languages other than English, change the language and cleaners params in the hparams.py, e.g. for French: 95 | ``` 96 | language = 'fr' 97 | tts_cleaner_name = 'no_cleaners' 98 | ``` 99 | 100 | ____ 101 | You can monitor the training processes for Tacotron and ForwardTacotron with 102 | ``` 103 | tensorboard --logdir checkpoints 104 | ``` 105 | Here is what the ForwardTacotron tensorboard looks like: 106 |
107 |
108 |
110 | Figure 2: Tensorboard example for training a ForwardTacotron model. 111 |
112 | 113 | 114 | ## Multispeaker Training 115 | Prepare the data in ljspeech format: 116 | ``` 117 | |- dataset_folder/ 118 | | |- metadata.csv 119 | | |- wav/ 120 | | |- file1.wav 121 | | |- ... 122 | ``` 123 | The metadata.csv is expected to have the speaker id in the second column: 124 | ``` 125 | id_001|speaker_1|this is the first text. 126 | id_002|speaker_1|this is the second text. 127 | id_003|speaker_2|this is the third text. 128 | ... 129 | ``` 130 | We also support the VCTK and a pandas format 131 | (can be set in the config multispeaker.yaml under preprocesing.metafile_format) 132 | 133 | Follow the same steps as for singlespaker, but provide the multispeaker config: 134 | ``` 135 | python preprocess.py --config configs/multispeaker.yaml --path /path/to/ljspeech 136 | python train_tacotron.py --config configs/multispeaker.yaml 137 | python train_forward.py --config configs/multispeaker.yaml 138 | ``` 139 | 140 | ## Pretrained Models 141 | 142 | | Model | Dataset | Commit Tag | 143 | |---|---|------------| 144 | |[forward_tacotron](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ForwardTacotron/forward_step90k.pt)| ljspeech | v3.1 | 145 | |[fastpitch](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ForwardTacotron/thorsten_fastpitch_50k.pt)| [thorstenmueller (german)](https://github.com/thorstenMueller/deep-learning-german-tts) | v3.1 | 146 | 147 | Our pre-trained LJSpeech model is compatible with the pre-trained vocoders: 148 | - [MelGAN](https://github.com/seungwonpark/melgan) 149 | - [HiFiGAN](https://github.com/jik876/hifi-gan) 150 | 151 | 152 | After downloading the models you can synthesize text using the pretrained models with 153 | ``` 154 | python gen_forward.py --input_text 'Hi there!' --checkpoint forward_step90k.pt wavernn --voc_checkpoint wave_step_575k.pt 155 | 156 | ``` 157 | 158 | ## Export Model with TorchScript 159 | 160 | Here is a dummy example of exporting the model in TorchScript: 161 | ``` 162 | import torch 163 | from models.forward_tacotron import ForwardTacotron 164 | 165 | tts_model = ForwardTacotron.from_checkpoint('checkpoints/ljspeech_tts.forward/latest_model.pt') 166 | tts_model.eval() 167 | model_script = torch.jit.script(tts_model) 168 | x = torch.ones((1, 5)).long() 169 | y = model_script.generate_jit(x) 170 | ``` 171 | For the necessary preprocessing steps (text to tokens) please refer to: 172 | ``` 173 | gen_forward.py 174 | ``` 175 | 176 | ## References 177 | 178 | * [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263) 179 | * [FastPitch: Parallel Text-to-speech with Pitch Prediction](https://arxiv.org/abs/2006.06873) 180 | * [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) 181 | * [MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis](https://arxiv.org/abs/1910.06711) 182 | * [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/abs/1806.04558) 183 | 184 | ## Acknowlegements 185 | 186 | * [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron) 187 | * [https://github.com/fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) 188 | * [https://github.com/seungwonpark/melgan](https://github.com/seungwonpark/melgan) 189 | * [https://github.com/jik876/hifi-gan](https://github.com/jik876/hifi-gan) 190 | * [https://github.com/xcmyz/LightSpeech](https://github.com/xcmyz/LightSpeech) 191 | * [https://github.com/resemble-ai/Resemblyzer](https://github.com/resemble-ai/Resemblyzer) 192 | * [https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch) 193 | * [https://github.com/resemble-ai/Resemblyzer](https://github.com/resemble-ai/Resemblyzer) 194 | 195 | ## Maintainers 196 | 197 | * Christian Schäfer, github: [cschaefer26](https://github.com/cschaefer26) 198 | 199 | ## Copyright 200 | 201 | See [LICENSE](LICENSE) for details. 202 | -------------------------------------------------------------------------------- /assets/WaveRNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/WaveRNN.png -------------------------------------------------------------------------------- /assets/energy_tb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/energy_tb.png -------------------------------------------------------------------------------- /assets/fast_speech.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/fast_speech.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/model.png -------------------------------------------------------------------------------- /assets/tacotron_wavernn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/tacotron_wavernn.png -------------------------------------------------------------------------------- /assets/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/tensorboard.png -------------------------------------------------------------------------------- /assets/tensorboard_wavernn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/tensorboard_wavernn.png -------------------------------------------------------------------------------- /assets/training_viz.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/training_viz.gif -------------------------------------------------------------------------------- /assets/wavernn_alt_model_hrz2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/assets/wavernn_alt_model_hrz2.png -------------------------------------------------------------------------------- /configs/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | 2 | tts_model_id: 'multispeaker_tts' 3 | data_path: 'data_multispeaker' # output data path 4 | 5 | tts_model: 'multi_forward_tacotron' # choices: [multi_forward_tacotron, multi_fast_pitch] 6 | 7 | 8 | dsp: 9 | 10 | sample_rate: 22050 11 | n_fft: 1024 12 | num_mels: 80 13 | hop_length: 256 14 | win_length: 1024 15 | fmin: 0 16 | fmax: 8000 17 | peak_norm: False # Normalise to the peak of each wav file 18 | trim_start_end_silence: True # Whether to trim leading and trailing silence 19 | trim_silence_top_db: 60 # Threshold in decibels below reference to consider silence for for trimming 20 | # start and end silences with librosa (no trimming if really high) 21 | 22 | trim_long_silences: True # Whether to reduce long silence using WebRTC Voice Activity Detector 23 | vad_window_length: 30 # In milliseconds 24 | vad_moving_average_width: 8 25 | vad_max_silence_length: 12 26 | vad_sample_rate: 16000 27 | 28 | 29 | preprocessing: 30 | 31 | metafile_format: 'ljspeech_multi' # Choices [ljspeech_multi, vctk, pandas] 32 | # ljspeech_multi expects a .csv file with rows: 'file_id|speaker_id|text" 33 | # pandas expects a .tsv with columns: ['file_id', 'speaker_id', 'text'] 34 | # expects VCTK version 0.92 (set audio_format to '_mic1.flac') 35 | audio_format: '.wav' # Audio extension, usually .wav (different for VCTK) 36 | seed: 42 37 | n_val: 2000 38 | language: 'en' 39 | cleaner_name: 'english_cleaners' # choices: ['english_cleaners', 'no_cleaners'], expands numbers and abbreviations. 40 | use_phonemes: True # whether to phonemize the text 41 | # if set to False, you have to provide the phonemized text yourself 42 | min_text_len: 2 43 | pitch_min_freq: 30 # Minimum value for pitch frequency to remove outliers (Common pitch range is about 60-300) 44 | pitch_max_freq: 600 # Maximum value for pitch frequency to remove outliers (Common pitch range is about 60-300) 45 | pitch_extractor: pyworld # choice of pitch extraction library, choices: [librosa, pyworld] 46 | pitch_frame_length: 2048 # Frame length for extracting pitch with librosa 47 | 48 | 49 | duration_extraction: 50 | 51 | silence_threshold: -11 # normalized mel value below which the voice is considered silent 52 | # minimum mel value = -11.512925465 for zeros in the wav array (=log(1e-5), 53 | # where 1e-5 is a cutoff value) 54 | silence_prob_shift: 0.25 # increase probability for silent characters in periods of silence 55 | # for better durations during non voiced periods 56 | max_batch_size: 32 # max allowed for binned dataloader used for tacotron inference 57 | num_workers: 12 # number of processes for costly dijkstra duration extraction 58 | 59 | 60 | tacotron: 61 | 62 | model: 63 | embed_dims: 256 64 | encoder_dims: 128 65 | decoder_dims: 256 66 | postnet_dims: 128 67 | speaker_emb_dim: 256 # dimension of speaker embedding, 68 | # set to 0 for no speaker conditioning, to 256 for speaker conditioning 69 | encoder_k: 16 70 | lstm_dims: 512 71 | postnet_k: 8 72 | num_highways: 4 73 | dropout: 0.5 74 | stop_threshold: -11 # Value below which audio generation ends. 75 | 76 | aligner_hidden_dims: 256 # text-mel aligner hidden dimensions 77 | aligner_out_dims: 32 # text-mel aligner encoding dimensions for text and mel 78 | 79 | training: 80 | schedule: 81 | - 8, 1e-3, 30_000, 32 # progressive training schedule 82 | - 4, 1e-4, 40_000, 16 # (r, lr, step, batch_size) 83 | - 2, 1e-4, 50_000, 8 84 | - 1, 1e-4, 65_000, 8 85 | 86 | dia_loss_matrix_g: 0.2 # value of g for diatonal matrix (larger g = broader diagonal) 87 | dia_loss_factor: 1.0 # factor for scaling diagonal loss 88 | ctc_loss_factor: 0.1 # factor for scaling aligner CTC loss 89 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 90 | checkpoint_every: 10000 # checkpoints the model every x steps 91 | plot_every: 1000 # generates samples and plots every x steps 92 | num_workers: 2 # number of workers for dataloader 93 | 94 | filter: 95 | max_mel_len: 1250 # filter files with mel len larger than given 96 | filter_duration_stats: False # whether to filter according to the duration stats below 97 | min_attention_sharpness: 0.5 # filter files with bad attention sharpness score, if 0 then no filter 98 | min_attention_alignment: 0.95 # filter files with bad attention alignment score, if 0 then no filter 99 | max_duration: 40 # filter files with durations larger than given 100 | max_consecutive_ones: 6 # filter files where durations contain more consecutive ones than given 101 | 102 | 103 | multi_forward_tacotron: 104 | 105 | model: 106 | speaker_emb_dims: 256 107 | embed_dims: 256 # embedding dimension for main model 108 | series_embed_dims: 128 # embedding dimension for series predictor 109 | 110 | durpred_conv_dims: 256 111 | durpred_rnn_dims: 128 112 | durpred_dropout: 0.5 113 | 114 | pitch_conv_dims: 256 115 | pitch_rnn_dims: 256 116 | pitch_dropout: 0.5 117 | pitch_strength: 1. # set to 0 if you want no pitch conditioning 118 | 119 | energy_conv_dims: 256 120 | energy_rnn_dims: 64 121 | energy_dropout: 0.5 122 | energy_strength: 1. # set to 0 if you want no energy conditioning 123 | 124 | pitch_cond_conv_dims: 256 # predictor for pitch prior (predicts unvoiced phonemes with zero pitch) 125 | pitch_cond_rnn_dims: 128 126 | pitch_cond_dropout: 0.5 127 | pitch_cond_emb_dims: 4 # conditional embedding on pitch softmax prediction (zero vs non-zero pitch) 128 | pitch_cond_categorical_dims: 3 # dimension of categorical output of pitch conditioning, should be set to 3 129 | # (zero=padding, one=zero pitch, two=nonzero pitch) 130 | 131 | prenet_dims: 256 132 | prenet_k: 16 133 | prenet_dropout: 0.5 134 | prenet_num_highways: 4 135 | 136 | rnn_dims: 512 137 | 138 | postnet_dims: 256 139 | postnet_k: 8 140 | postnet_num_highways: 4 141 | postnet_dropout: 0. 142 | 143 | training: 144 | schedule: 145 | - 5e-5, 500_000, 32 # progressive training schedule 146 | - 1e-5, 600_000, 32 # lr, step, batch_size 147 | dur_loss_factor: 0.1 148 | pitch_loss_factor: 0.1 149 | energy_loss_factor: 0.1 150 | pitch_cond_loss_factor: 0.1 151 | 152 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 153 | checkpoint_every: 50_000 # checkpoints the model every x steps 154 | plot_every: 5000 # generates samples and plots every x steps 155 | plot_n_speakers: 3 # max number of speakers to generate plots for 156 | plot_speakers: # speakers to generate plots for (additionally to plot_n_speakers) 157 | - default_speaker 158 | 159 | filter: 160 | max_mel_len: 1250 # filter files with mel len larger than given 161 | filter_duration_stats: True # whether to filter according to the duration stats below 162 | min_attention_sharpness: 0.5 # filter files with bad attention sharpness score, if 0 then no filter 163 | min_attention_alignment: 0.95 # filter files with bad attention alignment score, if 0 then no filter 164 | max_duration: 40 # filter files with durations larger than given 165 | max_consecutive_ones: 6 # filter files where durations contain more consecutive ones than given 166 | 167 | 168 | multi_fast_pitch: 169 | 170 | model: 171 | speaker_emb_dims: 256 172 | 173 | durpred_d_model: 128 174 | durpred_n_heads: 2 175 | durpred_layers: 4 176 | durpred_d_fft: 128 177 | durpred_dropout: 0.5 178 | 179 | pitch_d_model: 128 180 | pitch_n_heads: 2 181 | pitch_layers: 4 182 | pitch_d_fft: 128 183 | pitch_dropout: 0.5 184 | pitch_strength: 1.0 185 | 186 | energy_d_model: 128 187 | energy_n_heads: 2 188 | energy_layers: 4 189 | energy_d_fft: 128 190 | energy_dropout: 0.5 191 | energy_strength: 1.0 192 | 193 | pitch_cond_d_model: 128 194 | pitch_cond_n_heads: 2 195 | pitch_cond_layers: 4 196 | pitch_cond_d_fft: 128 197 | pitch_cond_dropout: 0.5 198 | pitch_cond_output_dims: 3 # dimension of categorical output of pitch conditioning, should be set to 3 199 | # (zero=padding, one=zero pitch, two=nonzero pitch) 200 | 201 | d_model: 256 202 | conv1_kernel: 9 203 | conv2_kernel: 1 204 | 205 | prenet_layers: 4 206 | prenet_heads: 2 207 | prenet_fft: 1024 208 | prenet_dropout: 0.1 209 | 210 | postnet_layers: 4 211 | postnet_heads: 2 212 | postnet_fft: 1024 213 | postnet_dropout: 0.1 214 | 215 | 216 | training: 217 | schedule: 218 | - 1e-5, 5_000, 32 # progressive training schedule 219 | - 5e-5, 300_000, 32 # lr, step, batch_size 220 | - 2e-5, 300_000, 32 221 | dur_loss_factor: 0.1 222 | pitch_loss_factor: 0.1 223 | energy_loss_factor: 0.1 224 | pitch_cond_loss_factor: 0.1 225 | pitch_zoneout: 0. # zoneout may regularize conditioning on pitch 226 | energy_zoneout: 0. # zoneout may regularize conditioning on energy 227 | 228 | max_mel_len: 1250 229 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 230 | checkpoint_every: 10_000 # checkpoints the model every x steps 231 | plot_every: 1000 232 | plot_n_speakers: 3 # max number of speakers to generate plots for 233 | plot_speakers: # speakers to generate plots for (additionally to plot_n_speakers) 234 | - default_speaker 235 | 236 | filter: 237 | max_mel_len: 1250 # filter files with mel len larger than given 238 | filter_duration_stats: True # whether to filter according to the duration stats below 239 | min_attention_sharpness: 0.5 # filter files with bad attention sharpness score, if 0 then no filter 240 | min_attention_alignment: 0.95 # filter files with bad attention alignment score, if 0 then no filter 241 | max_duration: 40 # filter files with durations larger than given 242 | max_consecutive_ones: 6 # filter files where durations contain more consecutive ones than given 243 | -------------------------------------------------------------------------------- /configs/singlespeaker.yaml: -------------------------------------------------------------------------------- 1 | 2 | tts_model_id: 'ljspeech_tts' 3 | data_path: 'data' # output data path 4 | 5 | tts_model: 'forward_tacotron' # choices: [forward_tacotron, fast_pitch] 6 | 7 | 8 | dsp: 9 | 10 | sample_rate: 22050 11 | n_fft: 1024 12 | num_mels: 80 13 | hop_length: 256 14 | win_length: 1024 15 | fmin: 0 16 | fmax: 8000 17 | target_dBFS: -30 # Target loudness in decibels, used for normalization 18 | peak_norm: False # Normalise to the peak of each wav file 19 | trim_start_end_silence: True # Whether to trim leading and trailing silence 20 | trim_silence_top_db: 60 # Threshold in decibels below reference to consider silence for for trimming 21 | # start and end silences with librosa (no trimming if really high) 22 | 23 | trim_long_silences: False # Whether to reduce long silence using WebRTC Voice Activity Detector 24 | vad_window_length: 30 # In milliseconds 25 | vad_moving_average_width: 8 26 | vad_max_silence_length: 12 27 | vad_sample_rate: 16000 28 | 29 | 30 | preprocessing: 31 | 32 | metafile_format: 'ljspeech' # not to be changed, we use the simplest format for singlespeaker models 33 | audio_format: '.wav' # extension for audio files (e.g. .wav or .flac) 34 | seed: 42 35 | n_val: 200 36 | language: 'en-us' 37 | cleaner_name: 'english_cleaners' # choices: ['english_cleaners', 'no_cleaners'], expands numbers and abbreviations. 38 | use_phonemes: True # whether to phonemize the text 39 | # if set to False, you have to provide the phonemized text yourself 40 | min_text_len: 2 41 | pitch_min_freq: 30 # Minimum value for pitch frequency to remove outliers (Common pitch range is 42 | # about 60-300) 43 | pitch_max_freq: 600 # Maximum value for pitch frequency to remove outliers (Common pitch range is 44 | # about 60-300)¡ 45 | pitch_extractor: pyworld # choice of pitch extraction library, choices: [librosa, pyworld] 46 | pitch_frame_length: 2048 # Frame length for extracting pitch with librosa 47 | 48 | 49 | duration_extraction: 50 | 51 | silence_threshold: -11 # normalized mel value below which the voice is considered silent 52 | # minimum mel value = -11.512925465 for zeros in the wav array (=log(1e-5), 53 | # where 1e-5 is a cutoff value) 54 | silence_prob_shift: 0.25 # increase probability for silent characters in periods of silence 55 | # for better durations during non voiced periods 56 | max_batch_size: 32 # max allowed for binned dataloader used for tacotron inference 57 | num_workers: 12 # number of processes for costly dijkstra duration extraction 58 | 59 | 60 | tacotron: 61 | 62 | model: 63 | embed_dims: 256 64 | encoder_dims: 128 65 | decoder_dims: 256 66 | postnet_dims: 128 67 | speaker_emb_dim: 0 # dimension of speaker embedding, 68 | # set to 0 for no speaker conditioning, to 256 for speaker conditioning 69 | encoder_k: 16 70 | lstm_dims: 512 71 | postnet_k: 8 72 | num_highways: 4 73 | dropout: 0.5 74 | stop_threshold: -11 # Value below which audio generation ends. 75 | 76 | aligner_hidden_dims: 256 # text-mel aligner hidden dimensions 77 | aligner_out_dims: 32 # text-mel aligner encoding dimensions for text and mel 78 | 79 | training: 80 | schedule: 81 | - 5, 1e-3, 10_000, 32 # progressive training schedule 82 | - 3, 1e-4, 20_000, 16 # (r, lr, step, batch_size) 83 | - 2, 1e-4, 30_000, 8 84 | - 1, 1e-4, 40_000, 8 85 | 86 | dia_loss_matrix_g: 0.2 # value of g for diatonal matrix (larger g = broader diagonal) 87 | dia_loss_factor: 1.0 # factor for scaling diagonal loss 88 | ctc_loss_factor: 0.1 # factor for scaling aligner CTC loss 89 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 90 | checkpoint_every: 10000 # checkpoints the model every x steps 91 | plot_every: 1000 # generates samples and plots every x steps 92 | num_workers: 2 # number of workers for dataloader 93 | 94 | filter: 95 | max_mel_len: 1250 # filter files with mel len larger than given 96 | filter_duration_stats: False # whether to filter according to the duration stats below 97 | min_attention_sharpness: 0.5 # filter files with bad attention sharpness score, if 0 then no filter 98 | min_attention_alignment: 0.95 # filter files with bad attention alignment score, if 0 then no filter 99 | max_duration: 40 # filter files with durations larger than given 100 | max_consecutive_ones: 6 # filter files where durations contain more consecutive ones than given 101 | 102 | 103 | forward_tacotron: 104 | 105 | model: 106 | embed_dims: 256 # embedding dimension for main model 107 | series_embed_dims: 64 # embedding dimension for series predictor 108 | 109 | durpred_conv_dims: 256 110 | durpred_rnn_dims: 64 111 | durpred_dropout: 0.5 112 | 113 | pitch_conv_dims: 256 114 | pitch_rnn_dims: 128 115 | pitch_dropout: 0.5 116 | pitch_strength: 1. # set to 0 if you want no pitch conditioning 117 | 118 | energy_conv_dims: 256 119 | energy_rnn_dims: 64 120 | energy_dropout: 0.5 121 | energy_strength: 1. # set to 0 if you want no energy conditioning 122 | 123 | prenet_dims: 256 124 | prenet_k: 16 125 | prenet_dropout: 0.5 126 | prenet_num_highways: 4 127 | 128 | rnn_dims: 512 129 | 130 | postnet_dims: 256 131 | postnet_k: 8 132 | postnet_num_highways: 4 133 | postnet_dropout: 0. 134 | 135 | training: 136 | schedule: 137 | - 5e-5, 150_000, 32 # progressive training schedule 138 | - 1e-5, 300_000, 32 # lr, step, batch_size 139 | dur_loss_factor: 0.1 140 | pitch_loss_factor: 0.1 141 | energy_loss_factor: 0.1 142 | pitch_zoneout: 0. # zoneout may regularize conditioning on pitch 143 | energy_zoneout: 0. # zoneout may regularize conditioning on energy 144 | 145 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 146 | checkpoint_every: 10_000 # checkpoints the model every x steps 147 | plot_every: 1000 # generates samples and plots every x steps 148 | 149 | filter: 150 | max_mel_len: 1250 # filter files with mel len larger than given 151 | filter_duration_stats: True # whether to filter according to the duration stats below 152 | min_attention_sharpness: 0.5 # filter files with bad attention sharpness score, if 0 then no filter 153 | min_attention_alignment: 0.95 # filter files with bad attention alignment score, if 0 then no filter 154 | max_duration: 40 # filter files with durations larger than given 155 | max_consecutive_ones: 6 # filter files where durations contain more consecutive ones than given 156 | 157 | fast_pitch: 158 | 159 | model: 160 | durpred_d_model: 128 161 | durpred_n_heads: 2 162 | durpred_layers: 4 163 | durpred_d_fft: 128 164 | durpred_dropout: 0.5 165 | 166 | pitch_d_model: 128 167 | pitch_n_heads: 2 168 | pitch_layers: 4 169 | pitch_d_fft: 128 170 | pitch_dropout: 0.5 171 | pitch_strength: 1.0 172 | 173 | energy_d_model: 128 174 | energy_n_heads: 2 175 | energy_layers: 4 176 | energy_d_fft: 128 177 | energy_dropout: 0.5 178 | energy_strength: 1.0 179 | 180 | d_model: 256 181 | conv1_kernel: 9 182 | conv2_kernel: 1 183 | 184 | prenet_layers: 4 185 | prenet_heads: 2 186 | prenet_fft: 1024 187 | prenet_dropout: 0.1 188 | 189 | postnet_layers: 4 190 | postnet_heads: 2 191 | postnet_fft: 1024 192 | postnet_dropout: 0.1 193 | 194 | 195 | training: 196 | schedule: 197 | - 1e-5, 5_000, 32 # progressive training schedule 198 | - 5e-5, 100_000, 32 # lr, step, batch_size 199 | - 2e-5, 300_000, 32 200 | dur_loss_factor: 0.1 201 | pitch_loss_factor: 0.1 202 | energy_loss_factor: 0.1 203 | pitch_zoneout: 0. # zoneout may regularize conditioning on pitch 204 | energy_zoneout: 0. # zoneout may regularize conditioning on energy 205 | 206 | max_mel_len: 1250 207 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 208 | checkpoint_every: 10_000 # checkpoints the model every x steps 209 | plot_every: 1000 210 | 211 | filter: 212 | max_mel_len: 1250 # filter files with mel len larger than given 213 | filter_duration_stats: True # whether to filter according to the duration stats below 214 | min_attention_sharpness: 0.5 # filter files with bad attention sharpness score, if 0 then no filter 215 | min_attention_alignment: 0.95 # filter files with bad attention alignment score, if 0 then no filter 216 | max_duration: 40 # filter files with durations larger than given 217 | max_consecutive_ones: 6 # filter files where durations contain more consecutive ones than given 218 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _site 2 | .sass-cache 3 | .jekyll-cache 4 | .jekyll-metadata 5 | vendor 6 | .bundle -------------------------------------------------------------------------------- /docs/404.html: -------------------------------------------------------------------------------- 1 | --- 2 | permalink: /404.html 3 | layout: default 4 | --- 5 | 6 | 19 | 20 |Page not found :(
24 |The requested page could not be found.
25 |Scientists at the CERN laboratory say they have discovered a new particle.
9 | 10 | 11 | 12 |There’s a way to measure the acute emotional intelligence that has never gone out of style.
13 | 14 | 15 | 16 |President Trump met with other leaders at the Group of 20 conference.
17 | 18 | 19 | 20 |In a statement announcing his resignation, Mr Ross, said: "While the intentions may have been well meaning, the reaction to this news shows that Mr Cummings interpretation of the government advice was not shared by the vast majority of people who have done as the government asked."
21 | 22 | 23 | 24 | 25 | ## Forward Tacotron + MelGAN Vocoder 26 | 27 | The samples are generated with a model trained 400K steps on [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) together with the pretrained MelGAN vocoder provided by the [MelGAN repo](https://github.com/seungwonpark/melgan). 28 | 29 |Scientists at the CERN laboratory say they have discovered a new particle.
30 | 31 | | normal speed | faster (1.25) | slower (0.85) | 32 | |:---:|:---:|:---:| 33 | |||| 34 | 35 |There’s a way to measure the acute emotional intelligence that has never gone out of style.
36 | 37 | |:---:|:---:|:---:| 38 | |||| 39 | 40 |President Trump met with other leaders at the Group of 20 conference.
41 | 42 | |:---:|:---:|:---:| 43 | |||| 44 | 45 | ## Forward Tacotron + WaveRNN Vocoder 46 | 47 | The samples are generated with a model trained 100K steps on [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) together with the pretrained WaveRNN vocoder provided by the [WaveRNN repo](https://github.com/fatchord/WaveRNN). 48 | 49 |Scientists at the CERN laboratory say they have discovered a new particle.
50 | 51 | | normal speed | faster (1.25) | slower (0.8) | 52 | |:---:|:---:|:---:| 53 | |||| 54 | 55 |There’s a way to measure the acute emotional intelligence that has never gone out of style.
56 | 57 | |:---:|:---:|:---:| 58 | |||| 59 | 60 | 61 |President Trump met with other leaders at the Group of 20 conference.
62 | 63 | |:---:|:---:|:---:| 64 | |||| 65 | 66 | ## Forward Tacotron + Griffin-Lim 67 | 68 |The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled.
69 | 70 | | normal speed | faster (1.4) | slower (0.6) | 71 | |:---:|:---:|:---:| 72 | |||| 73 | 74 |Generative adversarial network or variational auto-encoder.
75 | 76 | |:---:|:---:|:---:| 77 | |||| 78 | 79 |Basilar membrane and otolaryngology are not auto-correlations.
80 | 81 | |:---:|:---:|:---:| 82 | |||| 83 | 84 | 85 |Synthetic speech can be created by concatenating pieces of recorded speech that are stored in a database. Systems differ in the size of the stored speech units; a system that stores phones or diphones provides the largest output range, but may lack clarity. For specific usage domains, the storage of entire words or sentences allows for high-quality output. Alternatively, a synthesizer can incorporate a model of the vocal tract and other human voice characteristics to create a completely "synthetic" voice output.
86 | -------------------------------------------------------------------------------- /duration_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/duration_extraction/__init__.py -------------------------------------------------------------------------------- /duration_extraction/duration_extraction_pipe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from logging import INFO 4 | from typing import List, Dict, Any, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | from tqdm import tqdm 10 | 11 | from duration_extraction.duration_extractor import DurationExtractor 12 | from models.tacotron import Tacotron 13 | from trainer.common import to_device 14 | from utils.dataset import BinnedLengthSampler, get_binned_taco_dataloader, DurationStats 15 | from utils.files import unpickle_binary 16 | from utils.metrics import attention_score 17 | from utils.paths import Paths 18 | from utils.text.tokenizer import Tokenizer 19 | 20 | 21 | @dataclass 22 | class DurationResult: 23 | item_id: str 24 | att_score: float 25 | align_score: float 26 | durations: np.array 27 | 28 | 29 | class DurationCollator: 30 | 31 | def __call__(self, x: List[DurationResult]) -> DurationResult: 32 | if len(x) > 1: 33 | raise ValueError(f'Batch size must be 1! Found batch size: {len(x)}') 34 | return x[0] 35 | 36 | 37 | class DurationExtractionDataset(Dataset): 38 | 39 | def __init__(self, 40 | duration_extractor: DurationExtractor, 41 | paths: Paths, 42 | dataset_ids: List[str], 43 | text_dict: Dict[str, str], 44 | tokenizer: Tokenizer): 45 | self.metadata = dataset_ids 46 | self.text_dict = text_dict 47 | self.tokenizer = tokenizer 48 | self.text_dict = text_dict 49 | self.duration_extractor = duration_extractor 50 | self.paths = paths 51 | 52 | def __getitem__(self, index: int) -> DurationResult: 53 | item_id = self.metadata[index] 54 | x = self.text_dict[item_id] 55 | x = self.tokenizer(x) 56 | mel = np.load(self.paths.mel / f'{item_id}.npy') 57 | mel = torch.from_numpy(mel) 58 | x = torch.tensor(x) 59 | attention_npy = np.load(str(self.paths.att_pred / f'{item_id}.npy')) 60 | attention = torch.from_numpy(attention_npy) 61 | mel_len = mel.shape[-1] 62 | mel_len = torch.tensor(mel_len).unsqueeze(0) 63 | align_score, _ = attention_score(attention.unsqueeze(0), mel_len, r=1) 64 | align_score = float(align_score) 65 | durations, att_score = self.duration_extractor(x=x, mel=mel, attention=attention) 66 | att_score = float(att_score) 67 | durations_npy = durations.cpu().numpy() 68 | if np.sum(durations_npy) != mel_len: 69 | print(f'WARNINNG: Sum of durations did not match mel length for item {item_id}!') 70 | return DurationResult(item_id=item_id, att_score=att_score, 71 | align_score=align_score, durations=durations_npy) 72 | 73 | def __len__(self): 74 | return len(self.metadata) 75 | 76 | 77 | class DurationExtractionPipeline: 78 | 79 | def __init__(self, 80 | paths: Paths, 81 | config: Dict[str, Any], 82 | duration_extractor: DurationExtractor) -> None: 83 | self.paths = paths 84 | self.config = config 85 | self.duration_extractor = duration_extractor 86 | self.logger = logging.Logger(__name__, level=INFO) 87 | 88 | def extract_attentions(self, 89 | model: Tacotron, 90 | max_batch_size: int = 1) -> float: 91 | """ 92 | Performs tacotron inference and stores the attention matrices as npy arrays in paths.data.att_pred. 93 | Returns average attention score. 94 | 95 | Args: 96 | model: Tacotron model to use for attention extraction. 97 | batch_size: Batch size to use for tacotron inference. 98 | 99 | Returns: Mean attention score. The attention matrices are saved as numpy arrays in paths.att_pred. 100 | 101 | """ 102 | 103 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 104 | model.to(device) 105 | 106 | dataloader = get_binned_taco_dataloader(paths=self.paths, max_batch_size=max_batch_size) 107 | 108 | sum_items = 0 109 | sum_att_score = 0 110 | pbar = tqdm(dataloader, total=len(dataloader), smoothing=0.01) 111 | for i, batch in enumerate(pbar, 1): 112 | batch = to_device(batch, device=device) 113 | with torch.no_grad(): 114 | out = model(batch) 115 | attention_batch = out['att'] 116 | _, att_score = attention_score(attention_batch, batch['mel_len'], r=1) 117 | sum_att_score += att_score.sum() 118 | B = batch['x_len'].size(0) 119 | sum_items += B 120 | for b in range(B): 121 | x_len = batch['x_len'][b].cpu() 122 | mel_len = batch['mel_len'][b].cpu() 123 | item_id = batch['item_id'][b] 124 | attention = attention_batch[b, :mel_len, :x_len].cpu() 125 | np.save(self.paths.att_pred / f'{item_id}.npy', attention.numpy(), allow_pickle=False) 126 | pbar.set_description(f'Avg attention score: {sum_att_score / sum_items}', refresh=True) 127 | 128 | return sum_att_score / len(dataloader) 129 | 130 | def extract_durations(self, 131 | num_workers: int = 0, 132 | sampler_bin_size: int = 1) -> Dict[str, DurationStats]: 133 | """ 134 | Extracts durations from saved attention matrices. 135 | 136 | Args: 137 | num_workers: Number of workers for multiprocessing. 138 | sampler_bin_size: Bin size of BinnedLengthSampler. 139 | Should be greater than one (but much less than length of dataset) for optimal performance. 140 | 141 | Returns: Dictionary containing the attention scores for each item id. 142 | The durations are saved as numpy arrays in paths.alg. 143 | """ 144 | 145 | train_set = unpickle_binary(self.paths.train_dataset) 146 | val_set = unpickle_binary(self.paths.val_dataset) 147 | text_dict = unpickle_binary(self.paths.text_dict) 148 | dataset = train_set + val_set 149 | dataset = [(file_id, mel_len) for file_id, mel_len in dataset 150 | if (self.paths.att_pred / f'{file_id}.npy').is_file()] 151 | len_orig = len(dataset) 152 | data_ids, mel_lens = list(zip(*dataset)) 153 | self.logger.info(f'Found {len(data_ids)} / {len_orig} ' 154 | f'alignment files in {self.paths.att_pred}') 155 | 156 | duration_stats = {} 157 | sum_att_score = 0 158 | 159 | dataset = DurationExtractionDataset( 160 | duration_extractor=self.duration_extractor, 161 | paths=self.paths, dataset_ids=data_ids, 162 | text_dict=text_dict, tokenizer=Tokenizer()) 163 | 164 | dataset = DataLoader(dataset=dataset, 165 | batch_size=1, 166 | shuffle=False, 167 | pin_memory=False, 168 | collate_fn=DurationCollator(), 169 | sampler=BinnedLengthSampler(lengths=mel_lens, batch_size=1, bin_size=sampler_bin_size), 170 | num_workers=num_workers) 171 | 172 | pbar = tqdm(dataset, total=len(dataset), smoothing=0.01) 173 | 174 | for i, res in enumerate(pbar, 1): 175 | sum_att_score += res.att_score 176 | pbar.set_description(f'Avg duration attention score: {sum_att_score / i}', refresh=True) 177 | max_consecutive_ones = self._get_max_consecutive_ones(res.durations) 178 | max_duration = np.max(res.durations) 179 | duration_stats[res.item_id] = DurationStats(att_align_score=res.align_score, 180 | att_sharpness_score=res.att_score, 181 | max_consecutive_ones=max_consecutive_ones, 182 | max_duration=max_duration) 183 | np.save(self.paths.alg / f'{res.item_id}.npy', res.durations.astype(int), allow_pickle=False) 184 | 185 | return duration_stats 186 | 187 | @staticmethod 188 | def _get_max_consecutive_ones(durations: np.array) -> int: 189 | max_count = 0 190 | count = 0 191 | for d in durations: 192 | if d == 1: 193 | count += 1 194 | else: 195 | max_count = max(max_count, count) 196 | count = 0 197 | return max(max_count, count) 198 | -------------------------------------------------------------------------------- /duration_extraction/duration_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.sparse import coo_matrix, csr_matrix 6 | from scipy.sparse.csgraph import dijkstra 7 | 8 | from utils.text.symbols import silent_phonemes_indices 9 | 10 | 11 | class DurationExtractor: 12 | 13 | def __init__(self, 14 | silence_threshold: float, 15 | silence_prob_shift: float) -> None: 16 | """ 17 | :param silence_threshold: Mel spec threshold below which the voice is considered silent. 18 | :param silence_prob_shift: Attention probability that is added to silent phonemes in unvoiced parts. 19 | """ 20 | self.silence_prob_shift = silence_prob_shift 21 | self.silence_threshold = silence_threshold 22 | 23 | def __call__(self, 24 | x: torch.Tensor, 25 | mel: torch.Tensor, 26 | attention: torch.Tensor) -> Tuple[torch.tensor, float]: 27 | """ 28 | Extracts durations from the attention matrix by finding the shortest monotonic path from 29 | top left to bottom right. 30 | 31 | :param x: Tokenized sequence. 32 | :param mel: Mel spec. 33 | :param attention: Attention matrix with shape (mel_len, x_len). 34 | :return: Tuple, where the first entry is the durations and the second entry is the average attention probability. 35 | """ 36 | attention = attention[...] 37 | mel_len = mel.shape[-1] 38 | 39 | # We add a little probability to silent phonemes within unvoiced parts of the spec where the tacotron attention 40 | # is usually very unreliable. As a result we get more accurate (larger) durations for unvoiced parts and 41 | # avoid 'leakage' of durations into surrounding word phonemes. 42 | sil_mask = mel.mean(dim=0) < self.silence_threshold 43 | sil_mel_inds = sil_mask.nonzero().squeeze() 44 | sil_mel_inds = list(sil_mel_inds) if len(sil_mel_inds.size()) > 0 else [] 45 | 46 | sil_phon_inds = torch.tensor(silent_phonemes_indices) 47 | for i in sil_mel_inds: 48 | sil_tok_inds = torch.isin(x, sil_phon_inds) 49 | att_shift = sil_tok_inds.float() * self.silence_prob_shift * 2 - self.silence_prob_shift 50 | attention[i, :] = attention[i, :] + att_shift 51 | 52 | attention = torch.clamp(attention, min=0., max=1.) 53 | path_probs = 1. - attention[:mel_len, :] 54 | adj_matrix = self._to_adj_matrix(path_probs) 55 | dist_matrix, predecessors = dijkstra(csgraph=adj_matrix, directed=True, 56 | indices=0, return_predecessors=True) 57 | path = [] 58 | pr_index = predecessors[-1] 59 | while pr_index != 0: 60 | path.append(pr_index) 61 | pr_index = predecessors[pr_index] 62 | path.reverse() 63 | 64 | # append first and last node 65 | path = [0] + path + [dist_matrix.size-1] 66 | cols = path_probs.shape[1] 67 | mel_text = {} 68 | durations = torch.zeros(x.shape[0]) 69 | 70 | att_scores = [] 71 | 72 | # collect indices (mel, text) along the path 73 | for node_index in path: 74 | i, j = self._from_node_index(node_index, cols) 75 | mel_text[i] = j 76 | if not sil_mask[i]: 77 | att_scores.append(float(attention[i, j])) 78 | 79 | for j in mel_text.values(): 80 | durations[j] += 1 81 | 82 | att_score = sum(att_scores) / len(att_scores) 83 | 84 | return durations, att_score 85 | 86 | @staticmethod 87 | def _to_node_index(i: int, j: int, cols: int) -> int: 88 | return cols * i + j 89 | 90 | @staticmethod 91 | def _from_node_index(node_index: int, cols: int) -> Tuple[int, int]: 92 | return node_index // cols, node_index % cols 93 | 94 | @staticmethod 95 | def _to_adj_matrix(mat: np.array) -> csr_matrix: 96 | rows = mat.shape[0] 97 | cols = mat.shape[1] 98 | 99 | row_ind = [] 100 | col_ind = [] 101 | data = [] 102 | 103 | for i in range(rows): 104 | for j in range(cols): 105 | 106 | node = DurationExtractor._to_node_index(i, j, cols) 107 | 108 | if j < cols - 1: 109 | right_node = DurationExtractor._to_node_index(i, j + 1, cols) 110 | weight_right = mat[i, j + 1] 111 | row_ind.append(node) 112 | col_ind.append(right_node) 113 | data.append(weight_right) 114 | 115 | if i < rows - 1 and j < cols: 116 | bottom_node = DurationExtractor._to_node_index(i + 1, j, cols) 117 | weight_bottom = mat[i + 1, j] 118 | row_ind.append(node) 119 | col_ind.append(bottom_node) 120 | data.append(weight_bottom) 121 | 122 | if i < rows - 1 and j < cols - 1: 123 | bottom_right_node = DurationExtractor._to_node_index(i + 1, j + 1, cols) 124 | weight_bottom_right = mat[i + 1, j + 1] 125 | row_ind.append(node) 126 | col_ind.append(bottom_right_node) 127 | data.append(weight_bottom_right) 128 | 129 | adj_mat = coo_matrix((data, (row_ind, col_ind)), shape=(rows * cols, rows * cols)) 130 | return adj_mat.tocsr() 131 | -------------------------------------------------------------------------------- /gen_forward.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import numpy as np 4 | import torch 5 | from utils.checkpoints import init_tts_model 6 | from utils.display import simple_table 7 | from utils.dsp import DSP 8 | from utils.files import read_config 9 | from utils.paths import Paths 10 | from utils.text.cleaners import Cleaner 11 | from utils.text.tokenizer import Tokenizer 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | # Parse Arguments 17 | parser = argparse.ArgumentParser(description='TTS Generator') 18 | parser.add_argument('--input_text', '-i', default=None, type=str, help='[string] Type in something here and TTS will generate it!') 19 | parser.add_argument('--checkpoint', type=str, default=None, help='[string/path] path to .pt model file.') 20 | parser.add_argument('--config', metavar='FILE', default='default.yaml', help='The config containing all hyperparams. Only' 21 | 'used if no checkpoint is set.') 22 | parser.add_argument('--speaker', type=str, default=None, help='Speaker to generate audio for (only multispeaker).') 23 | 24 | parser.add_argument('--alpha', type=float, default=1., help='Parameter for controlling length regulator for speedup ' 25 | 'or slow-down of generated speech, e.g. alpha=2.0 is double-time') 26 | parser.add_argument('--amp', type=float, default=1., help='Parameter for controlling pitch amplification') 27 | 28 | # name of subcommand goes to args.vocoder 29 | subparsers = parser.add_subparsers(dest='vocoder') 30 | gl_parser = subparsers.add_parser('griffinlim') 31 | mg_parser = subparsers.add_parser('melgan') 32 | hg_parser = subparsers.add_parser('hifigan') 33 | 34 | args = parser.parse_args() 35 | 36 | assert args.vocoder in {'griffinlim', 'melgan', 'hifigan'}, \ 37 | 'Please provide a valid vocoder! Choices: [griffinlim, melgan, hifigan]' 38 | 39 | checkpoint_path = args.checkpoint 40 | if checkpoint_path is None: 41 | config = read_config(args.config) 42 | paths = Paths(config['data_path'], config['tts_model_id']) 43 | checkpoint_path = paths.forward_checkpoints / 'latest_model.pt' 44 | 45 | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) 46 | config = checkpoint['config'] 47 | tts_model = init_tts_model(config) 48 | tts_model.load_state_dict(checkpoint['model']) 49 | speaker_embedding = None 50 | if args.speaker is not None: 51 | assert 'speaker_embeddings' in checkpoint, 'Could not find speaker embeddings in checkpoint! Make sure you ' \ 52 | 'use trained multispeaker model!' 53 | speaker_embeddings = checkpoint.get('speaker_embeddings', None) 54 | assert args.speaker in speaker_embeddings, \ 55 | f'Provided speaker not found in speaker embeddings: {args.speaker},\n' \ 56 | f'Available speakers: {checkpoint["speaker_embeddings"].keys()}' 57 | speaker_embedding = speaker_embeddings[args.speaker] 58 | 59 | print(f'Initialized tts model: {tts_model}') 60 | print(f'Restored model with step {tts_model.get_step()}') 61 | dsp = DSP.from_config(config) 62 | 63 | voc_model, voc_dsp = None, None 64 | out_path = Path('model_outputs') 65 | out_path.mkdir(parents=True, exist_ok=True) 66 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 67 | tts_model.to(device) 68 | cleaner = Cleaner.from_config(config) 69 | tokenizer = Tokenizer() 70 | 71 | print(f'Using device: {device}\n') 72 | if args.input_text: 73 | texts = [args.input_text] 74 | else: 75 | with open('sentences.txt', 'r', encoding='utf-8') as f: 76 | texts = f.readlines() 77 | 78 | tts_k = tts_model.get_step() // 1000 79 | tts_model.eval() 80 | 81 | simple_table([('Forward Tacotron', str(tts_k) + 'k'), 82 | ('Vocoder Type', args.vocoder)]) 83 | 84 | # simple amplification of pitch 85 | pitch_function = lambda x: x * args.amp 86 | energy_function = lambda x: x 87 | 88 | for i, x in enumerate(texts, 1): 89 | print(f'\n| Generating {i}/{len(texts)}') 90 | text = x 91 | x = cleaner(x) 92 | x = tokenizer(x) 93 | x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0) 94 | 95 | speaker_name = args.speaker if args.speaker is not None else 'default_speaker' 96 | wav_name = f'{i}_forward_{tts_k}k_{speaker_name}_alpha{args.alpha}_amp{args.amp}_{args.vocoder}' 97 | 98 | input = { 99 | 'x': x, 100 | 'alpha': args.alpha, 101 | 'pitch_function': pitch_function, 102 | 'energy_function': energy_function 103 | } 104 | if speaker_embedding is not None: 105 | input.update({'speaker_emb': speaker_embedding}) 106 | 107 | gen = tts_model.generate(**input) 108 | 109 | m = gen['mel_post'].cpu() 110 | if args.vocoder == 'melgan': 111 | torch.save(m, out_path / f'{wav_name}.mel') 112 | if args.vocoder == 'hifigan': 113 | np.save(str(out_path / f'{wav_name}.npy'), m.numpy(), allow_pickle=False) 114 | elif args.vocoder == 'griffinlim': 115 | wav = dsp.griffinlim(m.squeeze().numpy()) 116 | dsp.save_wav(wav, out_path / f'{wav_name}.wav') 117 | 118 | print('\n\nDone.\n') -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/models/__init__.py -------------------------------------------------------------------------------- /models/common_layers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import LayerNorm, MultiheadAttention 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | 12 | class LengthRegulator(nn.Module): 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor: 18 | dur[dur < 0] = 0. 19 | x_expanded = [] 20 | for i in range(x.size(0)): 21 | x_exp = torch.repeat_interleave(x[i], (dur[i] + 0.5).long(), dim=0) 22 | x_expanded.append(x_exp) 23 | x_expanded = pad_sequence(x_expanded, padding_value=0., batch_first=True) 24 | return x_expanded 25 | 26 | 27 | class HighwayNetwork(nn.Module): 28 | 29 | def __init__(self, size: int) -> None: 30 | super().__init__() 31 | self.W1 = nn.Linear(size, size) 32 | self.W2 = nn.Linear(size, size) 33 | self.W1.bias.data.fill_(0.) 34 | 35 | def forward(self, x: torch.Tensor) -> torch.Tensor: 36 | x1 = self.W1(x) 37 | x2 = self.W2(x) 38 | g = torch.sigmoid(x2) 39 | y = g * F.relu(x1) + (1. - g) * x 40 | return y 41 | 42 | 43 | class BatchNormConv(nn.Module): 44 | 45 | def __init__(self, 46 | in_channels: int, 47 | out_channels: int, 48 | kernel: int, relu=True) -> None: 49 | super().__init__() 50 | self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) 51 | self.bnorm = nn.BatchNorm1d(out_channels) 52 | self.relu = relu 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | x = self.conv(x) 56 | x = F.relu(x) if self.relu is True else x 57 | return self.bnorm(x) 58 | 59 | 60 | class CBHG(nn.Module): 61 | 62 | def __init__(self, 63 | K: int, 64 | in_channels: int, 65 | channels: int, 66 | proj_channels: list, 67 | num_highways: int, 68 | dropout: float = 0.5) -> None: 69 | super().__init__() 70 | 71 | self.dropout = dropout 72 | self.bank_kernels = [i for i in range(1, K + 1)] 73 | self.conv1d_bank = nn.ModuleList() 74 | for k in self.bank_kernels: 75 | conv = BatchNormConv(in_channels, channels, k) 76 | self.conv1d_bank.append(conv) 77 | 78 | self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) 79 | 80 | self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) 81 | self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) 82 | 83 | self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) 84 | self.highways = nn.ModuleList() 85 | for i in range(num_highways): 86 | hn = HighwayNetwork(channels) 87 | self.highways.append(hn) 88 | 89 | self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | residual = x 93 | seq_len = x.size(-1) 94 | conv_bank = [] 95 | 96 | # Convolution Bank 97 | for conv in self.conv1d_bank: 98 | c = conv(x) # Convolution 99 | conv_bank.append(c[:, :, :seq_len]) 100 | 101 | # Stack along the channel axis 102 | conv_bank = torch.cat(conv_bank, dim=1) 103 | 104 | # dump the last padding to fit residual 105 | x = self.maxpool(conv_bank)[:, :, :seq_len] 106 | x = F.dropout(x, p=self.dropout, training=self.training) 107 | 108 | # Conv1d projections 109 | x = self.conv_project1(x) 110 | x = F.dropout(x, p=self.dropout, training=self.training) 111 | x = self.conv_project2(x) 112 | 113 | # Residual Connect 114 | x = x + residual 115 | 116 | # Through the highways 117 | x = x.transpose(1, 2) 118 | x = self.pre_highway(x) 119 | for h in self.highways: 120 | x = h(x) 121 | 122 | # And then the RNN 123 | x, _ = self.rnn(x) 124 | return x 125 | 126 | 127 | class PositionalEncoding(torch.nn.Module): 128 | 129 | def __init__(self, d_model: int, dropout=0.1, max_len=5000) -> None: 130 | super(PositionalEncoding, self).__init__() 131 | self.dropout = torch.nn.Dropout(p=dropout) 132 | self.scale = torch.nn.Parameter(torch.ones(1)) 133 | 134 | pe = torch.zeros(max_len, d_model) 135 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 136 | div_term = torch.exp(torch.arange( 137 | 0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 138 | pe[:, 0::2] = torch.sin(position * div_term) 139 | pe[:, 1::2] = torch.cos(position * div_term) 140 | pe = pe.unsqueeze(0).transpose(0, 1) 141 | self.register_buffer('pe', pe) 142 | 143 | def forward(self, x: torch.Tensor) -> torch.Tensor: # shape: [T, N] 144 | x = x + self.scale * self.pe[:x.size(0), :] 145 | return self.dropout(x) 146 | 147 | 148 | class FFTBlock(nn.Module): 149 | 150 | def __init__(self, 151 | d_model: int, 152 | nhead: int, 153 | conv1_kernel: int, 154 | conv2_kernel: int, 155 | d_fft: int, 156 | dropout: float = 0.1): 157 | super(FFTBlock, self).__init__() 158 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 159 | self.dropout = nn.Dropout(dropout) 160 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fft, 161 | kernel_size=conv1_kernel, stride=1, padding=conv1_kernel // 2) 162 | self.conv2 = nn.Conv1d(in_channels=d_fft, out_channels=d_model, 163 | kernel_size=conv2_kernel, stride=1, padding=conv2_kernel // 2) 164 | self.norm1 = LayerNorm(d_model) 165 | self.norm2 = LayerNorm(d_model) 166 | self.dropout1 = nn.Dropout(dropout) 167 | self.dropout2 = nn.Dropout(dropout) 168 | self.activation = torch.nn.ReLU() 169 | 170 | def forward(self, 171 | src: torch.Tensor, 172 | src_pad_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 173 | src2 = self.self_attn(src, src, src, 174 | attn_mask=None, 175 | key_padding_mask=src_pad_mask)[0] 176 | src = src + self.dropout1(src2) 177 | src = self.norm1(src) 178 | src = src.transpose(0, 1).transpose(1, 2) 179 | src2 = self.conv1(src) 180 | src2 = self.activation(src2) 181 | src2 = self.conv2(src2) 182 | src = src + self.dropout2(src2) 183 | src = src.transpose(1, 2).transpose(0, 1) 184 | src = self.norm2(src) 185 | return src 186 | 187 | 188 | class ForwardTransformer(torch.nn.Module): 189 | 190 | def __init__(self, 191 | d_model: int, 192 | d_fft: int, 193 | layers: int, 194 | heads: int, 195 | conv1_kernel: int, 196 | conv2_kernel: int, 197 | dropout: float = 0.1, 198 | ) -> None: 199 | super().__init__() 200 | 201 | self.d_model = d_model 202 | self.pos_encoder = PositionalEncoding(d_model, dropout) 203 | encoder_layer = FFTBlock(d_model=d_model, 204 | nhead=heads, 205 | d_fft=d_fft, 206 | conv1_kernel=conv1_kernel, 207 | conv2_kernel=conv2_kernel, 208 | dropout=dropout) 209 | encoder_norm = LayerNorm(d_model) 210 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) 211 | for _ in range(layers)]) 212 | self.norm = encoder_norm 213 | 214 | def forward(self, 215 | x: torch.Tensor, 216 | src_pad_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # shape: [N, T] 217 | x = x.transpose(0, 1) # shape: [T, N] 218 | x = self.pos_encoder(x) 219 | for layer in self.layers: 220 | x = layer(x, src_pad_mask=src_pad_mask) 221 | x = self.norm(x) 222 | x = x.transpose(0, 1) 223 | return x 224 | 225 | 226 | def generate_square_subsequent_mask(sz: int) -> torch.Tensor: 227 | mask = torch.triu(torch.ones(sz, sz), 1) 228 | mask = mask.masked_fill(mask == 1, float('-inf')) 229 | return mask 230 | 231 | 232 | def make_token_len_mask(x: torch.Tensor) -> torch.Tensor: 233 | return (x == 0).transpose(0, 1) 234 | 235 | 236 | def make_mel_len_mask(x: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: 237 | len_mask = torch.zeros((x.size(0), x.size(1))).bool().to(x.device) 238 | for i, mel_len in enumerate(mel_lens): 239 | len_mask[i, mel_len:] = True 240 | return len_mask 241 | -------------------------------------------------------------------------------- /models/fast_pitch.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Callable, Dict, Any, Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Embedding 9 | 10 | from models.common_layers import LengthRegulator, ForwardTransformer, make_token_len_mask 11 | from utils.text.symbols import phonemes 12 | 13 | 14 | class SeriesPredictor(nn.Module): 15 | 16 | def __init__(self, 17 | num_chars: int, 18 | d_model: int, 19 | n_heads: int, 20 | d_fft: int, 21 | layers: int, 22 | conv1_kernel: int, 23 | conv2_kernel: int, 24 | dropout=0.1): 25 | super().__init__() 26 | self.embedding = Embedding(num_chars, d_model) 27 | self.transformer = ForwardTransformer(heads=n_heads, dropout=dropout, 28 | d_model=d_model, d_fft=d_fft, 29 | conv1_kernel=conv1_kernel, 30 | conv2_kernel=conv2_kernel, 31 | layers=layers) 32 | self.lin = nn.Linear(d_model, 1) 33 | 34 | def forward(self, 35 | x: torch.Tensor, 36 | src_pad_mask: Optional[torch.Tensor] = None, 37 | alpha: float = 1.0) -> torch.Tensor: 38 | x = self.embedding(x) 39 | x = self.transformer(x, src_pad_mask=src_pad_mask) 40 | x = self.lin(x) 41 | return x / alpha 42 | 43 | 44 | class FastPitch(nn.Module): 45 | 46 | def __init__(self, 47 | num_chars: int, 48 | durpred_dropout: float, 49 | durpred_d_model: int, 50 | durpred_n_heads: int, 51 | durpred_layers: int, 52 | durpred_d_fft: int, 53 | pitch_dropout: float, 54 | pitch_d_model: int, 55 | pitch_n_heads: int, 56 | pitch_layers: int, 57 | pitch_d_fft: int, 58 | energy_dropout: float, 59 | energy_d_model: int, 60 | energy_n_heads: int, 61 | energy_layers: int, 62 | energy_d_fft: int, 63 | pitch_strength: float, 64 | energy_strength: float, 65 | d_model: int, 66 | conv1_kernel: int, 67 | conv2_kernel: int, 68 | prenet_layers: int, 69 | prenet_heads: int, 70 | prenet_fft: int, 71 | prenet_dropout: float, 72 | postnet_layers: int, 73 | postnet_heads: int, 74 | postnet_fft: int, 75 | postnet_dropout: float, 76 | n_mels: int, 77 | padding_value=-11.5129): 78 | super().__init__() 79 | self.padding_value = padding_value 80 | self.lr = LengthRegulator() 81 | self.dur_pred = SeriesPredictor(num_chars=num_chars, 82 | d_model=durpred_d_model, 83 | n_heads=durpred_n_heads, 84 | layers=durpred_layers, 85 | d_fft=durpred_d_fft, 86 | conv1_kernel=conv1_kernel, 87 | conv2_kernel=conv2_kernel, 88 | dropout=durpred_dropout) 89 | self.pitch_pred = SeriesPredictor(num_chars=num_chars, 90 | d_model=pitch_d_model, 91 | n_heads=pitch_n_heads, 92 | layers=pitch_layers, 93 | d_fft=pitch_d_fft, 94 | conv1_kernel=conv1_kernel, 95 | conv2_kernel=conv2_kernel, 96 | dropout=pitch_dropout) 97 | self.energy_pred = SeriesPredictor(num_chars=num_chars, 98 | d_model=energy_d_model, 99 | n_heads=energy_n_heads, 100 | layers=energy_layers, 101 | d_fft=energy_d_fft, 102 | conv1_kernel=conv1_kernel, 103 | conv2_kernel=conv2_kernel, 104 | dropout=energy_dropout) 105 | self.embedding = Embedding(num_embeddings=num_chars, embedding_dim=d_model) 106 | self.prenet = ForwardTransformer(heads=prenet_heads, dropout=prenet_dropout, 107 | conv1_kernel=conv1_kernel, conv2_kernel=conv2_kernel, 108 | d_model=d_model, d_fft=prenet_fft, layers=prenet_layers) 109 | self.postnet = ForwardTransformer(heads=postnet_heads, dropout=postnet_dropout, 110 | conv1_kernel=conv1_kernel, conv2_kernel=conv2_kernel, 111 | d_model=d_model, d_fft=postnet_fft, layers=postnet_layers) 112 | self.lin = torch.nn.Linear(d_model, n_mels) 113 | self.register_buffer('step', torch.zeros(1, dtype=torch.long)) 114 | self.pitch_strength = pitch_strength 115 | self.energy_strength = energy_strength 116 | self.pitch_proj = nn.Conv1d(1, d_model, kernel_size=3, padding=1) 117 | self.energy_proj = nn.Conv1d(1, d_model, kernel_size=3, padding=1) 118 | 119 | def __repr__(self): 120 | num_params = sum([np.prod(p.size()) for p in self.parameters()]) 121 | return f'FastPitch, num params: {num_params}' 122 | 123 | def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 124 | x = batch['x'] 125 | mel = batch['mel'] 126 | dur = batch['dur'] 127 | mel_lens = batch['mel_len'] 128 | pitch = batch['pitch'].unsqueeze(1) 129 | energy = batch['energy'].unsqueeze(1) 130 | 131 | if self.training: 132 | self.step += 1 133 | 134 | len_mask = make_token_len_mask(x.transpose(0, 1)) 135 | dur_hat = self.dur_pred(x, src_pad_mask=len_mask).squeeze(-1) 136 | pitch_hat = self.pitch_pred(x, src_pad_mask=len_mask).transpose(1, 2) 137 | energy_hat = self.energy_pred(x, src_pad_mask=len_mask).transpose(1, 2) 138 | 139 | x = self.embedding(x) 140 | x = self.prenet(x, src_pad_mask=len_mask) 141 | 142 | pitch_proj = self.pitch_proj(pitch) 143 | pitch_proj = pitch_proj.transpose(1, 2) 144 | x = x + pitch_proj * self.pitch_strength 145 | 146 | energy_proj = self.energy_proj(energy) 147 | energy_proj = energy_proj.transpose(1, 2) 148 | x = x + energy_proj * self.energy_strength 149 | 150 | x = self.lr(x, dur) 151 | 152 | len_mask = torch.zeros((x.size(0), x.size(1))).bool().to(x.device) 153 | for i, mel_len in enumerate(mel_lens): 154 | len_mask[i, mel_len:] = True 155 | 156 | x = self.postnet(x, src_pad_mask=len_mask) 157 | 158 | x = self.lin(x) 159 | x = x.transpose(1, 2) 160 | 161 | x_post = self.pad(x, mel.size(2)) 162 | x = self.pad(x, mel.size(2)) 163 | 164 | return {'mel': x, 'mel_post': x_post, 165 | 'dur': dur_hat, 'pitch': pitch_hat, 'energy': energy_hat} 166 | 167 | def generate(self, 168 | x: torch.Tensor, 169 | alpha=1.0, 170 | pitch_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, 171 | energy_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x) -> Dict[str, torch.Tensor]: 172 | self.eval() 173 | with torch.no_grad(): 174 | dur_hat = self.dur_pred(x, alpha=alpha) 175 | dur_hat = dur_hat.squeeze(2) 176 | if torch.sum(dur_hat.long()) <= 0: 177 | torch.fill_(dur_hat, value=2.) 178 | pitch_hat = self.pitch_pred(x).transpose(1, 2) 179 | pitch_hat = pitch_function(pitch_hat) 180 | energy_hat = self.energy_pred(x).transpose(1, 2) 181 | energy_hat = energy_function(energy_hat) 182 | return self._generate_mel(x=x, dur_hat=dur_hat, 183 | pitch_hat=pitch_hat, 184 | energy_hat=energy_hat) 185 | 186 | def pad(self, x: torch.Tensor, max_len: int) -> torch.Tensor: 187 | x = x[:, :, :max_len] 188 | x = F.pad(x, [0, max_len - x.size(2), 0, 0], 'constant', self.padding_value) 189 | return x 190 | 191 | def get_step(self) -> int: 192 | return self.step.data.item() 193 | 194 | def _generate_mel(self, 195 | x: torch.Tensor, 196 | dur_hat: torch.Tensor, 197 | pitch_hat: torch.Tensor, 198 | energy_hat: torch.Tensor) -> Dict[str, torch.Tensor]: 199 | 200 | len_mask = make_token_len_mask(x.transpose(0, 1)) 201 | 202 | x = self.embedding(x) 203 | x = self.prenet(x, src_pad_mask=len_mask) 204 | 205 | pitch_proj = self.pitch_proj(pitch_hat) 206 | pitch_proj = pitch_proj.transpose(1, 2) 207 | x = x + pitch_proj * self.pitch_strength 208 | 209 | energy_proj = self.energy_proj(energy_hat) 210 | energy_proj = energy_proj.transpose(1, 2) 211 | x = x + energy_proj * self.energy_strength 212 | 213 | x = self.lr(x, dur_hat) 214 | 215 | x = self.postnet(x, src_pad_mask=None) 216 | 217 | x = self.lin(x) 218 | x = x.transpose(1, 2) 219 | 220 | return {'mel': x, 'mel_post': x, 'dur': dur_hat, 221 | 'pitch': pitch_hat, 'energy': energy_hat} 222 | 223 | @classmethod 224 | def from_config(cls, config: Dict[str, Any]) -> 'FastPitch': 225 | model_config = config['fast_pitch']['model'] 226 | model_config['num_chars'] = len(phonemes) 227 | model_config['n_mels'] = config['dsp']['num_mels'] 228 | return FastPitch(**model_config) 229 | 230 | @classmethod 231 | def from_checkpoint(cls, path: Union[Path, str]) -> 'FastPitch': 232 | checkpoint = torch.load(path, map_location=torch.device('cpu')) 233 | model = FastPitch.from_config(checkpoint['config']) 234 | model.load_state_dict(checkpoint['model']) 235 | return model 236 | -------------------------------------------------------------------------------- /models/forward_tacotron.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Callable, Dict, Any 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import Embedding 8 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 9 | 10 | from models.common_layers import CBHG, LengthRegulator, BatchNormConv 11 | from utils.text.symbols import phonemes 12 | 13 | 14 | class SeriesPredictor(nn.Module): 15 | 16 | def __init__(self, num_chars, emb_dim=64, conv_dims=256, rnn_dims=64, dropout=0.5): 17 | super().__init__() 18 | self.embedding = Embedding(num_chars, emb_dim) 19 | self.convs = torch.nn.ModuleList([ 20 | BatchNormConv(emb_dim, conv_dims, 5, relu=True), 21 | BatchNormConv(conv_dims, conv_dims, 5, relu=True), 22 | BatchNormConv(conv_dims, conv_dims, 5, relu=True), 23 | ]) 24 | self.rnn = nn.GRU(conv_dims, rnn_dims, batch_first=True, bidirectional=True) 25 | self.lin = nn.Linear(2 * rnn_dims, 1) 26 | self.dropout = dropout 27 | 28 | def forward(self, 29 | x: torch.Tensor, 30 | alpha: float = 1.0) -> torch.Tensor: 31 | x = self.embedding(x) 32 | x = x.transpose(1, 2) 33 | for conv in self.convs: 34 | x = conv(x) 35 | x = F.dropout(x, p=self.dropout, training=self.training) 36 | x = x.transpose(1, 2) 37 | x, _ = self.rnn(x) 38 | x = self.lin(x) 39 | return x / alpha 40 | 41 | 42 | class ForwardTacotron(nn.Module): 43 | 44 | def __init__(self, 45 | embed_dims: int, 46 | series_embed_dims: int, 47 | num_chars: int, 48 | durpred_conv_dims: int, 49 | durpred_rnn_dims: int, 50 | durpred_dropout: float, 51 | pitch_conv_dims: int, 52 | pitch_rnn_dims: int, 53 | pitch_dropout: float, 54 | pitch_strength: float, 55 | energy_conv_dims: int, 56 | energy_rnn_dims: int, 57 | energy_dropout: float, 58 | energy_strength: float, 59 | rnn_dims: int, 60 | prenet_dims: int, 61 | prenet_k: int, 62 | postnet_num_highways: int, 63 | prenet_dropout: float, 64 | postnet_dims: int, 65 | postnet_k: int, 66 | prenet_num_highways: int, 67 | postnet_dropout: float, 68 | n_mels: int, 69 | padding_value=-11.5129): 70 | super().__init__() 71 | self.rnn_dims = rnn_dims 72 | self.padding_value = padding_value 73 | self.embedding = nn.Embedding(num_chars, embed_dims) 74 | self.lr = LengthRegulator() 75 | self.dur_pred = SeriesPredictor(num_chars=num_chars, 76 | emb_dim=series_embed_dims, 77 | conv_dims=durpred_conv_dims, 78 | rnn_dims=durpred_rnn_dims, 79 | dropout=durpred_dropout) 80 | self.pitch_pred = SeriesPredictor(num_chars=num_chars, 81 | emb_dim=series_embed_dims, 82 | conv_dims=pitch_conv_dims, 83 | rnn_dims=pitch_rnn_dims, 84 | dropout=pitch_dropout) 85 | self.energy_pred = SeriesPredictor(num_chars=num_chars, 86 | emb_dim=series_embed_dims, 87 | conv_dims=energy_conv_dims, 88 | rnn_dims=energy_rnn_dims, 89 | dropout=energy_dropout) 90 | self.prenet = CBHG(K=prenet_k, 91 | in_channels=embed_dims, 92 | channels=prenet_dims, 93 | proj_channels=[prenet_dims, embed_dims], 94 | num_highways=prenet_num_highways, 95 | dropout=prenet_dropout) 96 | self.lstm = nn.LSTM(2 * prenet_dims, 97 | rnn_dims, 98 | batch_first=True, 99 | bidirectional=True) 100 | self.lin = torch.nn.Linear(2 * rnn_dims, n_mels) 101 | self.register_buffer('step', torch.zeros(1, dtype=torch.long)) 102 | self.postnet = CBHG(K=postnet_k, 103 | in_channels=n_mels, 104 | channels=postnet_dims, 105 | proj_channels=[postnet_dims, n_mels], 106 | num_highways=postnet_num_highways, 107 | dropout=postnet_dropout) 108 | self.post_proj = nn.Linear(2 * postnet_dims, n_mels, bias=False) 109 | self.pitch_strength = pitch_strength 110 | self.energy_strength = energy_strength 111 | self.pitch_proj = nn.Conv1d(1, 2 * prenet_dims, kernel_size=3, padding=1) 112 | self.energy_proj = nn.Conv1d(1, 2 * prenet_dims, kernel_size=3, padding=1) 113 | 114 | def __repr__(self): 115 | num_params = sum([np.prod(p.size()) for p in self.parameters()]) 116 | return f'ForwardTacotron, num params: {num_params}' 117 | 118 | def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 119 | x = batch['x'] 120 | mel = batch['mel'] 121 | dur = batch['dur'] 122 | mel_lens = batch['mel_len'] 123 | pitch = batch['pitch'].unsqueeze(1) 124 | energy = batch['energy'].unsqueeze(1) 125 | 126 | if self.training: 127 | self.step += 1 128 | 129 | dur_hat = self.dur_pred(x).squeeze(-1) 130 | pitch_hat = self.pitch_pred(x).transpose(1, 2) 131 | energy_hat = self.energy_pred(x).transpose(1, 2) 132 | 133 | x = self.embedding(x) 134 | x = x.transpose(1, 2) 135 | x = self.prenet(x) 136 | 137 | pitch_proj = self.pitch_proj(pitch) 138 | pitch_proj = pitch_proj.transpose(1, 2) 139 | x = x + pitch_proj * self.pitch_strength 140 | 141 | energy_proj = self.energy_proj(energy) 142 | energy_proj = energy_proj.transpose(1, 2) 143 | x = x + energy_proj * self.energy_strength 144 | 145 | x = self.lr(x, dur) 146 | 147 | x = pack_padded_sequence(x, lengths=mel_lens.cpu(), enforce_sorted=False, 148 | batch_first=True) 149 | 150 | x, _ = self.lstm(x) 151 | 152 | x, _ = pad_packed_sequence(x, padding_value=self.padding_value, batch_first=True) 153 | 154 | x = self.lin(x) 155 | x = x.transpose(1, 2) 156 | 157 | x_post = self.postnet(x) 158 | x_post = self.post_proj(x_post) 159 | x_post = x_post.transpose(1, 2) 160 | 161 | x_post = self._pad(x_post, mel.size(2)) 162 | x = self._pad(x, mel.size(2)) 163 | 164 | return {'mel': x, 'mel_post': x_post, 165 | 'dur': dur_hat, 'pitch': pitch_hat, 'energy': energy_hat} 166 | 167 | def generate(self, 168 | x: torch.Tensor, 169 | alpha=1.0, 170 | pitch_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, 171 | energy_function: Callable[[torch.Tensor], torch.Tensor] = lambda x: x) -> Dict[str, torch.Tensor]: 172 | self.eval() 173 | with torch.no_grad(): 174 | dur_hat = self.dur_pred(x, alpha=alpha) 175 | dur_hat = dur_hat.squeeze(2) 176 | if torch.sum(dur_hat.long()) <= 0: 177 | torch.fill_(dur_hat, value=2.) 178 | pitch_hat = self.pitch_pred(x).transpose(1, 2) 179 | pitch_hat = pitch_function(pitch_hat) 180 | energy_hat = self.energy_pred(x).transpose(1, 2) 181 | energy_hat = energy_function(energy_hat) 182 | return self._generate_mel(x=x, dur_hat=dur_hat, 183 | pitch_hat=pitch_hat, 184 | energy_hat=energy_hat) 185 | 186 | @torch.jit.export 187 | def generate_jit(self, 188 | x: torch.Tensor, 189 | alpha: float = 1.0, 190 | beta: float = 1.0) -> Dict[str, torch.Tensor]: 191 | with torch.no_grad(): 192 | dur_hat = self.dur_pred(x, alpha=alpha) 193 | dur_hat = dur_hat.squeeze(2) 194 | if torch.sum(dur_hat.long()) <= 0: 195 | torch.fill_(dur_hat, value=2.) 196 | pitch_hat = self.pitch_pred(x).transpose(1, 2) * beta 197 | energy_hat = self.energy_pred(x).transpose(1, 2) 198 | return self._generate_mel(x=x, dur_hat=dur_hat, 199 | pitch_hat=pitch_hat, 200 | energy_hat=energy_hat) 201 | 202 | def get_step(self) -> int: 203 | return self.step.data.item() 204 | 205 | def _generate_mel(self, 206 | x: torch.Tensor, 207 | dur_hat: torch.Tensor, 208 | pitch_hat: torch.Tensor, 209 | energy_hat: torch.Tensor) -> Dict[str, torch.Tensor]: 210 | x = self.embedding(x) 211 | x = x.transpose(1, 2) 212 | x = self.prenet(x) 213 | 214 | pitch_proj = self.pitch_proj(pitch_hat) 215 | pitch_proj = pitch_proj.transpose(1, 2) 216 | x = x + pitch_proj * self.pitch_strength 217 | 218 | energy_proj = self.energy_proj(energy_hat) 219 | energy_proj = energy_proj.transpose(1, 2) 220 | x = x + energy_proj * self.energy_strength 221 | 222 | x = self.lr(x, dur_hat) 223 | 224 | x, _ = self.lstm(x) 225 | 226 | x = self.lin(x) 227 | x = x.transpose(1, 2) 228 | 229 | x_post = self.postnet(x) 230 | x_post = self.post_proj(x_post) 231 | x_post = x_post.transpose(1, 2) 232 | 233 | return {'mel': x, 'mel_post': x_post, 'dur': dur_hat, 234 | 'pitch': pitch_hat, 'energy': energy_hat} 235 | 236 | def _pad(self, x: torch.Tensor, max_len: int) -> torch.Tensor: 237 | x = x[:, :, :max_len] 238 | x = F.pad(x, [0, max_len - x.size(2), 0, 0], 'constant', self.padding_value) 239 | return x 240 | 241 | 242 | @classmethod 243 | def from_config(cls, config: Dict[str, Any]) -> 'ForwardTacotron': 244 | model_config = config['forward_tacotron']['model'] 245 | model_config['num_chars'] = len(phonemes) 246 | model_config['n_mels'] = config['dsp']['num_mels'] 247 | return ForwardTacotron(**model_config) 248 | 249 | @classmethod 250 | def from_checkpoint(cls, path: Union[Path, str]) -> 'ForwardTacotron': 251 | checkpoint = torch.load(path, map_location=torch.device('cpu')) 252 | model = ForwardTacotron.from_config(checkpoint['config']) 253 | model.load_state_dict(checkpoint['model']) 254 | return model -------------------------------------------------------------------------------- /notebook_utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /notebook_utils/synthesize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Callable 4 | 5 | from utils.checkpoints import init_tts_model 6 | from utils.dsp import DSP 7 | from utils.text.cleaners import Cleaner 8 | from utils.text.tokenizer import Tokenizer 9 | 10 | 11 | class Synthesizer: 12 | 13 | def __init__(self, 14 | tts_path: str, 15 | device='cuda'): 16 | self.device = torch.device(device) 17 | tts_checkpoint = torch.load(tts_path, map_location=self.device) 18 | tts_config = tts_checkpoint['config'] 19 | tts_model = init_tts_model(tts_config) 20 | tts_model.load_state_dict(tts_checkpoint['model']) 21 | self.tts_model = tts_model 22 | self.melgan = torch.hub.load('seungwonpark/melgan', 'melgan') 23 | self.melgan.to(device).eval() 24 | self.cleaner = Cleaner.from_config(tts_config) 25 | self.tokenizer = Tokenizer() 26 | self.dsp = DSP.from_config(tts_config) 27 | 28 | def __call__(self, 29 | text: str, 30 | voc_model: str, 31 | alpha=1.0, 32 | pitch_function: Callable[[torch.tensor], torch.tensor] = lambda x: x, 33 | energy_function: Callable[[torch.tensor], torch.tensor] = lambda x: x, 34 | ) -> np.array: 35 | x = self.cleaner(text) 36 | x = self.tokenizer(x) 37 | x = torch.tensor(x).unsqueeze(0) 38 | gen = self.tts_model.generate(x, 39 | alpha=alpha, 40 | pitch_function=pitch_function, 41 | energy_function=energy_function) 42 | m = gen['mel_post'].cpu() 43 | if voc_model == 'griffinlim': 44 | wav = self.dsp.griffinlim(m.squeeze().numpy(), n_iter=32) 45 | else: 46 | m = m.cuda() 47 | with torch.no_grad(): 48 | wav = self.melgan.inference(m).cpu().numpy() 49 | return wav 50 | -------------------------------------------------------------------------------- /pitch_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/pitch_extraction/__init__.py -------------------------------------------------------------------------------- /pitch_extraction/pitch_extractor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from enum import Enum 3 | from typing import Dict, Any, Union 4 | 5 | import librosa 6 | import numpy as np 7 | import torchaudio.functional as F 8 | import torch 9 | 10 | from utils.dataset import tensor_to_ndarray 11 | 12 | try: 13 | import pyworld as pw 14 | except ImportError as e: 15 | print('WARNING: Could not import pyworld! Please use pitch_extraction_method: librosa.') 16 | 17 | 18 | class PitchExtractionMethod(Enum): 19 | LIBROSA = 'librosa' 20 | PYWORLD = 'pyworld' 21 | TORCHAUDIO = 'torchaudio' 22 | 23 | 24 | class PitchExtractor(ABC): 25 | 26 | def __call__(self, wav: np.array) -> np.array: 27 | raise NotImplementedError() 28 | 29 | 30 | class LibrosaPitchExtractor(PitchExtractor): 31 | 32 | def __init__(self, 33 | fmin: int, 34 | fmax: int, 35 | sample_rate: int, 36 | frame_length: int, 37 | hop_length: int) -> None: 38 | 39 | self.fmin = fmin 40 | self.fmax= fmax 41 | self.sample_rate = sample_rate 42 | self.frame_length = frame_length 43 | self.hop_length = hop_length 44 | 45 | def __call__(self, wav: Union[torch.Tensor, np.array]) -> np.array: 46 | if torch.is_tensor(wav): 47 | wav = tensor_to_ndarray(wav) 48 | pitch, _, _ = librosa.pyin(wav, 49 | fmin=self.fmin, 50 | fmax=self.fmax, 51 | sr=self.sample_rate, 52 | frame_length=self.frame_length, 53 | hop_length=self.hop_length) 54 | np.nan_to_num(pitch, copy=False, nan=0.) 55 | return pitch 56 | 57 | 58 | class PyworldPitchExtractor(PitchExtractor): 59 | 60 | def __init__(self, 61 | sample_rate: int, 62 | hop_length: int) -> None: 63 | self.sample_rate = sample_rate 64 | self.hop_length = hop_length 65 | 66 | def __call__(self, wav: Union[torch.Tensor, np.array]) -> np.array: 67 | if torch.is_tensor(wav): 68 | wav = tensor_to_ndarray(wav) 69 | return pw.dio(wav.astype(np.float64), self.sample_rate, 70 | frame_period=self.hop_length / self.sample_rate * 1000)[0] 71 | 72 | 73 | class TorchAudioPitchExtractor(PitchExtractor): 74 | 75 | def __init__(self, 76 | sample_rate: int, 77 | hop_length: int, 78 | freq_min: int, 79 | freq_max: int) -> None: 80 | self.sample_rate = sample_rate 81 | self.hop_length = hop_length 82 | self.freq_min = freq_min 83 | self.freq_max = freq_max 84 | 85 | def __call__(self, wav: Union[torch.Tensor, np.array]) -> np.array: 86 | if torch.is_tensor(wav): 87 | wav = tensor_to_ndarray(wav) 88 | return F.detect_pitch_frequency(waveform=wav, 89 | sample_rate=self.sample_rate, 90 | frame_time=self.hop_length / self.sample_rate, 91 | freq_low=self.freq_min, 92 | freq_high=self.freq_max) 93 | 94 | 95 | def new_pitch_extractor_from_config(config: Dict[str, Any]) -> PitchExtractor: 96 | preproc_config = config['preprocessing'] 97 | pitch_extractor_type = preproc_config['pitch_extractor'] 98 | if pitch_extractor_type == 'librosa': 99 | pitch_extractor = LibrosaPitchExtractor(fmin=preproc_config['pitch_min_freq'], 100 | fmax=preproc_config['pitch_max_freq'], 101 | frame_length=preproc_config['pitch_frame_length'], 102 | sample_rate=config['dsp']['sample_rate'], 103 | hop_length=config['dsp']['hop_length']) 104 | elif pitch_extractor_type == 'pyworld': 105 | pitch_extractor = PyworldPitchExtractor(hop_length=config['dsp']['hop_length'], 106 | sample_rate=config['dsp']['sample_rate']) 107 | elif pitch_extractor_type == 'torchaudio': 108 | pitch_extractor = TorchAudioPitchExtractor(freq_min=preproc_config['pitch_min_freq'], 109 | freq_max=preproc_config['pitch_max_freq'], 110 | hop_length=config['dsp']['hop_length'], 111 | sample_rate=config['dsp']['sample_rate']) 112 | else: 113 | raise ValueError(f'Invalid pitch extractor type: {pitch_extractor_type}, choices: [librosa, pyworld].') 114 | return pitch_extractor 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba==0.56.4 2 | librosa==0.10.0 3 | pyworld >= 0.2.10 4 | torch>=1.2.0 5 | phonemizer>=2.2 6 | webrtcvad>=2.0.10 7 | PyYAML>=5.1 8 | tqdm 9 | dataclasses 10 | soundfile 11 | scipy 12 | tensorboard 13 | matplotlib 14 | unidecode 15 | inflect 16 | resemblyzer==0.1.3 17 | pandas 18 | tabulate 19 | torchaudio==2.0.2 -------------------------------------------------------------------------------- /sentences.txt: -------------------------------------------------------------------------------- 1 | Scientists at the CERN laboratory say they have discovered a new particle. 2 | There's a way to measure the acute emotional intelligence that has never gone out of style. 3 | President Trump met with other leaders at the Group of 20 conference. 4 | The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled. 5 | Generative adversarial network or variational auto-encoder. 6 | Basilar membrane and otolaryngology are not auto-correlations. 7 | Hi, I am a non recurrent neural network. 8 | In a statement announcing his resignation, Mr Ross said: "While the intentions may have been well meaning, the reaction to this news shows that Mr Cummings interpretation of the government advice was not shared by the vast majority of people who have done as the government asked. -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/tests/__init__.py -------------------------------------------------------------------------------- /tests/resources/test_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | voc_model_id: 'ljspeech_raw' 3 | tts_model_id: 'ljspeech_tts' 4 | 5 | data_path: 'data/' # output data path 6 | 7 | dsp: 8 | sample_rate: 22050 9 | n_fft: 1024 10 | num_mels: 80 11 | hop_length: 256 # 12.5ms - in line with Tacotron 2 paper 12 | win_length: 1024 # 50ms - same reason as above 13 | fmin: 0 14 | fmax: 8000 15 | 16 | peak_norm: False # Normalise to the peak of each wav file 17 | trim_start_end_silence: True # Whether to trim leading and trailing silence 18 | trim_silence_top_db: 60 # Threshold in decibels below reference to consider silence for for trimming 19 | # start and end silences with librosa (no trimming if really high) 20 | trim_long_silences: False # Whether to reduce long silence using WebRTC Voice Activity Detector 21 | vad_window_length: 30 # In milliseconds 22 | vad_moving_average_width: 8 23 | vad_max_silence_length: 12 24 | vad_sample_rate: 16000 25 | 26 | preprocessing: 27 | seed: 42 28 | n_val: 200 29 | language: 'en-us' 30 | cleaner_name: 'english_cleaners' 31 | min_text_len: 2 32 | 33 | # Duration Extraction from Attention 34 | extract_durations_with_dijkstra: True # slower but much more robust than simply counting attention peaks 35 | 36 | duration_extraction: 37 | 38 | silence_threshold: -11 # normalized mel value below which the voice is considered silent 39 | # minimum mel value = -11.512925465 for zeros in the wav array (=log(1e-5), 40 | # where 1e-5 is a cutoff value) 41 | silence_prob_shift: 0.25 # Increase probability for silent characters in periods of silence 42 | # for better durations during non voiced periods 43 | batch_size: 1 # batch size for tacotron inference to obtain attention matrices 44 | num_workers: 1 # number of processes for costly dijkstra duration extraction 45 | 46 | 47 | vocoder: 48 | model: 49 | mode: 'RAW' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from mixture of logistics) 50 | upsample_factors: [4, 8, 8] # NB - this needs to correctly factorise hop_length 51 | rnn_dims: 512 52 | fc_dims: 512 53 | compute_dims: 128 54 | res_out_dims: 128 55 | res_blocks: 10 56 | pad: 2 # this will pad the input so that the resnet can 'see' wider than input length 57 | 58 | 59 | training: 60 | schedule: 61 | - 1e-4, 300_000, 32 # progressive training schedule 62 | - 1e-5, 600_000, 32 # lr, step, batch_size 63 | 64 | checkpoint_every: 25_000 65 | gen_samples_every: 5000 # how often to generate samples for cherry-picking models 66 | num_gen_samples: 3 # number of samples to generate for cherry-picking models 67 | keep_top_k: 3 # how many top performing models to keep 68 | seq_len: 1280 # must be a multiple of hop_length 69 | clip_grad_norm: 4 # set to None if no gradient clipping needed 70 | max_mel_len: 20000 71 | 72 | # Generating / Synthesizing 73 | gen_batched: True # very fast (realtime+) single utterance batched generation 74 | target: 11_000 # target number of samples to be generated in each batch entry 75 | overlap: 550 # number of samples for crossfading between batches 76 | 77 | 78 | tacotron: 79 | model: 80 | embed_dims: 256 # embedding dimension for the graphemes/phoneme inputs 81 | encoder_dims: 128 82 | decoder_dims: 256 83 | postnet_dims: 128 84 | encoder_K: 16 85 | lstm_dims: 512 86 | postnet_K: 8 87 | num_highways: 4 88 | dropout: 0.5 89 | stop_threshold: -11 # Value below which audio generation ends. 90 | 91 | training: 92 | schedule: 93 | - 10, 1e-3, 10_000, 32 # progressive training schedule 94 | - 5, 1e-4, 20_000, 16 # (r, lr, step, batch_size) 95 | - 2, 1e-4, 30_000, 8 96 | - 1, 1e-4, 50_000, 8 97 | 98 | max_mel_len: 1250 # if you have a couple of extremely long spectrograms you might want to use this 99 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 100 | checkpoint_every: 10_000 # checkpoints the model every X steps 101 | plot_every: 1000 102 | 103 | 104 | forward_tacotron: 105 | model: 106 | embed_dims: 256 # embedding dimension for the graphemes/phoneme inputs 107 | prenet_dims: 256 108 | postnet_dims: 256 109 | durpred_conv_dims: 256 110 | durpred_rnn_dims: 64 111 | durpred_dropout: 0.5 112 | 113 | pitch_conv_dims: 256 114 | pitch_rnn_dims: 128 115 | pitch_dropout: 0.5 116 | pitch_emb_dims: 64 # embedding dimension of pitch, set to 0 if you don't want pitch conditioning 117 | pitch_proj_dropout: 0. 118 | 119 | prenet_K: 16 120 | postnet_K: 8 121 | rnn_dims: 512 122 | num_highways: 4 123 | dropout: 0.1 124 | 125 | training: 126 | schedule: 127 | - 1e-3, 150_000, 32 # progressive training schedule 128 | - 1e-4, 300_000, 32 # lr, step, batch_size 129 | 130 | max_mel_len: 1250 # if you have a couple of extremely long spectrograms you might want to use this 131 | clip_grad_norm: 1.0 # clips the gradient norm to prevent explosion - set to None if not needed 132 | checkpoint_every: 10_000 # checkpoints the model every X steps 133 | plot_every: 1000 134 | 135 | filter_attention: True # whether to filter data with bad attention scores 136 | min_attention_sharpness: 0.5 # filter data with bad attention sharpness score, if 0 then no filter 137 | min_attention_alignment: 0.95 # filter data with bad attention alignment score, if 0 then no filter -------------------------------------------------------------------------------- /tests/resources/test_mel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/tests/resources/test_mel.npy -------------------------------------------------------------------------------- /tests/resources/wavs/0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/tests/resources/wavs/0.wav -------------------------------------------------------------------------------- /tests/resources/wavs/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/tests/resources/wavs/1.wav -------------------------------------------------------------------------------- /tests/test_cleaner.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from utils.text.cleaners import Cleaner 4 | 5 | 6 | class TestCleaner(unittest.TestCase): 7 | 8 | def test_call_happy_path(self) -> None: 9 | cleaner = Cleaner(cleaner_name='no_cleaners', 10 | use_phonemes=True, lang='en-us') 11 | cleaned = cleaner('hello there!') 12 | self.assertEqual('həloʊ ðɛɹ!', cleaned) 13 | 14 | cleaned = cleaner('hello there?!.') 15 | self.assertEqual('həloʊ ðɛɹ?!.', cleaned) 16 | 17 | cleaner = Cleaner(cleaner_name='no_cleaners', 18 | use_phonemes=False, lang='en-us') 19 | cleaned = cleaner(' Hello there!') 20 | self.assertEqual('Hello there!', cleaned) 21 | 22 | cleaner = Cleaner(cleaner_name='english_cleaners', 23 | use_phonemes=False, lang='en-us') 24 | cleaned = cleaner('hello there Mr. 1!') 25 | self.assertEqual('hello there mister one!', cleaned) 26 | -------------------------------------------------------------------------------- /tests/test_collator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | 5 | from utils.dataset import TacoCollator, ForwardCollator 6 | 7 | 8 | class TestDataset(unittest.TestCase): 9 | 10 | def test_collate_forward(self) -> None: 11 | items = [ 12 | { 13 | 'item_id': 0, 14 | 'mel': np.full((2, 5), fill_value=1.), 15 | 'x': np.full(2, fill_value=2.), 16 | 'mel_len': 5, 17 | 'x_len': 2, 18 | 'dur': np.full(2, fill_value=3.), 19 | 'pitch': np.full(2, fill_value=4.), 20 | 'pitch_cond': np.full(2, fill_value=5.), 21 | 'energy': np.full(2, fill_value=5.), 22 | 'speaker_emb': np.full(1, fill_value=4.), 23 | 'speaker_name': 'speaker_1' 24 | }, 25 | { 26 | 'item_id': 1, 27 | 'mel': np.full((2, 6), fill_value=1.), 28 | 'x': np.full(3, fill_value=2.), 29 | 'mel_len': 6, 30 | 'x_len': 3, 31 | 'dur': np.full(3, fill_value=3.), 32 | 'pitch': np.full(3, fill_value=4.), 33 | 'pitch_cond': np.full(3, fill_value=5.), 34 | 'energy': np.full(3, fill_value=5.), 35 | 'speaker_emb': np.full(1, fill_value=5.), 36 | 'speaker_name': 'speaker_2' 37 | } 38 | ] 39 | 40 | collator = ForwardCollator(taco_collator=TacoCollator(r=1)) 41 | batch = collator(items) 42 | self.assertEqual(0, batch['item_id'][0]) 43 | self.assertEqual(1, batch['item_id'][1]) 44 | self.assertEqual((2, 7), batch['mel'][0].size()) 45 | self.assertEqual((2, 7), batch['mel'][1].size()) 46 | self.assertEqual([2., 2., 2., 2., 2., -11.5129*2, -11.5129*2], torch.sum(batch['mel'][0], dim=0).tolist()) 47 | self.assertEqual([2., 2., 2., 2., 2., 2., -11.5129*2], torch.sum(batch['mel'][1], dim=0).tolist()) 48 | self.assertEqual(2, batch['x_len'][0]) 49 | self.assertEqual(3, batch['x_len'][1]) 50 | self.assertEqual(5, batch['mel_len'][0]) 51 | self.assertEqual(6, batch['mel_len'][1]) 52 | self.assertEqual([2., 2., 0], batch['x'][0].tolist()) 53 | self.assertEqual([2., 2., 2.], batch['x'][1].tolist()) 54 | self.assertEqual([3., 3., 0], batch['dur'][0].tolist()) 55 | self.assertEqual([3., 3., 3.], batch['dur'][1].tolist()) 56 | self.assertEqual([4., 4., 0], batch['pitch'][0].tolist()) 57 | self.assertEqual([4., 4., 4.], batch['pitch'][1].tolist()) 58 | self.assertEqual([5., 5., 0.], batch['pitch_cond'][0].tolist()) 59 | self.assertEqual([5., 5., 5.], batch['pitch_cond'][1].tolist()) 60 | self.assertEqual([5., 5., 0], batch['energy'][0].tolist()) 61 | self.assertEqual([5., 5., 5.], batch['energy'][1].tolist()) 62 | self.assertEqual([4.], batch['speaker_emb'][0].tolist()) 63 | self.assertEqual([5.], batch['speaker_emb'][1].tolist()) 64 | self.assertEqual('speaker_1', batch['speaker_name'][0]) 65 | self.assertEqual('speaker_2', batch['speaker_name'][1]) 66 | 67 | -------------------------------------------------------------------------------- /tests/test_dataset_filter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from utils.dataset import DurationStats, DataFilter 4 | 5 | 6 | class TestDatasetFilter(unittest.TestCase): 7 | 8 | def test_filter_happy_path(self) -> None: 9 | 10 | dur_stats = { 11 | 'id_1': DurationStats(att_align_score=1., att_sharpness_score=1., max_consecutive_ones=1, max_duration=2), 12 | 'id_2': DurationStats(att_align_score=0.5, att_sharpness_score=1., max_consecutive_ones=1, max_duration=2), 13 | 'id_3': DurationStats(att_align_score=1., att_sharpness_score=0.5, max_consecutive_ones=1, max_duration=2), 14 | 'id_4': DurationStats(att_align_score=1., att_sharpness_score=1., max_consecutive_ones=6, max_duration=2), 15 | 'id_5': DurationStats(att_align_score=1., att_sharpness_score=1., max_consecutive_ones=1, max_duration=20), 16 | } 17 | 18 | dataset = [ 19 | ('id_1', 1000), 20 | ('id_2', 1000), 21 | ('id_3', 1000), 22 | ('id_4', 1000), 23 | ('id_5', 5000), 24 | ] 25 | 26 | data_filter = DataFilter(duration_stats=dur_stats, 27 | min_attention_alignment=1., 28 | min_attention_sharpness=1., 29 | max_consecutive_duration_ones=1, 30 | max_duration=2) 31 | 32 | result = data_filter(dataset) 33 | 34 | self.assertEqual(['id_1'], [r for r, _ in result]) 35 | 36 | data_filter = DataFilter(duration_stats=dur_stats, 37 | min_attention_alignment=0., 38 | min_attention_sharpness=0., 39 | max_consecutive_duration_ones=5, 40 | max_duration=10) 41 | 42 | result = data_filter(dataset) 43 | 44 | self.assertEqual(['id_1', 'id_2', 'id_3'], [r for r, _ in result]) 45 | 46 | -------------------------------------------------------------------------------- /tests/test_dsp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | 5 | import librosa 6 | import numpy as np 7 | import torch 8 | 9 | from utils.dataset import tensor_to_ndarray 10 | from utils.dsp import DSP 11 | from utils.files import read_config 12 | import torch.nn.functional as F 13 | 14 | 15 | class TestDSP(unittest.TestCase): 16 | 17 | def setUp(self) -> None: 18 | test_path = os.path.dirname(os.path.abspath(__file__)) 19 | self.resource_path = Path(test_path) / 'resources' 20 | config = read_config(self.resource_path / 'test_config.yaml') 21 | self.dsp = DSP.from_config(config) 22 | 23 | def load_wavs(self): 24 | wav_dir_path = self.resource_path / 'wavs' 25 | wav_files = [file.resolve() for file in wav_dir_path.iterdir() if file.suffix == '.wav'] 26 | waveforms = [self.dsp.load_wav(file_path) for file_path in wav_files] 27 | waveforms = [waveform.to(self.dsp.device) for waveform in waveforms] 28 | return waveforms 29 | 30 | def test_melspectrogram(self) -> None: 31 | file = librosa.util.example('brahms') 32 | y = self.dsp.load_wav(file)[:, :10000] 33 | y = y.to(self.dsp.device) 34 | mel = self.dsp.waveform_to_mel(y) 35 | mel = tensor_to_ndarray(mel) 36 | expected = np.load(str(self.resource_path / 'test_mel.npy')) 37 | np.testing.assert_allclose(expected, mel, rtol=1e-5) 38 | 39 | def test_batched_wav_to_mel(self) -> None: 40 | # read wav files 41 | waveforms = self.load_wavs() 42 | 43 | # process in batch 44 | mels_batched = self.dsp.waveform_to_mel_batched(waveforms) 45 | 46 | # process one by one 47 | mels_single_processing = [self.dsp.waveform_to_mel(waveform) for waveform in waveforms] 48 | 49 | # compare results 50 | for mel_batched, mel_single in zip(mels_batched, mels_single_processing): 51 | mse = F.mse_loss(mel_batched, mel_single).item() 52 | self.assertLess(mse, 1e-10) 53 | 54 | def test_batched_volume_adjustment(self) -> None: 55 | # read wav files 56 | waveforms = self.load_wavs() 57 | 58 | target_dbfs = -30 59 | # process in batch 60 | normalized_batched = self.dsp.adjust_volume_batched(waveforms, target_dbfs=target_dbfs) 61 | 62 | # process one by one 63 | normalized_single_processing = [self.dsp.adjust_volume(waveform, target_dbfs=target_dbfs) for waveform in waveforms] 64 | 65 | # compare results 66 | for norm_batch, norm_single in zip(normalized_batched, normalized_single_processing): 67 | mse = F.mse_loss(norm_batch, norm_single).item() 68 | self.assertEqual(mse, 0) 69 | -------------------------------------------------------------------------------- /tests/test_duration_extraction_pipe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | from typing import Tuple, Dict 6 | from unittest.mock import patch 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from duration_extraction.duration_extraction_pipe import DurationExtractionPipeline 12 | from duration_extraction.duration_extractor import DurationExtractor 13 | from models.tacotron import Tacotron 14 | from utils.files import read_config, pickle_binary 15 | from utils.paths import Paths 16 | 17 | 18 | def new_diagonal_attention(dims: Tuple[int, int, int]) -> torch.Tensor: 19 | """ Returns perfect diagonal attention matrix, assuming that the dimensions are almost square (1, M, M) """ 20 | att = torch.zeros(dims).float() 21 | for i in range(dims[1]): 22 | j = min(i, dims[2]-1) 23 | att[:, i, j] = 1 24 | return att 25 | 26 | 27 | class MockTacotron(torch.nn.Module): 28 | 29 | def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 30 | """ We just use the mock model to get the returned diagonal attention matrix. """ 31 | mel = batch['mel'] 32 | x = batch['x'] 33 | return {'att': new_diagonal_attention((1, mel.size(-1), x.size(-1)))} 34 | 35 | 36 | class TestDurationExtractionPipe(unittest.TestCase): 37 | 38 | def setUp(self) -> None: 39 | test_path = os.path.dirname(os.path.abspath(__file__)) 40 | self.resource_path = Path(test_path) / 'resources' 41 | self.config = read_config(self.resource_path / 'test_config.yaml') 42 | self.temp_dir = TemporaryDirectory(prefix='TestDurationExtractionPipeTmp') 43 | self.paths = Paths(data_path=self.temp_dir.name + '/data', tts_id='tts_test_id') 44 | self.train_dataset = [('id_1', 5), ('id_2', 10), ('id_3', 15)] 45 | self.val_dataset = [('id_4', 6), ('id_5', 12)] 46 | pickle_binary(self.train_dataset, self.paths.train_dataset) 47 | pickle_binary(self.val_dataset, self.paths.val_dataset) 48 | self.text_dict = {file_id: 'a' * length for file_id, length in self.train_dataset + self.val_dataset} 49 | self.speaker_dict = {file_id: 'default_speaker' for file_id, _ in self.train_dataset + self.val_dataset} 50 | pickle_binary(self.text_dict, self.paths.text_dict) 51 | pickle_binary(self.speaker_dict, self.paths.speaker_dict) 52 | for id, mel_len in self.train_dataset + self.val_dataset: 53 | np.save(self.paths.mel / f'{id}.npy', np.ones((5, mel_len)), allow_pickle=False) 54 | np.save(self.paths.speaker_emb / f'{id}.npy', np.ones(1), allow_pickle=False) 55 | 56 | def tearDown(self) -> None: 57 | self.temp_dir.cleanup() 58 | 59 | @patch.object(Tacotron, '__call__', new_callable=MockTacotron) 60 | def test_extract_attentions_durations(self, mock_tacotron: Tacotron) -> None: 61 | 62 | duration_extractor = DurationExtractor(silence_threshold=-11., 63 | silence_prob_shift=0.25) 64 | 65 | duration_extraction_pipe = DurationExtractionPipeline(paths=self.paths, config=self.config, 66 | duration_extractor=duration_extractor) 67 | 68 | avg_att_score = duration_extraction_pipe.extract_attentions(model=mock_tacotron, max_batch_size=1) 69 | self.assertEqual(1., avg_att_score) 70 | att_files = list(self.paths.att_pred.glob('**/*.npy')) 71 | self.assertEqual(5, len(att_files)) 72 | 73 | for item_id, mel_len in (self.train_dataset + self.val_dataset): 74 | att = np.load(self.paths.att_pred / f'{item_id}.npy') 75 | x = self.text_dict[item_id] 76 | expected_att_size = (mel_len, len(x)) 77 | self.assertEqual(expected_att_size, att.shape) 78 | 79 | dur_stats = duration_extraction_pipe.extract_durations(num_workers=1, sampler_bin_size=1) 80 | 81 | for file_id, dur_stat in dur_stats.items(): 82 | self.assertEqual(dur_stat.att_align_score, 1.) 83 | self.assertEqual(dur_stat.att_sharpness_score, 1.) 84 | self.assertTrue(3 <= dur_stat.max_consecutive_ones <= 15) 85 | self.assertEqual(dur_stat.max_duration, 1) 86 | 87 | dur_files = list(self.paths.alg.glob('**/*.npy')) 88 | self.assertEqual(5, len(dur_files)) 89 | 90 | for dur_file in dur_files: 91 | dur = np.load(dur_file) 92 | # We expect durations of one due to the diagonal attention. 93 | expected = np.ones(len(dur)) 94 | np.testing.assert_allclose(expected, dur, rtol=1e-8) 95 | -------------------------------------------------------------------------------- /tests/test_duration_extractor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | from duration_extraction.duration_extractor import DurationExtractor 7 | 8 | 9 | def new_diagonal_attention(dims: Tuple[int, int]) -> torch.Tensor: 10 | att = torch.zeros(dims).float() 11 | for i in range(dims[0]): 12 | att[i, i//2] = 1 13 | return att 14 | 15 | 16 | class TestDurationExtractor(unittest.TestCase): 17 | 18 | def setUp(self) -> None: 19 | pass 20 | 21 | def test_extract_happy_path(self) -> None: 22 | x = torch.tensor([15, 16, 10, 17, 18]).long() 23 | mel = torch.full((80, 10), fill_value=-10).float() 24 | attention = new_diagonal_attention((10, 5)) 25 | duration_extractor = DurationExtractor(silence_threshold=-11., 26 | silence_prob_shift=0.) 27 | durs, att_score = duration_extractor(x=x, mel=mel, attention=attention) 28 | expected = [2., 2., 2., 2., 2] 29 | self.assertEqual(expected, durs.tolist()) 30 | 31 | def test_extract_with_silent_part(self) -> None: 32 | """ Test extraction for mel with silent part that suffers from fuzzy attention. """ 33 | 34 | x = torch.tensor([15, 16, 10, 17, 18]).long() 35 | 36 | # Mock up mel that has silence at indices 4:6 37 | mel = torch.full((80, 10), fill_value=-10).float() 38 | mel[:, 4:6] = -11.51 39 | 40 | # Mock up simple diagonal attention which is fuzzy at mel indices 3:5, exactly where the model 41 | # should look at x[2], which is a silent token (token_index=10, which is a whitespace) 42 | attention = new_diagonal_attention((10, 5)) 43 | attention[3:5, :] = 1./len(x) 44 | 45 | # duration extractor with no probability shift delivers larger durations after the pause (at index=3) 46 | duration_extractor = DurationExtractor(silence_threshold=-11., 47 | silence_prob_shift=0.) 48 | durs, att_score = duration_extractor(x=x, mel=mel, attention=attention) 49 | expected = [2., 3., 1., 2., 2] 50 | self.assertEqual(expected, durs.tolist()) 51 | 52 | # duration extractor with some probability shift delivers larger durations during the pause (at index=2) 53 | duration_extractor = DurationExtractor(silence_threshold=-11., 54 | silence_prob_shift=0.25) 55 | durs, att_score = duration_extractor(x=x, mel=mel, attention=attention) 56 | expected = [2., 2., 2., 2., 2] 57 | self.assertEqual(expected, durs.tolist()) 58 | -------------------------------------------------------------------------------- /tests/test_forward_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | 5 | import numpy as np 6 | 7 | from utils.dataset import ForwardDataset 8 | from utils.paths import Paths 9 | from utils.text.tokenizer import Tokenizer 10 | 11 | 12 | class TestForwardDataset(unittest.TestCase): 13 | 14 | def setUp(self) -> None: 15 | self.temp_dir = TemporaryDirectory(prefix='TestForwardDatasetTmp') 16 | 17 | def tearDown(self) -> None: 18 | self.temp_dir.cleanup() 19 | 20 | def test_get_items(self) -> None: 21 | text_dict = {'0': 'a', '1': 'bc'} 22 | speaker_dict = {'0': 'speaker_0', '1': 'speaker_1'} 23 | data_dir = Path(self.temp_dir.name + '/data') 24 | paths = Paths(data_path=data_dir, tts_id='test_forward') 25 | paths.data = data_dir 26 | 27 | mels = [np.full((2, 2), fill_value=1), np.full((2, 3), fill_value=2)] 28 | durs = [np.full(1, fill_value=2), np.full(2, fill_value=3)] 29 | pitches = [np.full(1, fill_value=5), np.full(2, fill_value=6)] 30 | energies = [np.full(1, fill_value=6), np.full(2, fill_value=7)] 31 | speaker_embs = [np.full(1, fill_value=6), np.full(1, fill_value=7)] 32 | 33 | for i in range(2): 34 | np.save(str(paths.mel / f'{i}.npy'), mels[i]) 35 | np.save(str(paths.alg / f'{i}.npy'), durs[i]) 36 | np.save(str(paths.phon_pitch / f'{i}.npy'), pitches[i]) 37 | np.save(str(paths.phon_energy / f'{i}.npy'), energies[i]) 38 | np.save(str(paths.speaker_emb / f'{i}.npy'), speaker_embs[i]) 39 | 40 | dataset = ForwardDataset(paths=paths, 41 | dataset_ids=['0', '1'], 42 | text_dict=text_dict, 43 | speaker_dict=speaker_dict, 44 | tokenizer=Tokenizer()) 45 | 46 | data = [dataset[i] for i in range(len(dataset))] 47 | 48 | np.testing.assert_allclose(data[0]['mel'], mels[0], rtol=1e-10) 49 | np.testing.assert_allclose(data[1]['mel'], mels[1], rtol=1e-10) 50 | np.testing.assert_allclose(data[0]['dur'], durs[0], rtol=1e-10) 51 | np.testing.assert_allclose(data[1]['dur'], durs[1], rtol=1e-10) 52 | np.testing.assert_allclose(data[0]['pitch'], pitches[0], rtol=1e-10) 53 | np.testing.assert_allclose(data[1]['pitch'], pitches[1], rtol=1e-10) 54 | np.testing.assert_allclose(data[0]['energy'], energies[0], rtol=1e-10) 55 | np.testing.assert_allclose(data[1]['energy'], energies[1], rtol=1e-10) 56 | np.testing.assert_allclose(data[0]['speaker_emb'], speaker_embs[0], rtol=1e-10) 57 | np.testing.assert_allclose(data[1]['speaker_emb'], speaker_embs[1], rtol=1e-10) 58 | 59 | self.assertEqual(1, data[0]['x_len']) 60 | self.assertEqual(2, data[1]['x_len']) 61 | self.assertEqual('0', data[0]['item_id']) 62 | self.assertEqual('1', data[1]['item_id']) 63 | self.assertEqual(2, data[0]['mel_len']) 64 | self.assertEqual(3, data[1]['mel_len']) 65 | self.assertEqual('speaker_0', data[0]['speaker_name']) 66 | self.assertEqual('speaker_1', data[1]['speaker_name']) 67 | -------------------------------------------------------------------------------- /tests/test_forward_tacotron.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from models.forward_tacotron import ForwardTacotron 8 | from utils.files import read_config 9 | 10 | 11 | class TestForwardTacotron(unittest.TestCase): 12 | 13 | def setUp(self) -> None: 14 | test_path = os.path.dirname(os.path.abspath(__file__)) 15 | self.base_path = Path(test_path).parent 16 | config = read_config(self.base_path / 'configs/singlespeaker.yaml') 17 | self.model = ForwardTacotron.from_config(config) 18 | 19 | def test_forward(self) -> None: 20 | 21 | batch = { 22 | 'dur': torch.full((2, 10), fill_value=2).long(), 23 | 'mel': torch.ones((2, 80, 20)).float(), 24 | 'x': torch.ones((2, 10)).long(), 25 | 'speaker_emb': torch.ones((2, 256)).float(), 26 | 'mel_len': torch.full((2, ), fill_value=20).long(), 27 | 'pitch': torch.ones((2, 10)).float(), 28 | 'energy': torch.ones((2, 10)).float(), 29 | 'pitch_cond': torch.ones((2, 10)).long(), 30 | } 31 | 32 | pred = self.model(batch) 33 | 34 | self.assertEqual({'mel', 'mel_post', 'dur', 'pitch', 'energy'}, pred.keys()) 35 | self.assertEqual((2, 80, 20), pred['mel_post'].size()) 36 | self.assertEqual((2, 10), pred['dur'].size()) 37 | self.assertEqual((2, 1, 10), pred['pitch'].size()) 38 | self.assertEqual((2, 1, 10), pred['energy'].size()) 39 | 40 | def test_generate(self) -> None: 41 | gen = self.model.generate(x=torch.ones((1, 10)).long()) 42 | self.assertEqual({'mel', 'mel_post', 'dur', 'pitch', 'energy'}, gen.keys()) 43 | self.assertEqual(80, gen['mel_post'].size(1)) 44 | self.assertEqual((1, 10), gen['dur'].size()) 45 | self.assertEqual((1, 1, 10), gen['pitch'].size()) 46 | self.assertEqual((1, 1, 10), gen['energy'].size()) 47 | -------------------------------------------------------------------------------- /tests/test_guided_attention_matrix.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from trainer.common import new_guided_attention_matrix 7 | 8 | 9 | class TestGuidedAttentionMatrix(unittest.TestCase): 10 | 11 | def test_happy_path(self): 12 | attention = torch.ones(1, 4, 3) # Example input tensor with shape (batch_size, T, N) 13 | g = 1.0 14 | dia_mat = new_guided_attention_matrix(attention, g) 15 | 16 | expected = torch.tensor([ 17 | [[1., 0.9460, 0.8007], 18 | [0.9692, 0.9965, 0.9169], 19 | [0.8825, 0.9862, 0.9862], 20 | [0.7548, 0.9169, 0.9965]] 21 | ], dtype=torch.float32) 22 | 23 | np.testing.assert_allclose(expected.numpy(), dia_mat.numpy(), atol=1e-4) -------------------------------------------------------------------------------- /tests/test_multi_forward_tacotron.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | 5 | import torch 6 | 7 | from models.multi_forward_tacotron import MultiForwardTacotron 8 | from utils.files import read_config 9 | 10 | 11 | class TestMultiForwardTacotron(unittest.TestCase): 12 | 13 | def setUp(self) -> None: 14 | test_path = os.path.dirname(os.path.abspath(__file__)) 15 | self.base_path = Path(test_path).parent 16 | config = read_config(self.base_path / 'configs/multispeaker.yaml') 17 | self.model = MultiForwardTacotron.from_config(config) 18 | 19 | def test_forward(self) -> None: 20 | """ Simple test of happy path forward pass, useful for debugging purposes. """ 21 | 22 | batch = { 23 | 'dur': torch.full((2, 10), fill_value=2).long(), 24 | 'mel': torch.ones((2, 80, 20)).float(), 25 | 'x': torch.ones((2, 10)).long(), 26 | 'speaker_emb': torch.ones((2, 256)).float(), 27 | 'mel_len': torch.full((2, ), fill_value=20).long(), 28 | 'pitch': torch.ones((2, 10)).float(), 29 | 'energy': torch.ones((2, 10)).float(), 30 | 'pitch_cond': torch.ones((2, 10)).long(), 31 | } 32 | 33 | pred = self.model(batch) 34 | 35 | self.assertEqual({'mel', 'mel_post', 'dur', 'pitch', 'energy', 'pitch_cond'}, pred.keys()) 36 | self.assertEqual((2, 80, 20), pred['mel_post'].size()) 37 | self.assertEqual((2, 10), pred['dur'].size()) 38 | self.assertEqual((2, 1, 10), pred['pitch'].size()) 39 | self.assertEqual((2, 1, 10), pred['energy'].size()) 40 | self.assertEqual((2, 10, 3), pred['pitch_cond'].size()) 41 | 42 | def test_generate(self) -> None: 43 | gen = self.model.generate(x=torch.ones((1, 10)).long(), speaker_emb=torch.ones((1, 256))) 44 | self.assertEqual({'mel', 'mel_post', 'dur', 'pitch', 'energy', 'pitch_cond'}, gen.keys()) 45 | self.assertEqual(80, gen['mel_post'].size(1)) 46 | self.assertEqual((1, 10), gen['dur'].size()) 47 | self.assertEqual((1, 1, 10), gen['pitch'].size()) 48 | self.assertEqual((1, 1, 10), gen['energy'].size()) 49 | self.assertEqual((1, 1, 10), gen['pitch_cond'].size()) 50 | -------------------------------------------------------------------------------- /tests/test_recipes.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import unittest 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | 7 | from utils.text.recipes import read_ljspeech_format, read_pandas_format 8 | 9 | 10 | # These tests have been written by ChatGPT and slightly adjusted 11 | 12 | 13 | class TestRecipes(unittest.TestCase): 14 | 15 | def test_read_ljspeech_format_multi(self): 16 | test_data = "file_id_1|speaker1|Text1\nfile_id_2|speaker2|Text2\nfile_id_3|Text3" 17 | 18 | with tempfile.NamedTemporaryFile(mode='w+') as f: 19 | f.write(test_data) 20 | f.seek(0) 21 | path = Path(f.name) 22 | 23 | text_dict, speaker_dict = read_ljspeech_format(path, multispeaker=True) 24 | 25 | self.assertEqual({'file_id_1': 'Text1', 'file_id_2': 'Text2', 'file_id_3': 'Text3'}, text_dict) 26 | self.assertEqual({'file_id_1': 'speaker1', 'file_id_2': 'speaker2', 27 | 'file_id_3': 'default_speaker'}, speaker_dict) 28 | 29 | def test_read_ljspeech_format(self): 30 | test_data = "file_id_1|Text1\nfile_id_2|Text2" 31 | with tempfile.NamedTemporaryFile(mode='w+') as f: 32 | f.write(test_data) 33 | f.seek(0) 34 | path = Path(f.name) 35 | 36 | text_dict, speaker_dict = read_ljspeech_format(path, multispeaker=False) 37 | 38 | self.assertEqual({'file_id_1': 'Text1', 'file_id_2': 'Text2'}, text_dict) 39 | self.assertEqual({'file_id_1': 'default_speaker', 'file_id_2': 'default_speaker'}, speaker_dict) 40 | 41 | def test_read_pandas_format(self): 42 | test_data = {'file_id': ['file1', 'file2', 'file3'], 43 | 'text': ['This is a test', 'This is another test', 'A third test'], 44 | 'speaker_id': ['speaker1', 'speaker2', 'speaker3']} 45 | df = pd.DataFrame(test_data) 46 | 47 | with tempfile.NamedTemporaryFile(mode='w+') as f: 48 | csv_path = Path(f.name) 49 | df.to_csv(csv_path, sep='\t', encoding='utf-8', index=False) 50 | 51 | # Test the function with the created file 52 | text_dict, speaker_dict = read_pandas_format(csv_path) 53 | self.assertEqual({'file1': 'This is a test', 'file2': 'This is another test', 'file3': 'A third test'}, text_dict) 54 | self.assertEqual({'file1': 'speaker1', 'file2': 'speaker2', 'file3': 'speaker3'}, speaker_dict) -------------------------------------------------------------------------------- /tests/test_taco_binned_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | 6 | import numpy as np 7 | 8 | from utils.dataset import BinnedTacoDataLoader 9 | from utils.files import read_config, pickle_binary 10 | from utils.paths import Paths 11 | 12 | 13 | class TestTacoBinnedDataloader(unittest.TestCase): 14 | 15 | def setUp(self) -> None: 16 | test_path = os.path.dirname(os.path.abspath(__file__)) 17 | self.resource_path = Path(test_path) / 'resources' 18 | self.config = read_config(self.resource_path / 'test_config.yaml') 19 | self.temp_dir = TemporaryDirectory(prefix='forwardtaco_data_test_temp') 20 | self.paths = Paths(data_path=self.temp_dir.name + '/data', tts_id='tts_test_id') 21 | self.dataset = [('id_1', 2), ('id_2', 2), ('id_3', 3), ('id_4', 4), ('id_5', 4), ('id_6', 4)] 22 | self.text_dict = {'id_1': 'aa', 'id_2': 'aa', 'id_3': 'aaa', 'id_4': 'aaaa', 'id_5': 'aaaa', 'id_6': 'aaaa'} 23 | self.speaker_dict = {'id_1': 'speaker_1', 'id_2': 'speaker_2', 'id_3': 'speaker_3', 'id_4': 'speaker_4', 24 | 'id_5': 'speaker_5', 'id_6': 'speaker_6'} 25 | pickle_binary(self.text_dict, self.paths.text_dict) 26 | pickle_binary(self.speaker_dict, self.paths.speaker_dict) 27 | for id, mel_len in self.dataset: 28 | np.save(self.paths.mel / f'{id}.npy', np.ones((5, mel_len)), allow_pickle=False) 29 | np.save(self.paths.speaker_emb / f'{id}.npy', np.ones((1, mel_len)), allow_pickle=False) 30 | 31 | def tearDown(self) -> None: 32 | self.temp_dir.cleanup() 33 | 34 | def test_get_items(self) -> None: 35 | dataloader = BinnedTacoDataLoader(paths=self.paths, 36 | dataset=self.dataset, 37 | max_batch_size=2) 38 | 39 | expected_xlen_batches = [[2, 2], [3], [4, 4], [4]] 40 | self.assertEqual(len(expected_xlen_batches), len(dataloader)) 41 | 42 | batches = [d for d in dataloader] 43 | batches.sort(key=lambda b: b['x_len'][0]) 44 | 45 | actual_x_len_batches = [b['x_len'].tolist() for b in batches] 46 | self.assertEqual(expected_xlen_batches, actual_x_len_batches) 47 | 48 | -------------------------------------------------------------------------------- /tests/test_taco_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | 5 | import numpy as np 6 | 7 | from utils.dataset import TacoDataset 8 | from utils.paths import Paths 9 | from utils.text.tokenizer import Tokenizer 10 | 11 | 12 | class TestForwardDataset(unittest.TestCase): 13 | 14 | def setUp(self) -> None: 15 | self.temp_dir = TemporaryDirectory(prefix='TestForwardDatasetTmp') 16 | 17 | def tearDown(self) -> None: 18 | self.temp_dir.cleanup() 19 | 20 | def test_get_items(self) -> None: 21 | text_dict = {'0': 'a', '1': 'bc'} 22 | speaker_dict = {'0': 'speaker_0', '1': 'speaker_1'} 23 | data_dir = Path(self.temp_dir.name + '/data') 24 | paths = Paths(data_path=data_dir, tts_id='test_forward') 25 | paths.data = data_dir 26 | 27 | mels = [np.full((2, 2), fill_value=1), np.full((2, 3), fill_value=2)] 28 | speaker_embs = [np.full(1, fill_value=6), np.full(1, fill_value=7)] 29 | 30 | for i in range(2): 31 | np.save(str(paths.mel / f'{i}.npy'), mels[i]) 32 | np.save(str(paths.speaker_emb / f'{i}.npy'), speaker_embs[i]) 33 | 34 | dataset = TacoDataset(paths=paths, 35 | dataset_ids=['0', '1'], 36 | text_dict=text_dict, 37 | speaker_dict=speaker_dict, 38 | tokenizer=Tokenizer()) 39 | 40 | data = [dataset[i] for i in range(len(dataset))] 41 | 42 | np.testing.assert_allclose(data[0]['mel'], mels[0], rtol=1e-10) 43 | np.testing.assert_allclose(data[1]['mel'], mels[1], rtol=1e-10) 44 | np.testing.assert_allclose(data[0]['speaker_emb'], speaker_embs[0], rtol=1e-10) 45 | np.testing.assert_allclose(data[1]['speaker_emb'], speaker_embs[1], rtol=1e-10) 46 | 47 | self.assertEqual(1, data[0]['x_len']) 48 | self.assertEqual(2, data[1]['x_len']) 49 | self.assertEqual('0', data[0]['item_id']) 50 | self.assertEqual('1', data[1]['item_id']) 51 | self.assertEqual(2, data[0]['mel_len']) 52 | self.assertEqual(3, data[1]['mel_len']) 53 | self.assertEqual('speaker_0', data[0]['speaker_name']) 54 | self.assertEqual('speaker_1', data[1]['speaker_name']) 55 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from utils.text.tokenizer import Tokenizer 4 | 5 | 6 | class TestTokenizer(unittest.TestCase): 7 | 8 | def test_call_happy_path(self) -> None: 9 | tokenizer = Tokenizer() 10 | tokens = tokenizer('_ abc{') 11 | self.assertEqual([0, 10, 36, 52, 57], tokens) 12 | 13 | decoded = tokenizer.decode(tokens) 14 | self.assertEqual('_ abc', decoded) -------------------------------------------------------------------------------- /train_forward.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import torch 9 | from torch import optim 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | from models.fast_pitch import FastPitch 13 | from models.forward_tacotron import ForwardTacotron 14 | from trainer.common import to_device 15 | from trainer.forward_trainer import ForwardTrainer 16 | from trainer.multi_forward_trainer import MultiForwardTrainer 17 | from utils.checkpoints import restore_checkpoint, init_tts_model 18 | from utils.dataset import get_forward_dataloaders 19 | from utils.display import * 20 | from utils.dsp import DSP 21 | from utils.files import read_config 22 | from utils.paths import Paths 23 | 24 | 25 | def try_get_git_hash() -> Union[str, None]: 26 | try: 27 | return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() 28 | except Exception as e: 29 | print(f'Could not retrieve git hash! {e}') 30 | return None 31 | 32 | 33 | def create_gta_features(model: Union[ForwardTacotron, FastPitch], 34 | train_set: DataLoader, 35 | val_set: DataLoader, 36 | save_path: Path) -> None: 37 | model.eval() 38 | device = next(model.parameters()).device # use same device as model parameters 39 | iters = len(train_set) + len(val_set) 40 | dataset = itertools.chain(train_set, val_set) 41 | for i, batch in enumerate(dataset, 1): 42 | batch = to_device(batch, device=device) 43 | with torch.no_grad(): 44 | pred = model(batch) 45 | gta = pred['mel_post'].cpu().numpy() 46 | for j, item_id in enumerate(batch['item_id']): 47 | mel = gta[j][:, :batch['mel_len'][j]] 48 | np.save(str(save_path/f'{item_id}.npy'), mel, allow_pickle=False) 49 | bar = progbar(i, iters) 50 | msg = f'{bar} {i}/{iters} Batches ' 51 | stream(msg) 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = argparse.ArgumentParser(description='Train ForwardTacotron TTS') 56 | parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features') 57 | parser.add_argument('--config', metavar='FILE', default='configs/singlespeaker.yaml', help='The config containing all hyperparams.') 58 | args = parser.parse_args() 59 | 60 | config = read_config(args.config) 61 | if 'git_hash' not in config or config['git_hash'] is None: 62 | config['git_hash'] = try_get_git_hash() 63 | dsp = DSP.from_config(config) 64 | paths = Paths(config['data_path'], config['tts_model_id']) 65 | 66 | assert len(os.listdir(paths.alg)) > 0, f'Could not find alignment files in {paths.alg}, please predict ' \ 67 | f'alignments first with python train_tacotron.py --force_align!' 68 | 69 | force_gta = args.force_gta 70 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 71 | print('Using device:', device) 72 | 73 | # Instantiate Forward TTS Model 74 | model = init_tts_model(config).to(device) 75 | print(f'\nInitialized tts model: {model}\n') 76 | optimizer = optim.Adam(model.parameters()) 77 | restore_checkpoint(model=model, optim=optimizer, 78 | path=paths.forward_checkpoints / 'latest_model.pt', 79 | device=device) 80 | 81 | if force_gta: 82 | print('Creating Ground Truth Aligned Dataset...\n') 83 | train_set, val_set = get_forward_dataloaders( 84 | paths=paths, batch_size=8, **config['training']['filter']) 85 | create_gta_features(model, train_set, val_set, paths.gta) 86 | elif config['tts_model'] in ['multi_forward_tacotron', 'multi_fast_pitch']: 87 | trainer = MultiForwardTrainer(paths=paths, dsp=dsp, config=config) 88 | trainer.train(model, optimizer) 89 | else: 90 | trainer = ForwardTrainer(paths=paths, dsp=dsp, config=config) 91 | trainer.train(model, optimizer) 92 | 93 | -------------------------------------------------------------------------------- /train_tacotron.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | from pathlib import Path 4 | from typing import Tuple, Dict, Any 5 | 6 | import torch 7 | from torch import optim 8 | from torch.utils.data.dataloader import DataLoader 9 | from tqdm import tqdm 10 | 11 | from duration_extraction.duration_extraction_pipe import DurationExtractionPipeline 12 | from duration_extraction.duration_extractor import DurationExtractor 13 | from models.tacotron import Tacotron 14 | from trainer.common import to_device 15 | from trainer.taco_trainer import TacoTrainer 16 | from utils.checkpoints import restore_checkpoint 17 | from utils.dataset import get_taco_dataloaders 18 | from utils.display import * 19 | from utils.dsp import DSP 20 | from utils.files import pickle_binary, unpickle_binary, read_config 21 | from utils.paths import Paths 22 | 23 | 24 | def normalize_values(phoneme_val): 25 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] 26 | for item_id, v in phoneme_val]) 27 | mean, std = np.mean(nonzeros), np.std(nonzeros) 28 | if not std > 0: 29 | std = 1e10 30 | for item_id, v in phoneme_val: 31 | zero_idxs = np.where(v == 0.0)[0] 32 | v -= mean 33 | v /= std 34 | v[zero_idxs] = 0.0 35 | return mean, std 36 | 37 | 38 | # adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/ 39 | # 0b27e359a5869cd23294c1707c92f989c0bf201e/PyTorch/SpeechSynthesis/FastPitch/extract_mels.py 40 | def extract_pitch_energy(save_path_pitch: Path, 41 | save_path_energy: Path, 42 | pitch_min_freq: float, 43 | pitch_max_freq: float) -> Tuple[float, float]: 44 | speaker_dict = unpickle_binary(paths.speaker_dict) 45 | 46 | 47 | speaker_names = set([v for v in speaker_dict.values() if len(v) > 1]) 48 | mean, var = 0, 0 49 | 50 | train_data = unpickle_binary(paths.train_dataset) 51 | val_data = unpickle_binary(paths.val_dataset) 52 | all_data = train_data + val_data 53 | 54 | for speaker_name in tqdm(speaker_names, total=len(speaker_names), smoothing=0.1): 55 | all_data_speaker = [(item_id, mel_len) for item_id, mel_len in all_data if speaker_dict[item_id] == speaker_name] 56 | phoneme_pitches = [] 57 | phoneme_energies = [] 58 | for prog_idx, (item_id, mel_len) in enumerate(all_data_speaker, 1): 59 | try: 60 | dur = np.load(paths.alg / f'{item_id}.npy') 61 | mel = np.load(paths.mel / f'{item_id}.npy') 62 | energy = np.linalg.norm(np.exp(mel), axis=0, ord=2) 63 | assert np.sum(dur) == mel_len 64 | pitch = np.load(paths.raw_pitch / f'{item_id}.npy') 65 | durs_cum = np.cumsum(np.pad(dur, (1, 0))) 66 | pitch_char = np.zeros((dur.shape[0],), dtype=np.float32) 67 | energy_char = np.zeros((dur.shape[0],), dtype=np.float32) 68 | for idx, a, b in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]): 69 | values = pitch[a:b][np.where(pitch[a:b] != 0.0)[0]] 70 | values = values[np.where((values >= pitch_min_freq) & (values <= pitch_max_freq))[0]] 71 | pitch_char[idx] = np.mean(values) if len(values) > 0 else 0.0 72 | energy_values = energy[a:b] 73 | energy_char[idx] = np.mean(energy_values)if len(energy_values) > 0 else 0.0 74 | phoneme_pitches.append((item_id, pitch_char)) 75 | phoneme_energies.append((item_id, energy_char)) 76 | bar = progbar(prog_idx, len(all_data)) 77 | msg = f'{bar} {prog_idx}/{len(all_data_speaker )} Files ' 78 | stream(msg) 79 | except Exception as e: 80 | print(e) 81 | 82 | for item_id, phoneme_energy in phoneme_energies: 83 | np.save(str(save_path_energy / f'{item_id}.npy'), phoneme_energy, allow_pickle=False) 84 | 85 | mean, var = normalize_values(phoneme_pitches) 86 | for item_id, phoneme_pitch in phoneme_pitches: 87 | np.save(str(save_path_pitch / f'{item_id}.npy'), phoneme_pitch, allow_pickle=False) 88 | 89 | return mean, var 90 | 91 | 92 | def create_gta_features(model: Tacotron, 93 | train_set: DataLoader, 94 | val_set: DataLoader, 95 | save_path: Path): 96 | model.eval() 97 | device = next(model.parameters()).device # use same device as model parameters 98 | iters = len(train_set) + len(val_set) 99 | dataset = itertools.chain(train_set, val_set) 100 | for i, batch in enumerate(dataset, 1): 101 | batch = to_device(batch, device=device) 102 | with torch.no_grad(): 103 | _, gta, _ = model(batch) 104 | gta = gta.cpu().numpy() 105 | for j, item_id in enumerate(batch['item_id']): 106 | mel = gta[j][:, :batch['mel_len'][j]] 107 | np.save(str(save_path/f'{item_id}.npy'), mel, allow_pickle=False) 108 | bar = progbar(i, iters) 109 | msg = f'{bar} {i}/{iters} Batches ' 110 | stream(msg) 111 | 112 | 113 | def create_align_features(model: Tacotron, 114 | paths: Paths, 115 | config: Dict[str, Any]) -> None: 116 | 117 | assert model.r == 1, f'Reduction factor of tacotron must be 1 for creating alignment features! ' \ 118 | f'Reduction factor was: {model.r}' 119 | model.eval() 120 | model.decoder.prenet.train() 121 | 122 | dur_extr_conf = config['duration_extraction'] 123 | 124 | duration_extractor = DurationExtractor(silence_threshold=dur_extr_conf['silence_threshold'], 125 | silence_prob_shift=dur_extr_conf['silence_prob_shift']) 126 | 127 | duration_extraction_pipe = DurationExtractionPipeline(paths=paths, config=config, 128 | duration_extractor=duration_extractor) 129 | 130 | print('Extracting attention matrices from tacotron...') 131 | duration_extraction_pipe.extract_attentions(model, max_batch_size=dur_extr_conf['max_batch_size']) 132 | 133 | num_workers = dur_extr_conf['num_workers'] 134 | print(f'Extracting durations from attention matrices (num workers={num_workers})...') 135 | duration_stats = duration_extraction_pipe.extract_durations(num_workers=num_workers, 136 | sampler_bin_size=num_workers*4) 137 | pickle_binary(duration_stats, paths.duration_stats) 138 | 139 | print('Extracting Pitch Values...') 140 | extract_pitch_energy(save_path_pitch=paths.phon_pitch, 141 | save_path_energy=paths.phon_energy, 142 | pitch_min_freq=config['preprocessing']['pitch_min_freq'], 143 | pitch_max_freq=config['preprocessing']['pitch_max_freq']) 144 | 145 | 146 | if __name__ == '__main__': 147 | parser = argparse.ArgumentParser(description='Train Tacotron TTS') 148 | parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features') 149 | parser.add_argument('--force_align', '-a', action='store_true', help='Force the model to create attention alignment features') 150 | parser.add_argument('--extract_pitch', '-p', action='store_true', help='Extracts phoneme-pitch values only') 151 | parser.add_argument('--config', metavar='FILE', default='configs/singlespeaker.yaml', help='The config containing all hyperparams.') 152 | 153 | args = parser.parse_args() 154 | config = read_config(args.config) 155 | dsp = DSP.from_config(config) 156 | paths = Paths(config['data_path'], config['tts_model_id']) 157 | 158 | if args.extract_pitch: 159 | print('Extracting Pitch and Energy Values...') 160 | mean, var = extract_pitch_energy(save_path_pitch=paths.phon_pitch, 161 | save_path_energy=paths.phon_energy, 162 | pitch_min_freq=config['preprocessing']['pitch_min_freq'], 163 | pitch_max_freq=config['preprocessing']['pitch_max_freq']) 164 | print('\n\nYou can now train ForwardTacotron - use python train_forward.py\n') 165 | exit() 166 | 167 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 168 | print('Using device:', device) 169 | 170 | # Instantiate Tacotron Model 171 | print('\nInitialising Tacotron Model...\n') 172 | model = Tacotron.from_config(config).to(device) 173 | 174 | optimizer = optim.Adam(model.parameters()) 175 | restore_checkpoint(model=model, optim=optimizer, 176 | path=paths.taco_checkpoints / 'latest_model.pt', 177 | device=device) 178 | 179 | train_cfg = config['tacotron']['training'] 180 | if args.force_gta: 181 | print('Creating Ground Truth Aligned Dataset...\n') 182 | train_set, val_set = get_taco_dataloaders(paths.data, 1, model.r, **train_cfg['filter']) 183 | create_gta_features(model, train_set, val_set, paths.gta) 184 | print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n') 185 | elif args.force_align: 186 | print('Creating Attention Alignments and Pitch Values...') 187 | train_set, val_set = get_taco_dataloaders(paths, 1, model.r, **train_cfg['filter']) 188 | create_align_features(model=model, config=config, paths=paths) 189 | print('\n\nYou can now train ForwardTacotron - use python train_forward.py\n') 190 | else: 191 | trainer = TacoTrainer(paths, config=config, dsp=dsp) 192 | trainer.train(model, optimizer) 193 | print('Training finished, now creating Attention Alignments and Pitch Values...') 194 | train_set, val_set = get_taco_dataloaders(paths, 1, model.r, **train_cfg['filter']) 195 | create_align_features(model=model, config=config, paths=paths) 196 | print('\n\nYou can now train ForwardTacotron - use python train_forward.py\n') 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/common.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data.dataloader import DataLoader 6 | 7 | 8 | class TTSSession: 9 | 10 | def __init__(self, 11 | index: int, 12 | r: int, 13 | lr: int, 14 | max_step: int, 15 | bs: int, 16 | train_set: DataLoader, 17 | val_set: DataLoader) -> None: 18 | """ Container for TTS training variables. """ 19 | 20 | self.index = index 21 | self.r = r 22 | self.lr = lr 23 | self.max_step = max_step 24 | self.bs = bs 25 | self.train_set = train_set 26 | self.val_set = val_set 27 | self.val_sample = next(iter(val_set)) 28 | 29 | 30 | class VocSession: 31 | 32 | def __init__(self, 33 | index: int, 34 | lr: int, 35 | max_step: int, 36 | bs: int, 37 | train_set: DataLoader, 38 | val_set: list, 39 | val_set_samples: list) -> None: 40 | """ Container for WaveRNN training variables. """ 41 | 42 | self.index = index 43 | self.lr = lr 44 | self.max_step = max_step 45 | self.bs = bs 46 | self.train_set = train_set 47 | self.val_set = val_set 48 | self.val_set_samples = val_set_samples 49 | 50 | 51 | class Averager: 52 | 53 | def __init__(self) -> None: 54 | self.count = 0 55 | self.val = 0. 56 | 57 | def add(self, val: float) -> None: 58 | self.val += float(val) 59 | self.count += 1 60 | 61 | def reset(self) -> None: 62 | self.val = 0. 63 | self.count = 0 64 | 65 | def get(self) -> float: 66 | return self.val / self.count if self.count > 0. else 0. 67 | 68 | 69 | class MaskedL1(torch.nn.Module): 70 | 71 | def forward(self, x, target, lens): 72 | target.requires_grad = False 73 | max_len = target.size(2) 74 | mask = pad_mask(lens, max_len) 75 | mask = mask.unsqueeze(1).expand_as(x) 76 | loss = F.l1_loss( 77 | x * mask, target * mask, reduction='sum') 78 | return loss / mask.sum() 79 | 80 | 81 | class ForwardSumLoss(torch.nn.Module): 82 | 83 | def __init__(self, blank_logprob=-1): 84 | super(ForwardSumLoss, self).__init__() 85 | self.log_softmax = torch.nn.LogSoftmax(dim=3) 86 | self.blank_logprob = blank_logprob 87 | self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True) 88 | 89 | def forward(self, 90 | attn_logprob: torch.Tensor, 91 | text_lens: torch.Tensor, 92 | mel_lens: torch.Tensor) -> torch.Tensor: 93 | 94 | # The CTC loss module assumes the existence of a blank token 95 | # that can be optionally inserted anywhere in the sequence for 96 | # a fixed probability. 97 | # A row must be added to the attention matrix to account for this 98 | attn_logprob_padded = F.pad(input=attn_logprob, 99 | pad=(1, 0, 0, 0, 0, 0), 100 | value=self.blank_logprob) 101 | batch_size = attn_logprob.size(0) 102 | steps = attn_logprob.size(-1) 103 | target_seq = torch.arange(1, steps+1).expand(batch_size, steps) 104 | attn_logprob_padded = attn_logprob_padded.permute(1, 0, 2) 105 | attn_logprob_padded = attn_logprob_padded.log_softmax(-1) 106 | cost = self.CTCLoss(attn_logprob_padded, 107 | target_seq, 108 | input_lengths=mel_lens, 109 | target_lengths=text_lens) 110 | return cost 111 | 112 | # Adapted from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 113 | def pad_mask(lens, max_len): 114 | batch_size = lens.size(0) 115 | seq_range = torch.arange(0, max_len).long() 116 | seq_range = seq_range.unsqueeze(0) 117 | seq_range = seq_range.expand(batch_size, max_len) 118 | if lens.is_cuda: 119 | seq_range = seq_range.cuda() 120 | lens = lens.unsqueeze(1) 121 | lens = lens.expand_as(seq_range) 122 | mask = seq_range < lens 123 | return mask.float() 124 | 125 | 126 | def new_guided_attention_matrix(attention: torch.Tensor, g: float) -> torch.Tensor: 127 | T = attention.size(1) 128 | N = attention.size(2) 129 | t_vals = torch.arange(T, device=attention.device, dtype=attention.dtype) 130 | n_vals = torch.arange(N, device=attention.device, dtype=attention.dtype) 131 | t_diff = t_vals[:, None] / T - n_vals[None, :] / N 132 | dia_mat = torch.exp(-t_diff**2 / (2 * g**2)).unsqueeze(0) 133 | return dia_mat 134 | 135 | 136 | def to_device(batch: Dict[str, torch.tensor], 137 | device: torch.device) -> Dict[str, torch.tensor]: 138 | output = {} 139 | for key, val in batch.items(): 140 | val = val.to(device) if torch.is_tensor(val) else val 141 | output[key] = val 142 | return output 143 | 144 | 145 | def np_now(x: torch.Tensor): return x.detach().cpu().numpy() 146 | -------------------------------------------------------------------------------- /trainer/taco_trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.optim.optimizer import Optimizer 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from typing import Tuple, Dict, Any 10 | 11 | from models.tacotron import Tacotron 12 | from trainer.common import Averager, TTSSession, to_device, np_now, ForwardSumLoss, new_guided_attention_matrix 13 | from utils.checkpoints import save_checkpoint 14 | from utils.dataset import get_taco_dataloaders 15 | from utils.decorators import ignore_exception 16 | from utils.display import stream, simple_table, plot_mel, plot_attention 17 | from utils.dsp import DSP 18 | from utils.files import parse_schedule 19 | from utils.metrics import attention_score 20 | from utils.paths import Paths 21 | 22 | 23 | class TacoTrainer: 24 | 25 | def __init__(self, 26 | paths: Paths, 27 | dsp: DSP, 28 | config: Dict[str, Any]) -> None: 29 | self.paths = paths 30 | self.dsp = dsp 31 | self.config = config 32 | self.train_cfg = config['tacotron']['training'] 33 | self.writer = SummaryWriter(log_dir=paths.taco_log, comment='v1') 34 | self.forward_loss = ForwardSumLoss() 35 | 36 | def train(self, 37 | model: Tacotron, 38 | optimizer: Optimizer) -> None: 39 | tts_schedule = self.train_cfg['schedule'] 40 | tts_schedule = parse_schedule(tts_schedule) 41 | for i, session_params in enumerate(tts_schedule, 1): 42 | r, lr, max_step, bs = session_params 43 | if model.get_step() < max_step: 44 | train_set, val_set = get_taco_dataloaders( 45 | paths=self.paths, batch_size=bs, r=r, 46 | **self.train_cfg['filter'] 47 | ) 48 | session = TTSSession( 49 | index=i, r=r, lr=lr, max_step=max_step, 50 | bs=bs, train_set=train_set, val_set=val_set) 51 | self.train_session(model, optimizer, session=session) 52 | 53 | def train_session(self, model: Tacotron, 54 | optimizer: Optimizer, 55 | session: TTSSession) -> None: 56 | current_step = model.get_step() 57 | training_steps = session.max_step - current_step 58 | total_iters = len(session.train_set) 59 | epochs = training_steps // total_iters + 1 60 | model.r = session.r 61 | simple_table([(f'Steps with r={session.r}', str(training_steps // 1000) + 'k Steps'), 62 | ('Batch Size', session.bs), 63 | ('Learning Rate', session.lr), 64 | ('Outputs/Step (r)', model.r)]) 65 | for g in optimizer.param_groups: 66 | g['lr'] = session.lr 67 | 68 | loss_avg = Averager() 69 | duration_avg = Averager() 70 | device = next(model.parameters()).device # use same device as model parameters 71 | for e in range(1, epochs + 1): 72 | for i, batch in enumerate(session.train_set, 1): 73 | batch = to_device(batch, device=device) 74 | start = time.time() 75 | model.train() 76 | 77 | out = model(batch) 78 | m1_hat, m2_hat, attention, att_aligner = out['mel'], out['mel_post'], out['att'], out['att_aligner'] 79 | ctc_loss = self.forward_loss(att_aligner, text_lens=batch['x_len'], mel_lens=batch['mel_len']) 80 | 81 | m1_loss = F.l1_loss(m1_hat, batch['mel']) 82 | m2_loss = F.l1_loss(m2_hat, batch['mel']) 83 | 84 | dia_mat = new_guided_attention_matrix(attention=attention, 85 | g=self.train_cfg['dia_loss_matrix_g']) 86 | dia_loss = ((1 - dia_mat) * attention).mean() 87 | 88 | mel_loss = m1_loss + m2_loss 89 | loss = mel_loss + self.train_cfg['ctc_loss_factor'] * ctc_loss \ 90 | + self.train_cfg['dia_loss_factor'] * dia_loss 91 | 92 | optimizer.zero_grad() 93 | loss.backward() 94 | torch.nn.utils.clip_grad_norm_(model.parameters(), 95 | self.train_cfg['clip_grad_norm']) 96 | optimizer.step() 97 | loss_avg.add(loss.item()) 98 | step = model.get_step() 99 | k = step // 1000 100 | 101 | duration_avg.add(time.time() - start) 102 | speed = 1. / duration_avg.get() 103 | msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {loss_avg.get():#.4} ' \ 104 | f'| {speed:#.2} steps/s | Step: {k}k | ' 105 | 106 | if step % self.train_cfg['checkpoint_every'] == 0: 107 | save_checkpoint(model=model, optim=optimizer, config=self.config, 108 | path=self.paths.taco_checkpoints / f'taco_step{k}k.pt') 109 | 110 | if step % self.train_cfg['plot_every'] == 0: 111 | self.generate_plots(model, session) 112 | 113 | _, att_score = attention_score(attention, batch['mel_len']) 114 | att_score = torch.mean(att_score) 115 | self.writer.add_scalar('Attention_Score/train', att_score, model.get_step()) 116 | self.writer.add_scalar('Mel_Loss/train', mel_loss, model.get_step()) 117 | self.writer.add_scalar('CTC_Loss/train', ctc_loss, model.get_step()) 118 | self.writer.add_scalar('Dia_Loss/train', dia_loss, model.get_step()) 119 | self.writer.add_scalar('Params/reduction_factor', session.r, model.get_step()) 120 | self.writer.add_scalar('Params/batch_size', session.bs, model.get_step()) 121 | self.writer.add_scalar('Params/learning_rate', session.lr, model.get_step()) 122 | 123 | stream(msg) 124 | 125 | val_loss, val_att_score = self.evaluate(model, session.val_set) 126 | self.writer.add_scalar('Mel_Loss/val', val_loss, model.get_step()) 127 | self.writer.add_scalar('Attention_Score/val', val_att_score, model.get_step()) 128 | save_checkpoint(model=model, optim=optimizer, config=self.config, 129 | path=self.paths.taco_checkpoints / 'latest_model.pt') 130 | 131 | loss_avg.reset() 132 | duration_avg.reset() 133 | print(' ') 134 | 135 | def evaluate(self, model: Tacotron, val_set: DataLoader) -> Tuple[float, float]: 136 | model.eval() 137 | model.decoder.prenet.train() 138 | val_loss = 0 139 | val_att_score = 0 140 | device = next(model.parameters()).device 141 | for i, batch in enumerate(val_set, 1): 142 | batch = to_device(batch, device=device) 143 | with torch.no_grad(): 144 | out = model(batch) 145 | m1_hat, m2_hat, attention = out['mel'], out['mel_post'], out['att'] 146 | m1_loss = F.l1_loss(m1_hat, batch['mel']) 147 | m2_loss = F.l1_loss(m2_hat, batch['mel']) 148 | val_loss += m1_loss.item() + m2_loss.item() 149 | _, att_score = attention_score(attention, batch['mel_len']) 150 | val_att_score += torch.mean(att_score).item() 151 | 152 | return val_loss / len(val_set), val_att_score / len(val_set) 153 | 154 | @ignore_exception 155 | def generate_plots(self, model: Tacotron, session: TTSSession) -> None: 156 | model.eval() 157 | device = next(model.parameters()).device 158 | batch = session.val_sample 159 | batch = to_device(batch, device=device) 160 | with torch.no_grad(): 161 | out = model(batch) 162 | m1_hat, m2_hat, att, att_aligner = out['mel'], out['mel_post'], out['att'], out['att_aligner'] 163 | att = np_now(att)[0] 164 | att_aligner = np_now(att_aligner.softmax(-1))[0] 165 | m1_hat = np_now(m1_hat)[0, :, :] 166 | m2_hat = np_now(m2_hat)[0, :, :] 167 | m_target = np_now(batch['mel'])[0, :, :] 168 | speaker = batch['speaker_name'][0] 169 | 170 | att_fig = plot_attention(att) 171 | att_aligner_fig = plot_attention(att_aligner) 172 | 173 | m1_hat_fig = plot_mel(m1_hat) 174 | m2_hat_fig = plot_mel(m2_hat) 175 | m_target_fig = plot_mel(m_target) 176 | 177 | self.writer.add_figure(f'Ground_Truth_Aligned/attention/{speaker}', att_fig, model.step) 178 | self.writer.add_figure(f'Ground_Truth_Aligned/attention_aligner/{speaker}', att_aligner_fig, model.step) 179 | self.writer.add_figure(f'Ground_Truth_Aligned/target/{speaker}', m_target_fig, model.step) 180 | self.writer.add_figure(f'Ground_Truth_Aligned/linear/{speaker}', m1_hat_fig, model.step) 181 | self.writer.add_figure(f'Ground_Truth_Aligned/postnet/{speaker}', m2_hat_fig, model.step) 182 | 183 | m2_hat_wav = self.dsp.griffinlim(m2_hat) 184 | target_wav = self.dsp.griffinlim(m_target) 185 | 186 | self.writer.add_audio( 187 | tag=f'Ground_Truth_Aligned/target_wav/{speaker}', snd_tensor=target_wav, 188 | global_step=model.step, sample_rate=self.dsp.sample_rate) 189 | self.writer.add_audio( 190 | tag=f'Ground_Truth_Aligned/postnet_wav/{speaker}', snd_tensor=m2_hat_wav, 191 | global_step=model.step, sample_rate=self.dsp.sample_rate) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/utils/__init__.py -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, Any, Union 3 | 4 | import torch 5 | import torch.optim.optimizer 6 | from models.fast_pitch import FastPitch 7 | from models.forward_tacotron import ForwardTacotron 8 | from models.multi_fast_pitch import MultiFastPitch 9 | from models.multi_forward_tacotron import MultiForwardTacotron 10 | from models.tacotron import Tacotron 11 | 12 | 13 | def save_checkpoint(model: torch.nn.Module, 14 | optim: torch.optim.Optimizer, 15 | config: Dict[str, Any], 16 | path: Path, 17 | meta: Dict[str, Any] = None) -> None: 18 | checkpoint = {'model': model.state_dict(), 19 | 'optim': optim.state_dict(), 20 | 'config': config} 21 | if meta is not None: 22 | checkpoint.update(meta) 23 | torch.save(checkpoint, str(path)) 24 | 25 | 26 | def restore_checkpoint(model: Union[FastPitch, ForwardTacotron, Tacotron, MultiForwardTacotron, MultiFastPitch], 27 | optim: torch.optim.Optimizer, 28 | path: Path, 29 | device: torch.device) -> None: 30 | if path.is_file(): 31 | checkpoint = torch.load(path, map_location=device) 32 | model.load_state_dict(checkpoint['model']) 33 | optim.load_state_dict(checkpoint['optim']) 34 | print(f'Restored model with step {model.get_step()}\n') 35 | 36 | 37 | def init_tts_model(config: Dict[str, Any]) -> Union[ForwardTacotron, FastPitch, MultiForwardTacotron, MultiFastPitch]: 38 | model_type = config.get('tts_model', 'forward_tacotron') 39 | if model_type == 'forward_tacotron': 40 | model = ForwardTacotron.from_config(config) 41 | elif model_type == 'fast_pitch': 42 | model = FastPitch.from_config(config) 43 | elif model_type == 'multi_forward_tacotron': 44 | model = MultiForwardTacotron.from_config(config) 45 | elif model_type == 'multi_fast_pitch': 46 | model = MultiFastPitch.from_config(config) 47 | else: 48 | raise ValueError(f'Model type not supported: {model_type}') 49 | return model 50 | -------------------------------------------------------------------------------- /utils/decorators.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from time import time 3 | from typing import Callable 4 | 5 | 6 | def ignore_exception(f) -> Callable: 7 | def apply_func(*args, **kwargs): 8 | try: 9 | result = f(*args, **kwargs) 10 | return result 11 | except Exception: 12 | print(f'Catched exception in {f}:') 13 | traceback.print_exc() 14 | return None 15 | return apply_func 16 | 17 | 18 | def time_it(f) -> Callable: 19 | def apply_func(*args, **kwargs): 20 | t_start = time() 21 | result = f(*args, **kwargs) 22 | t_end = time() 23 | dur = round(t_end - t_start, ndigits=2) 24 | print(f'{f} took {dur}s') 25 | return result 26 | return apply_func -------------------------------------------------------------------------------- /utils/display.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | from matplotlib.figure import Figure 3 | mpl.use('agg') # Use non-interactive backend by default 4 | import matplotlib.pyplot as plt 5 | import time 6 | import numpy as np 7 | import sys 8 | 9 | 10 | def progbar(i, n, size=16): 11 | done = (i * size) // n 12 | bar = '' 13 | for i in range(size): 14 | bar += '█' if i <= done else '░' 15 | return bar 16 | 17 | 18 | def stream(message): 19 | sys.stdout.write(f"\r{message}") 20 | 21 | 22 | def simple_table(item_tuples): 23 | 24 | border_pattern = '+---------------------------------------' 25 | whitespace = ' ' 26 | 27 | headings, cells, = [], [] 28 | 29 | for item in item_tuples: 30 | 31 | heading, cell = str(item[0]), str(item[1]) 32 | 33 | pad_head = True if len(heading) < len(cell) else False 34 | 35 | pad = abs(len(heading) - len(cell)) 36 | pad = whitespace[:pad] 37 | 38 | pad_left = pad[:len(pad)//2] 39 | pad_right = pad[len(pad)//2:] 40 | 41 | if pad_head: 42 | heading = pad_left + heading + pad_right 43 | else: 44 | cell = pad_left + cell + pad_right 45 | 46 | headings += [heading] 47 | cells += [cell] 48 | 49 | border, head, body = '', '', '' 50 | 51 | for i in range(len(item_tuples)): 52 | 53 | temp_head = f'| {headings[i]} ' 54 | temp_body = f'| {cells[i]} ' 55 | 56 | border += border_pattern[:len(temp_head)] 57 | head += temp_head 58 | body += temp_body 59 | 60 | if i == len(item_tuples) - 1: 61 | head += '|' 62 | body += '|' 63 | border += '+' 64 | 65 | print(border) 66 | print(head) 67 | print(border) 68 | print(body) 69 | print(border) 70 | print(' ') 71 | 72 | 73 | def time_since(started): 74 | elapsed = time.time() - started 75 | m = int(elapsed // 60) 76 | s = int(elapsed % 60) 77 | if m >= 60: 78 | h = int(m // 60) 79 | m = m % 60 80 | return f'{h}h {m}m {s}s' 81 | else: 82 | return f'{m}m {s}s' 83 | 84 | 85 | def save_attention(attn, path): 86 | fig = plt.figure(figsize=(12, 6)) 87 | plt.imshow(attn.T, interpolation='nearest', aspect='auto') 88 | fig.savefig(path.parent/f'{path.stem}.png', bbox_inches='tight') 89 | plt.close(fig) 90 | 91 | 92 | def save_spectrogram(M, path, length=None): 93 | M = np.flip(M, axis=0) 94 | if length: M = M[:, :length] 95 | fig = plt.figure(figsize=(12, 6)) 96 | plt.imshow(M, interpolation='nearest', aspect='auto') 97 | fig.savefig(f'{path}.png', bbox_inches='tight') 98 | plt.close(fig) 99 | 100 | 101 | def plot(array): 102 | mpl.interactive(True) 103 | fig = plt.figure(figsize=(30, 5)) 104 | ax = fig.add_subplot(111) 105 | ax.xaxis.label.set_color('grey') 106 | ax.yaxis.label.set_color('grey') 107 | ax.xaxis.label.set_fontsize(23) 108 | ax.yaxis.label.set_fontsize(23) 109 | ax.tick_params(axis='x', colors='grey', labelsize=23) 110 | ax.tick_params(axis='y', colors='grey', labelsize=23) 111 | plt.plot(array) 112 | mpl.interactive(False) 113 | 114 | 115 | def plot_mel(mel: np.array) -> Figure: 116 | mel = np.flip(mel, axis=0) 117 | fig = plt.figure(figsize=(12, 6), dpi=150) 118 | plt.imshow(mel, interpolation='nearest', aspect='auto') 119 | return fig 120 | 121 | 122 | def plot_pitch(pitch: np.array, color='gray') -> Figure: 123 | fig = plt.figure(figsize=(12, 6), dpi=100) 124 | plt.plot(pitch, color=color) 125 | return fig 126 | 127 | 128 | def plot_attention(attn: np.array) -> Figure: 129 | fig = plt.figure(figsize=(12, 6)) 130 | plt.imshow(attn.T, interpolation='nearest', aspect='auto') 131 | return fig 132 | 133 | 134 | def plot_spec(M): 135 | mpl.interactive(True) 136 | M = np.flip(M, axis=0) 137 | plt.figure(figsize=(18,4)) 138 | plt.imshow(M, interpolation='nearest', aspect='auto') 139 | plt.show() 140 | mpl.interactive(False) 141 | 142 | -------------------------------------------------------------------------------- /utils/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def log_sum_exp(x): 7 | """ numerically stable log_sum_exp implementation that prevents overflow """ 8 | # TF ordering 9 | axis = len(x.size()) - 1 10 | m, _ = torch.max(x, dim=axis) 11 | m2, _ = torch.max(x, dim=axis, keepdim=True) 12 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 13 | 14 | 15 | # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py 16 | def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, 17 | log_scale_min=None, reduce=True): 18 | if log_scale_min is None: 19 | log_scale_min = float(np.log(1e-14)) 20 | y_hat = y_hat.permute(0,2,1) 21 | assert y_hat.dim() == 3 22 | assert y_hat.size(1) % 3 == 0 23 | nr_mix = y_hat.size(1) // 3 24 | 25 | # (B x T x C) 26 | y_hat = y_hat.transpose(1, 2) 27 | 28 | # unpack parameters. (B, T, num_mixtures) x 3 29 | logit_probs = y_hat[:, :, :nr_mix] 30 | means = y_hat[:, :, nr_mix:2 * nr_mix] 31 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) 32 | 33 | # B x T x 1 -> B x T x num_mixtures 34 | y = y.expand_as(means) 35 | 36 | centered_y = y - means 37 | inv_stdv = torch.exp(-log_scales) 38 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) 39 | cdf_plus = torch.sigmoid(plus_in) 40 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) 41 | cdf_min = torch.sigmoid(min_in) 42 | 43 | # log probability for edge case of 0 (before scaling) 44 | # equivalent: torch.log(F.sigmoid(plus_in)) 45 | log_cdf_plus = plus_in - F.softplus(plus_in) 46 | 47 | # log probability for edge case of 255 (before scaling) 48 | # equivalent: (1 - F.sigmoid(min_in)).log() 49 | log_one_minus_cdf_min = -F.softplus(min_in) 50 | 51 | # probability for all other cases 52 | cdf_delta = cdf_plus - cdf_min 53 | 54 | mid_in = inv_stdv * centered_y 55 | # log probability in the center of the bin, to be used in extreme cases 56 | # (not actually used in our code) 57 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 58 | 59 | # tf equivalent 60 | """ 61 | log_probs = tf.where(x < -0.999, log_cdf_plus, 62 | tf.where(x > 0.999, log_one_minus_cdf_min, 63 | tf.where(cdf_delta > 1e-5, 64 | tf.log(tf.maximum(cdf_delta, 1e-12)), 65 | log_pdf_mid - np.log(127.5)))) 66 | """ 67 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value 68 | # for num_classes=65536 case? 1e-7? not sure.. 69 | inner_inner_cond = (cdf_delta > 1e-5).float() 70 | 71 | inner_inner_out = inner_inner_cond * \ 72 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ 73 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) 74 | inner_cond = (y > 0.999).float() 75 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 76 | cond = (y < -0.999).float() 77 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 78 | 79 | log_probs = log_probs + F.log_softmax(logit_probs, -1) 80 | 81 | if reduce: 82 | return -torch.mean(log_sum_exp(log_probs)) 83 | else: 84 | return -log_sum_exp(log_probs).unsqueeze(-1) 85 | 86 | 87 | def sample_from_discretized_mix_logistic(y, log_scale_min=None): 88 | """ 89 | Sample from discretized mixture of logistic distributions 90 | Args: 91 | y (Tensor): B x C x T 92 | log_scale_min (float): Log scale minimum value 93 | Returns: 94 | Tensor: sample in range of [-1, 1]. 95 | """ 96 | if log_scale_min is None: 97 | log_scale_min = float(np.log(1e-14)) 98 | assert y.size(1) % 3 == 0 99 | nr_mix = y.size(1) // 3 100 | 101 | # B x T x C 102 | y = y.transpose(1, 2) 103 | logit_probs = y[:, :, :nr_mix] 104 | 105 | # sample mixture indicator from softmax 106 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) 107 | temp = logit_probs.data - torch.log(- torch.log(temp)) 108 | _, argmax = temp.max(dim=-1) 109 | 110 | # (B, T) -> (B, T, nr_mix) 111 | one_hot = F.one_hot(argmax, nr_mix).float() 112 | # select logistic parameters 113 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 114 | log_scales = torch.clamp(torch.sum( 115 | y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) 116 | # sample from logistic & clip to interval 117 | # we don't actually round to the nearest 8bit value when sampling 118 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) 119 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 120 | 121 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 122 | 123 | return x 124 | 125 | ''' 126 | def to_one_hot(tensor, n, fill_with=1.): 127 | # we perform one hot encore with respect to the last axis 128 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 129 | if tensor.is_cuda: 130 | one_hot = one_hot.cuda() 131 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 132 | return one_hot''' 133 | -------------------------------------------------------------------------------- /utils/dsp.py: -------------------------------------------------------------------------------- 1 | import struct 2 | from pathlib import Path 3 | from typing import Dict, Any, Union, List 4 | import numpy as np 5 | import librosa 6 | import torch 7 | import webrtcvad 8 | from scipy.ndimage import binary_dilation 9 | import torchaudio 10 | import torchaudio.transforms as transforms 11 | 12 | from utils.dataset import tensor_to_ndarray, ndarray_to_tensor 13 | 14 | 15 | class DSP: 16 | 17 | def __init__(self, 18 | num_mels: int, 19 | sample_rate: int, 20 | hop_length: int, 21 | win_length: int, 22 | n_fft: int, 23 | fmin: float, 24 | fmax: float, 25 | peak_norm: bool, 26 | trim_start_end_silence: bool, 27 | trim_silence_top_db: int, 28 | trim_long_silences: bool, 29 | vad_sample_rate: int, 30 | vad_window_length: float, 31 | vad_moving_average_width: float, 32 | vad_max_silence_length: int, 33 | **kwargs, # for backward compatibility 34 | ) -> None: 35 | 36 | self.n_mels = num_mels 37 | self.sample_rate = sample_rate 38 | self.hop_length = hop_length 39 | self.win_length = win_length 40 | self.n_fft = n_fft 41 | self.fmin = fmin 42 | self.fmax = fmax 43 | 44 | self.should_peak_norm = peak_norm 45 | self.should_trim_start_end_silence = trim_start_end_silence 46 | self.should_trim_long_silences = trim_long_silences 47 | self.trim_silence_top_db = trim_silence_top_db 48 | 49 | self.vad_sample_rate = vad_sample_rate 50 | self.vad_window_length = vad_window_length 51 | self.vad_moving_average_width = vad_moving_average_width 52 | self.vad_max_silence_length = vad_max_silence_length 53 | 54 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | 56 | # init transformation 57 | self.mel_transform = self._init_mel_transform() 58 | 59 | @classmethod 60 | def from_config(cls, config: Dict[str, Any]) -> 'DSP': 61 | """Initialize from configuration object""" 62 | return DSP(**config['dsp']) 63 | 64 | def _init_mel_transform(self): 65 | """Initialize mel transformation""" 66 | mel_transform = transforms.MelSpectrogram( 67 | sample_rate=self.sample_rate, 68 | n_fft=self.n_fft, 69 | win_length=self.win_length, 70 | hop_length=self.hop_length, 71 | power=1, 72 | norm="slaney", 73 | n_mels=self.n_mels, 74 | mel_scale="slaney", 75 | f_min=self.fmin, 76 | f_max=self.fmax, 77 | ).to(self.device) 78 | 79 | return mel_transform 80 | 81 | def load_wav(self, path: Union[str, Path], mono: bool = True) -> torch.Tensor: 82 | """Load audio file into a tensor""" 83 | effects = [] 84 | metadata = torchaudio.info(path) 85 | 86 | # merge channels if source is multichannel 87 | if mono and metadata.num_channels > 1: 88 | effects.extend([ 89 | ["remix", "-"] # convert to mono 90 | ]) 91 | 92 | # resample if source sample rate is different from desired sample rate 93 | if metadata.sample_rate != self.sample_rate: 94 | effects.extend([ 95 | ["rate", f'{self.sample_rate}'], 96 | ]) 97 | 98 | waveform, _ = torchaudio.sox_effects.apply_effects_file(path, effects=effects) 99 | return waveform 100 | 101 | def save_wav(self, waveform: torch.Tensor, path: Union[str, Path]) -> None: 102 | """Save waveform to file""" 103 | torchaudio.save(filepath=path, src=waveform, sample_rate=self.sample_rate) 104 | 105 | def adjust_volume(self, waveform: torch.Tensor, target_dbfs: int = -30) -> torch.Tensor: 106 | """Adjust volume of the waveform""" 107 | volume_transform = transforms.Vol(gain=target_dbfs, gain_type='db').to(self.device) 108 | return volume_transform(waveform) 109 | 110 | def adjust_volume_batched(self, data: List[torch.Tensor], target_dbfs: int = -30) -> List[torch.Tensor]: 111 | """Adjust volume of the waveforms in the batch""" 112 | lengths = [tensor.size(1) for tensor in data] 113 | padded_batch = [torch.nn.functional.pad(x, (0, max(lengths) - x.size(1))) for x in data] 114 | stacked_tensor = torch.stack(padded_batch, dim=0) 115 | processed_batch = self.adjust_volume(stacked_tensor, target_dbfs=target_dbfs) 116 | result = [processed_waveform[:, :lengths[index]] for index, processed_waveform in enumerate(processed_batch)] 117 | return result 118 | 119 | def waveform_to_mel_batched(self, batch: List[torch.Tensor]) -> List[torch.Tensor]: 120 | """Convert waveform to mel spectrogram for the batch of waveforms""" 121 | lengths = [tensor.size(1) for tensor in batch] 122 | expected_mel_lengths = [x // self.hop_length + 1 for x in lengths] 123 | padded_batch = [torch.nn.functional.pad(x, (0, max(lengths) - x.size(1))) for x in batch] 124 | batch_tensor = torch.stack(padded_batch, dim=0).to(self.device) 125 | mels = self.waveform_to_mel(batch_tensor) 126 | list_of_mels = [mel[:, :, :expected_mel_lengths[index]] for index, mel in enumerate(mels)] 127 | return list_of_mels 128 | 129 | def waveform_to_mel(self, waveform: torch.Tensor, normalized: bool = True) -> torch.Tensor: 130 | """Convert waveform to mel spectrogram""" 131 | mel_spec = self.mel_transform(waveform) 132 | if normalized: 133 | mel_spec = self.normalize(mel_spec) 134 | return mel_spec 135 | 136 | def griffinlim(self, mel: np.array, n_iter: int = 32) -> np.array: 137 | mel = self.denormalize(mel) 138 | S = librosa.feature.inverse.mel_to_stft( 139 | mel, 140 | power=1, 141 | sr=self.sample_rate, 142 | n_fft=self.n_fft, 143 | fmin=self.fmin, 144 | fmax=self.fmax) 145 | wav = librosa.core.griffinlim( 146 | S, 147 | n_iter=n_iter, 148 | hop_length=self.hop_length, 149 | win_length=self.win_length) 150 | return wav 151 | 152 | @staticmethod 153 | def normalize(mel: torch.Tensor) -> torch.Tensor: 154 | """Normalize mel spectrogram""" 155 | mel = torch.clip(mel, min=1.e-5, max=None) 156 | return torch.log(mel) 157 | 158 | @staticmethod 159 | def denormalize(mel: np.ndarray) -> np.ndarray: 160 | """Denormalize mel spectrogram""" 161 | return np.exp(mel) 162 | 163 | def trim_silence(self, waveform: torch.Tensor) -> torch.Tensor: 164 | """Trim silence from the waveform""" 165 | waveform = tensor_to_ndarray(waveform) 166 | trimmed_waveform = librosa.effects.trim(waveform, 167 | top_db=self.trim_silence_top_db, 168 | frame_length=self.win_length, 169 | hop_length=self.hop_length) 170 | return ndarray_to_tensor(trimmed_waveform[0]) 171 | 172 | # borrowed from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/audio.py 173 | def trim_long_silences(self, wav: torch.Tensor) -> torch.Tensor: 174 | wav = tensor_to_ndarray(wav) 175 | int16_max = (2 ** 15) - 1 176 | samples_per_window = (self.vad_window_length * self.vad_sample_rate) // 1000 177 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 178 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 179 | voice_flags = [] 180 | vad = webrtcvad.Vad(mode=3) 181 | for window_start in range(0, len(wav), samples_per_window): 182 | window_end = window_start + samples_per_window 183 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 184 | sample_rate=self.vad_sample_rate)) 185 | voice_flags = np.array(voice_flags) 186 | def moving_average(array, width): 187 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 188 | ret = np.cumsum(array_padded, dtype=float) 189 | ret[width:] = ret[width:] - ret[:-width] 190 | return ret[width - 1:] / width 191 | audio_mask = moving_average(voice_flags, self.vad_moving_average_width) 192 | audio_mask = np.round(audio_mask).astype(np.bool) 193 | audio_mask[:] = binary_dilation(audio_mask[:], np.ones(self.vad_max_silence_length + 1)) 194 | audio_mask = np.repeat(audio_mask, samples_per_window) 195 | return ndarray_to_tensor(wav[audio_mask]) 196 | -------------------------------------------------------------------------------- /utils/files.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import yaml 3 | from pathlib import Path 4 | from typing import Union, List, Any, Dict, Tuple 5 | 6 | 7 | def get_files(path: Path, extension='.wav') -> List[Path]: 8 | path = path.expanduser().resolve() 9 | return list(path.rglob(f'*{extension}')) 10 | 11 | 12 | def pickle_binary(data: object, file: Union[str, Path]) -> None: 13 | with open(str(file), 'wb') as f: 14 | pickle.dump(data, f) 15 | 16 | 17 | def unpickle_binary(file: Union[str, Path]) -> Any: 18 | with open(str(file), 'rb') as f: 19 | return pickle.load(f) 20 | 21 | 22 | def read_config(path: Union[Path, str]) -> Dict[str, Any]: 23 | with open(str(path), 'r') as stream: 24 | config = yaml.load(stream, Loader=yaml.FullLoader) 25 | return config 26 | 27 | 28 | def save_config(config: Dict[str, Any], path: str) -> None: 29 | with open(path, 'w+', encoding='utf-8') as stream: 30 | yaml.dump(config, stream, default_flow_style=False) 31 | 32 | 33 | def parse_schedule(schedule: List[str]) -> List[Tuple]: 34 | out = [] 35 | for line in schedule: 36 | split = line.split(',') 37 | if len(split) == 4: 38 | r, lr, step, bs = split 39 | out.append((int(r), float(lr), int(step), int(bs))) 40 | else: 41 | lr, step, bs = split 42 | out.append((float(lr), int(step), int(bs))) 43 | return out 44 | 45 | 46 | 47 | 48 | if __name__ == '__main__': 49 | config = read_config('../configs/default.yaml') 50 | print(config) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def attention_score(att, mel_lens, r=1): 5 | """ 6 | Returns a tuple of scores (loc_score, sharp_score), where loc_score measures monotonicity and 7 | sharp_score measures the sharpness of attention peaks 8 | """ 9 | 10 | with torch.no_grad(): 11 | device = att.device 12 | mel_lens = mel_lens.to(device) 13 | b, t_max, c_max = att.size() 14 | 15 | # create mel padding mask 16 | mel_range = torch.arange(0, t_max, device=device) 17 | mel_lens = mel_lens // r 18 | mask = (mel_range[None, :] < mel_lens[:, None]).float() 19 | 20 | # score for how adjacent the attention loc is 21 | max_loc = torch.argmax(att, dim=2) 22 | max_loc_diff = torch.abs(max_loc[:, 1:] - max_loc[:, :-1]) 23 | loc_score = (max_loc_diff >= 0) * (max_loc_diff <= r) 24 | loc_score = torch.sum(loc_score * mask[:, 1:], dim=1) 25 | loc_score = loc_score / (mel_lens - 1) 26 | 27 | # score for attention sharpness 28 | sharp_score, inds = att.max(dim=2) 29 | sharp_score = torch.sum(sharp_score * mask, dim=1) / torch.sum(mask, dim=1) 30 | 31 | return loc_score, sharp_score -------------------------------------------------------------------------------- /utils/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | class Paths: 6 | """Manages and configures the paths used by WaveRNN, Tacotron, and the data.""" 7 | def __init__(self, data_path, tts_id): 8 | 9 | # directories 10 | self.base = Path(__file__).parent.parent.expanduser().resolve() 11 | self.data = Path(data_path).expanduser().resolve() 12 | self.quant = self.data/'quant' 13 | self.mel = self.data/'mel' 14 | self.gta = self.data/'gta' 15 | self.att_pred = self.data/'att_pred' 16 | self.alg = self.data/'alg' 17 | self.speaker_emb = self.data/'speaker_emb' 18 | self.mean_speaker_emb = self.data/'mean_speaker_emb' 19 | self.raw_pitch = self.data/'raw_pitch' 20 | self.phon_pitch = self.data/'phon_pitch' 21 | self.phon_energy = self.data/'phon_energy' 22 | self.model_output = self.base / 'model_output' 23 | self.taco_checkpoints = self.base / 'checkpoints' / f'{tts_id}.tacotron' 24 | self.taco_log = self.taco_checkpoints / 'logs' 25 | self.forward_checkpoints = self.base/'checkpoints'/f'{tts_id}.forward' 26 | self.forward_log = self.forward_checkpoints/'logs' 27 | 28 | # pickle objects 29 | self.train_dataset = self.data / 'train_dataset.pkl' 30 | self.val_dataset = self.data / 'val_dataset.pkl' 31 | self.text_dict = self.data / 'text_dict.pkl' 32 | self.speaker_dict = self.data / 'speaker_dict.pkl' 33 | self.duration_stats = self.data / 'duration_stats.pkl' 34 | 35 | self.create_paths() 36 | 37 | def create_paths(self): 38 | os.makedirs(self.data, exist_ok=True) 39 | os.makedirs(self.quant, exist_ok=True) 40 | os.makedirs(self.mel, exist_ok=True) 41 | os.makedirs(self.gta, exist_ok=True) 42 | os.makedirs(self.alg, exist_ok=True) 43 | os.makedirs(self.speaker_emb, exist_ok=True) 44 | os.makedirs(self.mean_speaker_emb, exist_ok=True) 45 | os.makedirs(self.att_pred, exist_ok=True) 46 | os.makedirs(self.raw_pitch, exist_ok=True) 47 | os.makedirs(self.phon_pitch, exist_ok=True) 48 | os.makedirs(self.phon_energy, exist_ok=True) 49 | os.makedirs(self.taco_checkpoints, exist_ok=True) 50 | os.makedirs(self.forward_checkpoints, exist_ok=True) 51 | 52 | def get_tts_named_weights(self, name): 53 | """Gets the path for the weights in a named tts checkpoint.""" 54 | return self.taco_checkpoints / f'{name}_weights.pyt' 55 | 56 | def get_tts_named_optim(self, name): 57 | """Gets the path for the optimizer state in a named tts checkpoint.""" 58 | return self.taco_checkpoints / f'{name}_optim.pyt' 59 | 60 | def get_voc_named_weights(self, name): 61 | """Gets the path for the weights in a named voc checkpoint.""" 62 | return self.voc_checkpoints/f'{name}_weights.pyt' 63 | 64 | def get_voc_named_optim(self, name): 65 | """Gets the path for the optimizer state in a named voc checkpoint.""" 66 | return self.voc_checkpoints/f'{name}_optim.pyt' 67 | 68 | 69 | -------------------------------------------------------------------------------- /utils/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /utils/text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/ForwardTacotron/2c3d281c4501040bdb844357203504efa7c6614d/utils/text/__init__.py -------------------------------------------------------------------------------- /utils/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, Any 3 | 4 | from phonemizer.backend import EspeakBackend 5 | from unidecode import unidecode 6 | 7 | from utils.text.numbers import normalize_numbers 8 | from utils.text.symbols import phonemes_set 9 | 10 | # Regular expression matching whitespace: 11 | _whitespace_re = re.compile(r'\s+') 12 | 13 | # List of (regular expression, replacement) pairs for abbreviations: 14 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 15 | ('mrs', 'misess'), 16 | ('mr', 'mister'), 17 | ('dr', 'doctor'), 18 | ('st', 'saint'), 19 | ('co', 'company'), 20 | ('jr', 'junior'), 21 | ('maj', 'major'), 22 | ('gen', 'general'), 23 | ('drs', 'doctors'), 24 | ('rev', 'reverend'), 25 | ('lt', 'lieutenant'), 26 | ('hon', 'honorable'), 27 | ('sgt', 'sergeant'), 28 | ('capt', 'captain'), 29 | ('esq', 'esquire'), 30 | ('ltd', 'limited'), 31 | ('col', 'colonel'), 32 | ('ft', 'fort'), 33 | ]] 34 | 35 | 36 | def expand_abbreviations(text): 37 | for regex, replacement in _abbreviations: 38 | text = re.sub(regex, replacement, text) 39 | return text 40 | 41 | 42 | def collapse_whitespace(text): 43 | return re.sub(_whitespace_re, ' ', text) 44 | 45 | 46 | def no_cleaners(text): 47 | return text 48 | 49 | 50 | def english_cleaners(text): 51 | text = unidecode(text) 52 | text = normalize_numbers(text) 53 | text = expand_abbreviations(text) 54 | return text 55 | 56 | 57 | class Cleaner: 58 | 59 | def __init__(self, 60 | cleaner_name: str, 61 | use_phonemes: bool, 62 | lang: str) -> None: 63 | if cleaner_name == 'english_cleaners': 64 | self.clean_func = english_cleaners 65 | elif cleaner_name == 'no_cleaners': 66 | self.clean_func = no_cleaners 67 | else: 68 | raise ValueError(f'Cleaner not supported: {cleaner_name}! ' 69 | f'Currently supported: [\'english_cleaners\', \'no_cleaners\']') 70 | self.use_phonemes = use_phonemes 71 | self.lang = lang 72 | if use_phonemes: 73 | self.backend = EspeakBackend(language=lang, 74 | preserve_punctuation=True, 75 | with_stress=False, 76 | punctuation_marks=';:,.!?¡¿—…"«»“”()', 77 | language_switch='remove-flags') 78 | 79 | def __call__(self, text: str) -> str: 80 | text = self.clean_func(text) 81 | if self.use_phonemes: 82 | text = self.backend.phonemize([text], strip=True)[0] 83 | text = ''.join([p for p in text if p in phonemes_set]) 84 | text = collapse_whitespace(text) 85 | text = text.strip() 86 | return text 87 | 88 | @classmethod 89 | def from_config(cls, config: Dict[str, Any]) -> 'Cleaner': 90 | return Cleaner( 91 | cleaner_name=config['preprocessing']['cleaner_name'], 92 | use_phonemes=config['preprocessing']['use_phonemes'], 93 | lang=config['preprocessing']['language'] 94 | ) 95 | 96 | -------------------------------------------------------------------------------- /utils/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /utils/text/recipes.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | from pathlib import Path 3 | from typing import Tuple 4 | 5 | import pandas as pd 6 | import tqdm 7 | 8 | from utils.files import get_files 9 | 10 | DEFAULT_SPEAKER_NAME = 'default_speaker' 11 | 12 | 13 | def read_metadata(path: Path, 14 | metafile: str, 15 | format: str, 16 | n_workers=1) -> Tuple[dict, dict]: 17 | if format == 'ljspeech': 18 | return read_ljspeech_format(path/metafile, multispeaker=False) 19 | elif format == 'ljspeech_multi': 20 | return read_ljspeech_format(path/metafile, multispeaker=True) 21 | elif format == 'vctk': 22 | return read_vctk_format(path, n_workers=n_workers) 23 | elif format == 'pandas': 24 | return read_pandas_format(path/metafile) 25 | else: 26 | raise ValueError(f'Metafile has unexpected ending: {path.stem}, expected [.csv, .tsv]"') 27 | 28 | 29 | def read_ljspeech_format(path: Path, multispeaker: bool = False) -> Tuple[dict, dict]: 30 | if not path.is_file(): 31 | raise ValueError(f'Could not find metafile: {path}, ' 32 | f'please make sure that you set the correct path and metafile name!') 33 | text_dict = {} 34 | speaker_dict = {} 35 | with open(str(path), encoding='utf-8') as f: 36 | for line in f: 37 | split = line.split('|') 38 | speaker_name = split[-2] if multispeaker and len(split) > 2 else DEFAULT_SPEAKER_NAME 39 | file_id, text = split[0], split[-1] 40 | text_dict[file_id] = text.replace('\n', '') 41 | speaker_dict[file_id] = speaker_name 42 | return text_dict, speaker_dict 43 | 44 | 45 | def read_vctk_format(path: Path, 46 | n_workers: int, 47 | extension='.txt') -> Tuple[dict, dict]: 48 | files = get_files(path, extension=extension) 49 | text_dict = {} 50 | speaker_dict = {} 51 | pool = Pool(processes=n_workers) 52 | for i, (file, text) in tqdm.tqdm(enumerate(pool.imap_unordered(read_line, files), 1), total=len(files)): 53 | text_id = file.name.replace(extension, '') 54 | speaker_id = file.parent.stem 55 | text_dict[text_id] = text.replace('\n', '') 56 | speaker_dict[text_id] = speaker_id 57 | return text_dict, speaker_dict 58 | 59 | 60 | def read_pandas_format(path: Path) -> Tuple[dict, dict]: 61 | if not path.is_file(): 62 | raise ValueError(f'Could not find metafile: {path}, ' 63 | f'please make sure that you set the correct path and metafile name!') 64 | df = pd.read_csv(str(path), sep='\t', encoding='utf-8') 65 | text_dict = {} 66 | speaker_dict = {} 67 | for index, row in df.iterrows(): 68 | id = row['file_id'] 69 | text_dict[id] = row['text'] 70 | speaker_dict[id] = row['speaker_id'] 71 | return text_dict, speaker_dict 72 | 73 | 74 | def read_line(file: Path) -> Tuple[Path, str]: 75 | with open(str(file), encoding='utf-8') as f: 76 | line = f.readlines()[0] 77 | return file, line 78 | -------------------------------------------------------------------------------- /utils/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | 8 | _pad = '_' 9 | _punctuation = '!\'(),.:;? ' 10 | _special = '-' 11 | 12 | # Phonemes 13 | _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ' 14 | _non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ' 15 | _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ' 16 | _suprasegmentals = 'ˈˌːˑ' 17 | _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' 18 | _diacrilics = 'ɚ˞ɫ' 19 | _extra_phons = ['g', 'ɝ', '̃', '̍', '̥', '̩', '̯', '͡'] # some extra symbols that I found in from wiktionary ipa annotations 20 | 21 | phonemes = list( 22 | _pad + _punctuation + _special + _vowels + _non_pulmonic_consonants 23 | + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics) + _extra_phons 24 | 25 | phonemes_set = set(phonemes) 26 | silent_phonemes_indices = [i for i, p in enumerate(phonemes) if p in _pad + _punctuation] -------------------------------------------------------------------------------- /utils/text/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from utils.text.symbols import phonemes 4 | 5 | 6 | class Tokenizer: 7 | 8 | def __init__(self) -> None: 9 | self.symbol_to_id = {s: i for i, s in enumerate(phonemes)} 10 | self.id_to_symbol = {i: s for i, s in enumerate(phonemes)} 11 | 12 | def __call__(self, text: str) -> List[int]: 13 | return [self.symbol_to_id[t] for t in text if t in self.symbol_to_id] 14 | 15 | def decode(self, sequence: List[int]) -> str: 16 | text = [self.id_to_symbol[s] for s in sequence if s in self.id_to_symbol] 17 | return ''.join(text) --------------------------------------------------------------------------------