├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── aligner_test_sentences.txt ├── config ├── data_config_wavernn.yaml └── training_config.yaml ├── create_training_data.py ├── data ├── __init__.py ├── audio.py ├── datasets.py ├── metadata_readers.py └── text │ ├── __init__.py │ ├── symbols.py │ └── tokenizer.py ├── docs ├── .gitignore ├── 404.html ├── Gemfile ├── Gemfile.lock ├── README.md ├── _config.yml ├── _layouts │ └── default.html ├── assets │ └── css │ │ └── style.scss ├── favicon.png ├── index.md ├── tboard_demo.gif └── transformer_logo.png ├── extract_durations.py ├── model ├── __init__.py ├── factory.py ├── layers.py ├── models.py └── transformer_utils.py ├── notebooks └── synthesize_forward_melgan.ipynb ├── predict_tts.py ├── requirements.txt ├── test_sentences.txt ├── tests ├── __init__.py ├── test_char_tokenizer.py ├── test_config.yaml └── test_loss.py ├── train_aligner.py ├── train_tts.py └── utils ├── __init__.py ├── alignments.py ├── decorators.py ├── display.py ├── logging_utils.py ├── losses.py ├── metrics.py ├── scheduling.py ├── scripts_utils.py ├── spectrogram_ops.py ├── training_config_manager.py └── vec_ops.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## CUSTOM 2 | samples 3 | .idea 4 | *.pkl 5 | *.hdf5 6 | .DS_Store 7 | private 8 | logs 9 | *.pt 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | Copyright (c) 2020 Axel Springer AI. All rights reserved. 4 | 5 | LICENSE 6 | 7 | The MIT License (MIT) 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |
5 |

6 | 7 |

8 |

A Text-to-Speech Transformer in TensorFlow 2

9 |

10 | 11 | 12 | Implementation of a non-autoregressive Transformer based neural network for Text-to-Speech (TTS).
13 | This repo is based, among others, on the following papers: 14 | - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) 15 | - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263) 16 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558) 17 | - [FastPitch: Parallel Text-to-speech with Pitch Prediction](https://fastpitch.github.io/) 18 | 19 | Our pre-trained LJSpeech model is compatible with the pre-trained vocoders: 20 | - [MelGAN](https://github.com/seungwonpark/melgan) 21 | - [HiFiGAN](https://github.com/jik876/hifi-gan) 22 | 23 | (older versions are available also for [WaveRNN](https://github.com/fatchord/WaveRNN)) 24 | 25 | For quick inference with these vocoders, checkout the [Vocoding branch](https://github.com/as-ideas/TransformerTTS/tree/vocoding) 26 | 27 | #### Non-Autoregressive 28 | Being non-autoregressive, this Transformer model is: 29 | - Robust: No repeats and failed attention modes for challenging sentences. 30 | - Fast: With no autoregression, predictions take a fraction of the time. 31 | - Controllable: It is possible to control the speed and pitch of the generated utterance. 32 | 33 | ## 🔈 Samples 34 | 35 | [Can be found here.](https://as-ideas.github.io/TransformerTTS/) 36 | 37 | These samples' spectrograms are converted using the pre-trained [MelGAN](https://github.com/seungwonpark/melgan) vocoder.
38 | 39 | 40 | Try it out on Colab: 41 | 42 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/main/notebooks/synthesize_forward_melgan.ipynb) 43 | 44 | ## Updates 45 | - 06/20: Added normalisation and pre-trained models compatible with the faster [MelGAN](https://github.com/seungwonpark/melgan) vocoder. 46 | - 11/20: Added pitch prediction. Autoregressive model is now specialized as an Aligner and Forward is now the only TTS model. Changed models architectures. Discontinued WaveRNN support. Improved duration extraction with Dijkstra algorithm. 47 | - 03/20: Vocoding branch. 48 | 49 | ## 📖 Contents 50 | - [Installation](#installation) 51 | - [API](#pre-trained-ljspeech-api) 52 | - [Dataset](#dataset) 53 | - [Training](#training) 54 | - [Aligner](#train-aligner-model) 55 | - [TTS](#train-tts-model) 56 | - [Prediction](#prediction) 57 | - [Model Weights](#model-weights) 58 | 59 | ## Installation 60 | 61 | Make sure you have: 62 | 63 | * Python >= 3.6 64 | 65 | Install espeak as phonemizer backend (for macOS use brew): 66 | ``` 67 | sudo apt-get install espeak 68 | ``` 69 | 70 | Then install the rest with pip: 71 | ``` 72 | pip install -r requirements.txt 73 | ``` 74 | 75 | Read the individual scripts for more command line arguments. 76 | 77 | ## Pre-Trained LJSpeech API 78 | Use our pre-trained model (with Griffin-Lim) from command line with 79 | ```commandline 80 | python predict_tts.py -t "Please, say something." 81 | ``` 82 | Or in a python script 83 | ```python 84 | from data.audio import Audio 85 | from model.factory import tts_ljspeech 86 | 87 | model = tts_ljspeech() 88 | audio = Audio.from_config(model.config) 89 | out = model.predict('Please, say something.') 90 | 91 | # Convert spectrogram to wav (with griffin lim) 92 | wav = audio.reconstruct_waveform(out['mel'].numpy().T) 93 | ``` 94 | 95 | You can specify the model step with the `--step` flag (CL) or `step` parameter (script).
96 | Steps from 60000 to 100000 are available at a frequency of 5K steps (60000, 65000, ..., 95000, 100000). 97 | 98 | IMPORTANT: make sure to checkout the correct repository version to use the API.
99 | Currently 493be6345341af0df3ae829de79c2793c9afd0ec 100 | 101 | ## Dataset 102 | You can directly use [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) to create the training dataset. 103 | 104 | #### Configuration 105 | * If training on LJSpeech, or if unsure, simply use ```config/training_config.yaml``` to create [MelGAN](https://github.com/seungwonpark/melgan) or [HiFiGAN](https://github.com/jik876/hifi-gan) compatible models 106 | * swap the content of ```data_config_wavernn.yaml``` in ```config/training_config.yaml``` to create models compatible with [WaveRNN](https://github.com/fatchord/WaveRNN) 107 | * **EDIT PATHS**: in `config/training_config.yaml` edit the paths to point at your dataset and log folders 108 | 109 | #### Custom dataset 110 | Prepare a folder containing your metadata and wav files, for instance 111 | ``` 112 | |- dataset_folder/ 113 | | |- metadata.csv 114 | | |- wavs/ 115 | | |- file1.wav 116 | | |- ... 117 | ``` 118 | if `metadata.csv` has the following format 119 | ``` wav_file_name|transcription ``` 120 | you can use the ljspeech preprocessor in ```data/metadata_readers.py```, otherwise add your own under the same file. 121 | 122 | Make sure that: 123 | - the metadata reader function name is the same as ```data_name``` field in ```training_config.yaml```. 124 | - the metadata file (can be anything) is specified under ```metadata_path``` in ```training_config.yaml``` 125 | 126 | ## Training 127 | Change the ```--config``` argument based on the configuration of your choice. 128 | ### Train Aligner Model 129 | #### Create training dataset 130 | ```bash 131 | python create_training_data.py --config config/training_config.yaml 132 | ``` 133 | This will populate the training data directory (default `transformer_tts_data.ljspeech`). 134 | #### Training 135 | ```bash 136 | python train_aligner.py --config config/training_config.yaml 137 | ``` 138 | ### Train TTS Model 139 | #### Compute alignment dataset 140 | First use the aligner model to create the durations dataset 141 | ```bash 142 | python extract_durations.py --config config/training_config.yaml 143 | ``` 144 | this will add the `durations.` as well as the char-wise pitch folders to the training data directory. 145 | #### Training 146 | ```bash 147 | python train_tts.py --config config/training_config.yaml 148 | ``` 149 | #### Training & Model configuration 150 | - Training and model settings can be configured in `training_config.yaml` 151 | 152 | #### Resume or restart training 153 | - To resume training simply use the same configuration files 154 | - To restart training, delete the weights and/or the logs from the logs folder with the training flag `--reset_dir` (both) or `--reset_logs`, `--reset_weights` 155 | 156 | #### Monitor training 157 | ```bash 158 | tensorboard --logdir /logs/directory/ 159 | ``` 160 | 161 | ![Tensorboard Demo](https://raw.githubusercontent.com/as-ideas/TransformerTTS/master/docs/tboard_demo.gif) 162 | ## Prediction 163 | ### With model weights 164 | From command line with 165 | ```commandline 166 | python predict_tts.py -t "Please, say something." -p /path/to/weights/ 167 | ``` 168 | Or in a python script 169 | ```python 170 | from model.models import ForwardTransformer 171 | from data.audio import Audio 172 | model = ForwardTransformer.load_model('/path/to/weights/') 173 | audio = Audio.from_config(model.config) 174 | out = model.predict('Please, say something.') 175 | 176 | # Convert spectrogram to wav (with griffin lim) 177 | wav = audio.reconstruct_waveform(out['mel'].numpy().T) 178 | ``` 179 | 180 | ## Model Weights 181 | Access the pre-trained models with the API call. 182 | 183 | Old weights 184 | | Model URL | Commit | Vocoder Commit| 185 | |---|---|---| 186 | |[ljspeech_tts_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ljspeech_weights_tts.zip)| 0cd7d33 | aca5990 | 187 | |[ljspeech_melgan_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_melgan_forward_transformer.zip)| 1c1cb03| aca5990 | 188 | |[ljspeech_melgan_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_melgan_autoregressive_transformer.zip)| 1c1cb03| aca5990 | 189 | |[ljspeech_wavernn_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_wavernn_forward_transformer.zip)| 1c1cb03| 3595219 | 190 | |[ljspeech_wavernn_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_wavernn_autoregressive_transformer.zip)| 1c1cb03| 3595219 | 191 | |[ljspeech_wavernn_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_forward_transformer.zip)| d9ccee6| 3595219 | 192 | |[ljspeech_wavernn_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_autoregressive_transformer.zip)| d9ccee6| 3595219 | 193 | |[ljspeech_wavernn_autoregressive_model_v1](https://github.com/as-ideas/tts_model_outputs/tree/master/ljspeech_transformertts)| 2f3a1b5| 3595219 | 194 | ## Maintainers 195 | * Francesco Cardinale, github: [cfrancesco](https://github.com/cfrancesco) 196 | 197 | ## Special thanks 198 | [MelGAN](https://github.com/seungwonpark/melgan) and [WaveRNN](https://github.com/fatchord/WaveRNN): data normalization and samples' vocoders are from these repos. 199 | 200 | [Erogol](https://github.com/erogol) and the Mozilla TTS team for the lively exchange on the topic. 201 | 202 | 203 | ## Copyright 204 | See [LICENSE](LICENSE) for details. 205 | -------------------------------------------------------------------------------- /aligner_test_sentences.txt: -------------------------------------------------------------------------------- 1 | Scientists at the CERN laboratory say they have discovered a new particle. 2 | If this is that, then what are those. -------------------------------------------------------------------------------- /config/data_config_wavernn.yaml: -------------------------------------------------------------------------------- 1 | audio_settings_name: WaveRNN_default 2 | text_settings_name: Stress_NoBreathing 3 | 4 | # TRAINING DATA SETTINGS 5 | n_samples: 100000 6 | n_test: 100 7 | mel_start_value: .5 8 | mel_end_value: -.5 9 | max_mel_len: 1_200 10 | min_mel_len: 80 11 | bucket_boundaries: [200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200] # mel bucketing 12 | bucket_batch_sizes: [64, 42, 32, 25, 21, 18, 16, 14, 12, 6, 1] 13 | val_bucket_batch_size: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1] 14 | 15 | # AUDIO SETTINGS 16 | sampling_rate: 22050 17 | n_fft: 2048 18 | mel_channels: 80 19 | hop_length: 275 20 | win_length: 1100 21 | f_min: 40 22 | f_max: null 23 | normalizer: WaveRNN # which mel normalization to use from utils.audio.py [MelGAN or WaveRNN] 24 | 25 | # SILENCE CUTTING 26 | trim_silence_top_db: 60 27 | trim_silence: False 28 | trim_long_silences: True 29 | # Params for trimming long silences, from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/hparams.py 30 | vad_window_length: 30 # In milliseconds 31 | vad_moving_average_width: 8 32 | vad_max_silence_length: 12 33 | vad_sample_rate: 16000 34 | 35 | # TOKENIZER 36 | phoneme_language: 'en-us' 37 | with_stress: True # use stress symbols in phonemization 38 | model_breathing: false # add a token for the initial breathing -------------------------------------------------------------------------------- /config/training_config.yaml: -------------------------------------------------------------------------------- 1 | paths: 2 | # PATHS: change accordingly 3 | wav_directory: '/path/to/wav_directory' # path to directory cointaining the wavs 4 | metadata_path: '/path/to/metadata.csv' # name of metadata file under wav_directory 5 | log_directory: '/path/to/logs_directory' # weights and logs are stored here 6 | train_data_directory: 'transformer_tts_data' # training data is stored here 7 | 8 | naming: 9 | data_name: ljspeech # raw data naming for default data reader (select function from data/metadata_readers.py) 10 | audio_settings_name: MelGAN_default 11 | text_settings_name: Stress_NoBreathing 12 | aligner_settings_name: alinger_extralayer_layernorm 13 | tts_settings_name: tts_swap_conv_dims 14 | 15 | # TRAINING DATA SETTINGS 16 | training_data_settings: 17 | n_test: 100 18 | mel_start_value: .5 19 | mel_end_value: -.5 20 | max_mel_len: 1_200 21 | min_mel_len: 80 22 | bucket_boundaries: [200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200] # mel bucketing 23 | bucket_batch_sizes: [64, 42, 32, 25, 21, 18, 16, 14, 12, 6, 1] 24 | val_bucket_batch_size: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1] 25 | 26 | # AUDIO SETTINGS 27 | audio_settings: 28 | sampling_rate: 22050 29 | n_fft: 1024 30 | mel_channels: 80 31 | hop_length: 256 32 | win_length: 1024 33 | f_min: 0 34 | f_max: 8000 35 | normalizer: MelGAN # which mel normalization to use from utils.audio.py [MelGAN or WaveRNN] 36 | 37 | # SILENCE CUTTING 38 | trim_silence_top_db: 60 39 | trim_silence: False 40 | trim_long_silences: True 41 | # Params for trimming long silences, from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/hparams.py 42 | vad_window_length: 30 # In milliseconds 43 | vad_moving_average_width: 8 44 | vad_max_silence_length: 12 45 | vad_sample_rate: 16000 46 | 47 | # Wav normalization 48 | norm_wav: True 49 | target_dBFS: -30 50 | int16_max: 32767 51 | 52 | text_settings: 53 | # TOKENIZER 54 | phoneme_language: 'en-us' 55 | with_stress: True # use stress symbols in phonemization 56 | model_breathing: false # add a token for the initial breathing 57 | 58 | aligner_settings: 59 | # ARCHITECTURE 60 | decoder_model_dimension: 256 61 | encoder_model_dimension: 256 62 | decoder_num_heads: [4, 4, 4, 4, 1] # the length of this defines the number of layers 63 | encoder_num_heads: [4, 4, 4, 4] # the length of this defines the number of layers 64 | encoder_feed_forward_dimension: 512 65 | decoder_feed_forward_dimension: 512 66 | decoder_prenet_dimension: 256 67 | encoder_prenet_dimension: 256 68 | encoder_max_position_encoding: 10000 69 | decoder_max_position_encoding: 10000 70 | 71 | # LOSSES 72 | stop_loss_scaling: 8 73 | 74 | # TRAINING 75 | dropout_rate: 0.1 76 | decoder_prenet_dropout: 0.1 77 | learning_rate_schedule: 78 | - [0, 1.0e-4] 79 | reduction_factor_schedule: 80 | - [0, 10] 81 | - [80_000, 5] 82 | - [100_000, 2] 83 | - [130_000, 1] 84 | max_steps: 260_000 85 | force_encoder_diagonal_steps: 500 86 | force_decoder_diagonal_steps: 7_000 87 | extract_attention_weighted: False # weighted average between last layer decoder attention heads when extracting durations 88 | debug: False 89 | 90 | # LOGGING 91 | validation_frequency: 5_000 92 | weights_save_frequency: 5_000 93 | train_images_plotting_frequency: 1_000 94 | keep_n_weights: 2 95 | keep_checkpoint_every_n_hours: 12 96 | n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps 97 | prediction_start_step: 10_000 # step after which to predict durations at validation time 98 | prediction_frequency: 5_000 99 | test_stencences: 100 | - aligner_test_sentences.txt 101 | 102 | tts_settings: 103 | # ARCHITECTURE 104 | decoder_model_dimension: 384 105 | encoder_model_dimension: 384 106 | decoder_num_heads: [2, 2, 2, 2, 2, 2] # the length of this defines the number of layers 107 | encoder_num_heads: [2, 2, 2, 2, 2, 2] # the length of this defines the number of layers 108 | encoder_feed_forward_dimension: null 109 | decoder_feed_forward_dimension: null 110 | encoder_attention_conv_filters: [1536, 384] 111 | decoder_attention_conv_filters: [1536, 384] 112 | encoder_attention_conv_kernel: 3 113 | decoder_attention_conv_kernel: 3 114 | encoder_max_position_encoding: 2000 115 | decoder_max_position_encoding: 10000 116 | encoder_dense_blocks: 0 117 | decoder_dense_blocks: 0 118 | transposed_attn_convs: True # if True, convolutions after MHA are over time. 119 | 120 | # STATS PREDICTORS ARCHITECTURE 121 | duration_conv_filters: [256, 226] 122 | pitch_conv_filters: [256, 226] 123 | duration_kernel_size: 3 124 | pitch_kernel_size: 3 125 | 126 | # TRAINING 127 | predictors_dropout: 0.1 128 | dropout_rate: 0.1 129 | learning_rate_schedule: 130 | - [0, 1.0e-4] 131 | max_steps: 100_000 132 | debug: False 133 | 134 | # LOGGING 135 | validation_frequency: 5_000 136 | prediction_frequency: 5_000 137 | weights_save_frequency: 5_000 138 | weights_save_starting_step: 5_000 139 | train_images_plotting_frequency: 1_000 140 | keep_n_weights: 5 141 | keep_checkpoint_every_n_hours: 12 142 | n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps 143 | prediction_start_step: 4_000 144 | text_prediction: 145 | - test_sentences.txt -------------------------------------------------------------------------------- /create_training_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import pickle 4 | 5 | import numpy as np 6 | from p_tqdm import p_uimap, p_umap 7 | 8 | from utils.logging_utils import SummaryManager 9 | from data.text import TextToTokens 10 | from data.datasets import DataReader 11 | from utils.training_config_manager import TrainingConfigManager 12 | from data.audio import Audio 13 | from data.text.symbols import _alphabet 14 | 15 | np.random.seed(42) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', type=str, required=True) 19 | parser.add_argument('--skip_phonemes', action='store_true') 20 | parser.add_argument('--skip_mels', action='store_true') 21 | 22 | args = parser.parse_args() 23 | for arg in vars(args): 24 | print('{}: {}'.format(arg, getattr(args, arg))) 25 | 26 | cm = TrainingConfigManager(args.config, aligner=True) 27 | cm.create_remove_dirs() 28 | metadatareader = DataReader.from_config(cm, kind='original', scan_wavs=True) 29 | summary_manager = SummaryManager(model=None, log_dir=cm.log_dir / 'data_preprocessing', config=cm.config, 30 | default_writer='data_preprocessing') 31 | file_ids_from_wavs = list(metadatareader.wav_paths.keys()) 32 | print(f"Reading wavs from {metadatareader.wav_directory}") 33 | print(f"Reading metadata from {metadatareader.metadata_path}") 34 | print(f'\nFound {len(metadatareader.filenames)} metadata lines.') 35 | print(f'\nFound {len(file_ids_from_wavs)} wav files.') 36 | cross_file_ids = [fid for fid in file_ids_from_wavs if fid in metadatareader.filenames] 37 | print(f'\nThere are {len(cross_file_ids)} wav file names that correspond to metadata lines.') 38 | 39 | if not args.skip_mels: 40 | 41 | def process_wav(wav_path: Path): 42 | file_name = wav_path.stem 43 | y, sr = audio.load_wav(str(wav_path)) 44 | pitch = audio.extract_pitch(y) 45 | mel = audio.mel_spectrogram(y) 46 | assert mel.shape[1] == audio.config['mel_channels'], len(mel.shape) == 2 47 | assert mel.shape[0] == pitch.shape[0], f'{mel.shape[0]} == {pitch.shape[0]} (wav {y.shape})' 48 | mel_path = (cm.mel_dir / file_name).with_suffix('.npy') 49 | pitch_path = (cm.pitch_dir / file_name).with_suffix('.npy') 50 | np.save(mel_path, mel) 51 | np.save(pitch_path, pitch) 52 | return {'fname': file_name, 'mel.len': mel.shape[0], 'pitch.path': pitch_path, 'pitch': pitch} 53 | 54 | 55 | print(f"\nMels will be stored stored under") 56 | print(f"{cm.mel_dir}") 57 | audio = Audio.from_config(config=cm.config) 58 | wav_files = [metadatareader.wav_paths[k] for k in cross_file_ids] 59 | len_dict = {} 60 | remove_files = [] 61 | mel_lens = [] 62 | pitches = {} 63 | wav_iter = p_uimap(process_wav, wav_files) 64 | for out_dict in wav_iter: 65 | len_dict.update({out_dict['fname']: out_dict['mel.len']}) 66 | pitches.update({out_dict['pitch.path']: out_dict['pitch']}) 67 | if out_dict['mel.len'] > cm.config['max_mel_len'] or out_dict['mel.len'] < cm.config['min_mel_len']: 68 | remove_files.append(out_dict['fname']) 69 | else: 70 | mel_lens.append(out_dict['mel.len']) 71 | 72 | 73 | def normalize_pitch_vectors(pitch_vecs): 74 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] 75 | for v in pitch_vecs.values()]) 76 | mean, std = np.mean(nonzeros), np.std(nonzeros) 77 | return mean, std 78 | 79 | 80 | def process_pitches(item: tuple): 81 | fname, pitch = item 82 | zero_idxs = np.where(pitch == 0.0)[0] 83 | pitch -= mean 84 | pitch /= std 85 | pitch[zero_idxs] = 0.0 86 | np.save(fname, pitch) 87 | 88 | 89 | mean, std = normalize_pitch_vectors(pitches) 90 | pickle.dump({'pitch_mean': mean, 'pitch_std': std}, open(cm.data_dir / 'pitch_stats.pkl', 'wb')) 91 | pitch_iter = p_umap(process_pitches, pitches.items()) 92 | 93 | pickle.dump(len_dict, open(cm.data_dir / 'mel_len.pkl', 'wb')) 94 | pickle.dump(remove_files, open(cm.data_dir / 'under-over_sized_mels.pkl', 'wb')) 95 | summary_manager.add_histogram('Mel Lengths', values=np.array(mel_lens)) 96 | total_mel_len = np.sum(mel_lens) 97 | total_wav_len = total_mel_len * audio.config['hop_length'] 98 | summary_manager.display_scalar('Total duration (hours)', 99 | scalar_value=total_wav_len / audio.config['sampling_rate'] / 60. ** 2) 100 | 101 | if not args.skip_phonemes: 102 | remove_files = pickle.load(open(cm.data_dir / 'under-over_sized_mels.pkl', 'rb')) 103 | phonemized_metadata_path = cm.phonemized_metadata_path 104 | train_metadata_path = cm.train_metadata_path 105 | test_metadata_path = cm.valid_metadata_path 106 | print(f'\nReading metadata from {metadatareader.metadata_path}') 107 | print(f'\nFound {len(metadatareader.filenames)} lines.') 108 | filter_metadata = [] 109 | for fname in cross_file_ids: 110 | item = metadatareader.text_dict[fname] 111 | non_p = [c for c in item if c in _alphabet] 112 | if len(non_p) < 1: 113 | filter_metadata.append(fname) 114 | if len(filter_metadata) > 0: 115 | print(f'Removing {len(filter_metadata)} suspiciously short line(s):') 116 | for fname in filter_metadata: 117 | print(f'{fname}: {metadatareader.text_dict[fname]}') 118 | print(f'\nRemoving {len(remove_files)} line(s) due to mel filtering.') 119 | remove_files += filter_metadata 120 | metadata_file_ids = [fname for fname in cross_file_ids if fname not in remove_files] 121 | metadata_len = len(metadata_file_ids) 122 | sample_items = np.random.choice(metadata_file_ids, 5) 123 | test_len = cm.config['n_test'] 124 | train_len = metadata_len - test_len 125 | print(f'\nMetadata contains {metadata_len} lines.') 126 | print(f'\nFiles will be stored under {cm.data_dir}') 127 | print(f' - all: {phonemized_metadata_path}') 128 | print(f' - {train_len} training lines: {train_metadata_path}') 129 | print(f' - {test_len} validation lines: {test_metadata_path}') 130 | 131 | print('\nMetadata samples:') 132 | for i in sample_items: 133 | print(f'{i}:{metadatareader.text_dict[i]}') 134 | summary_manager.add_text(f'{i}/text', text=metadatareader.text_dict[i]) 135 | 136 | # run cleaner on raw text 137 | text_proc = TextToTokens.default(cm.config['phoneme_language'], add_start_end=False, 138 | with_stress=cm.config['with_stress'], model_breathing=cm.config['model_breathing'], 139 | njobs=1) 140 | 141 | 142 | def process_phonemes(file_id): 143 | text = metadatareader.text_dict[file_id] 144 | try: 145 | phon = text_proc.phonemizer(text) 146 | except Exception as e: 147 | print(f'{e}\nFile id {file_id}') 148 | raise BrokenPipeError 149 | return (file_id, phon) 150 | 151 | 152 | print('\nPHONEMIZING') 153 | phonemized_data = {} 154 | phon_iter = p_uimap(process_phonemes, metadata_file_ids) 155 | for (file_id, phonemes) in phon_iter: 156 | phonemized_data.update({file_id: phonemes}) 157 | 158 | print('\nPhonemized metadata samples:') 159 | for i in sample_items: 160 | print(f'{i}:{phonemized_data[i]}') 161 | summary_manager.add_text(f'{i}/phonemes', text=phonemized_data[i]) 162 | 163 | new_metadata = [f'{k}|{v}\n' for k, v in phonemized_data.items()] 164 | shuffled_metadata = np.random.permutation(new_metadata) 165 | train_metadata = shuffled_metadata[0:train_len] 166 | test_metadata = shuffled_metadata[-test_len:] 167 | 168 | with open(phonemized_metadata_path, 'w+', encoding='utf-8') as file: 169 | file.writelines(new_metadata) 170 | with open(train_metadata_path, 'w+', encoding='utf-8') as file: 171 | file.writelines(train_metadata) 172 | with open(test_metadata_path, 'w+', encoding='utf-8') as file: 173 | file.writelines(test_metadata) 174 | # some checks 175 | assert metadata_len == len(set(list(phonemized_data.keys()))), \ 176 | f'Length of metadata ({metadata_len}) does not match the length of the phoneme array ({len(set(list(phonemized_data.keys())))}). Check for empty text lines in metadata.' 177 | assert len(train_metadata) + len(test_metadata) == metadata_len, \ 178 | f'Train and/or validation lengths incorrect. ({len(train_metadata)} + {len(test_metadata)} != {metadata_len})' 179 | 180 | print('\nDone') 181 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/data/__init__.py -------------------------------------------------------------------------------- /data/audio.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import struct 3 | 4 | import librosa 5 | import numpy as np 6 | import librosa.display 7 | from matplotlib import pyplot as plt 8 | import soundfile as sf 9 | import webrtcvad 10 | from scipy.ndimage import binary_dilation 11 | import pyworld as pw 12 | 13 | 14 | class Audio(): 15 | def __init__(self, 16 | sampling_rate: int, 17 | n_fft: int, 18 | mel_channels: int, 19 | hop_length: int, 20 | win_length: int, 21 | f_min: int, 22 | f_max: int, 23 | normalizer: str, 24 | norm_wav: bool = None, 25 | target_dBFS: int = None, 26 | int16_max: int = None, 27 | trim_long_silences: bool = None, 28 | trim_silence: bool = None, 29 | trim_silence_top_db: int = None, 30 | vad_window_length: int = None, 31 | vad_sample_rate: int = None, 32 | vad_moving_average_width: int = None, 33 | vad_max_silence_length: int = None, 34 | **kwargs): 35 | 36 | self.config = self._make_config(locals()) 37 | self.sampling_rate = sampling_rate 38 | self.n_fft = n_fft 39 | self.mel_channels = mel_channels 40 | self.hop_length = hop_length 41 | self.win_length = win_length 42 | self.f_min = f_min 43 | self.f_max = f_max 44 | self.norm_wav = norm_wav 45 | self.target_dBFS = target_dBFS 46 | self.int16_max = int16_max 47 | self.trim_long_silences = trim_long_silences 48 | self.trim_silence = trim_silence 49 | self.trim_silence_top_db = trim_silence_top_db 50 | self.vad_window_length = vad_window_length 51 | self.vad_sample_rate = vad_sample_rate 52 | self.vad_moving_average_width = vad_moving_average_width 53 | self.vad_max_silence_length = vad_max_silence_length 54 | self.normalizer = getattr(sys.modules[__name__], normalizer)() 55 | 56 | def _make_config(self, locals) -> dict: 57 | config = {} 58 | for k in locals: 59 | if (k != 'self') and (k != '__class__'): 60 | if isinstance(locals[k], dict): 61 | config.update(locals[k]) 62 | else: 63 | config.update({k: locals[k]}) 64 | return dict(config) 65 | 66 | def _normalize(self, S): 67 | return self.normalizer.normalize(S) 68 | 69 | def _denormalize(self, S): 70 | return self.normalizer.denormalize(S) 71 | 72 | def _linear_to_mel(self, spectrogram): 73 | return librosa.feature.melspectrogram( 74 | S=spectrogram, 75 | sr=self.sampling_rate, 76 | n_fft=self.n_fft, 77 | n_mels=self.mel_channels, 78 | fmin=self.f_min, 79 | fmax=self.f_max) 80 | 81 | def _stft(self, y): 82 | return librosa.stft( 83 | y=y, 84 | n_fft=self.n_fft, 85 | hop_length=self.hop_length, 86 | win_length=self.win_length) 87 | 88 | def mel_spectrogram(self, wav): 89 | """ This is what the model is trained to reproduce. """ 90 | D = self._stft(wav) 91 | S = self._linear_to_mel(np.abs(D)) 92 | return self._normalize(S).T 93 | 94 | def reconstruct_waveform(self, mel, n_iter=32): 95 | """ Uses Griffin-Lim phase reconstruction to convert from a normalized 96 | mel spectrogram back into a waveform. """ 97 | amp_mel = self._denormalize(mel) 98 | S = librosa.feature.inverse.mel_to_stft( 99 | amp_mel, 100 | power=1, 101 | sr=self.sampling_rate, 102 | n_fft=self.n_fft, 103 | fmin=self.f_min, 104 | fmax=self.f_max) 105 | wav = librosa.core.griffinlim( 106 | S, 107 | n_iter=n_iter, 108 | hop_length=self.hop_length, 109 | win_length=self.win_length) 110 | return wav 111 | 112 | def display_mel(self, mel, is_normal=True): 113 | if is_normal: 114 | mel = self._denormalize(mel) 115 | f = plt.figure(figsize=(10, 4)) 116 | s_db = librosa.power_to_db(mel, ref=np.max) 117 | ax = librosa.display.specshow(s_db, 118 | x_axis='time', 119 | y_axis='mel', 120 | sr=self.sampling_rate, 121 | fmin=self.f_min, 122 | fmax=self.f_max) 123 | f.add_subplot(ax) 124 | return f 125 | 126 | def load_wav(self, wav_path, preprocess=True): 127 | y, sr = librosa.load(wav_path, sr=self.sampling_rate) 128 | if preprocess: 129 | y = self.preprocess(y) 130 | return y, sr 131 | 132 | def preprocess(self, y): 133 | if self.norm_wav: 134 | y = self.normalize_volume(y, increase_only=True) 135 | if self.trim_long_silences: 136 | y = self.trim_audio_long_silences(y) 137 | if self.trim_silence: 138 | y = self.trim_audio_silence(y) 139 | if y.shape[0] % self.hop_length == 0: 140 | y = np.pad(y, (0, 1)) 141 | return y 142 | 143 | def save_wav(self, y, wav_path): 144 | sf.write(wav_path, data=y, samplerate=self.sampling_rate) 145 | 146 | def extract_pitch(self, y): 147 | _f0, t = pw.dio(y.astype(np.float64), fs=self.sampling_rate, 148 | frame_period=self.hop_length / self.sampling_rate * 1000) 149 | f0 = pw.stonemask(y.astype(np.float64), _f0, t, fs=self.sampling_rate) # pitch refinement 150 | 151 | return f0 152 | 153 | # from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/audio.py 154 | def normalize_volume(self, wav, increase_only=False, decrease_only=False): 155 | if increase_only and decrease_only: 156 | raise ValueError("Both increase only and decrease only are set") 157 | rms = np.sqrt(np.mean((wav * self.int16_max) ** 2)) 158 | wave_dBFS = 20 * np.log10(rms / self.int16_max) 159 | dBFS_change = self.target_dBFS - wave_dBFS 160 | if dBFS_change < 0 and increase_only or dBFS_change > 0 and decrease_only: 161 | return wav 162 | return wav * (10 ** (dBFS_change / 20)) 163 | 164 | def trim_audio_silence(self, wav): 165 | trimmed = librosa.effects.trim(wav, 166 | top_db=self.trim_silence_top_db, 167 | frame_length=256, 168 | hop_length=64) 169 | return trimmed[0] 170 | 171 | # from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/audio.py 172 | def trim_audio_long_silences(self, wav): 173 | samples_per_window = (self.vad_window_length * self.vad_sample_rate) // 1000 174 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 175 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * self.int16_max)).astype(np.int16)) 176 | voice_flags = [] 177 | vad = webrtcvad.Vad(mode=3) 178 | for window_start in range(0, len(wav), samples_per_window): 179 | window_end = window_start + samples_per_window 180 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 181 | sample_rate=self.vad_sample_rate)) 182 | voice_flags = np.array(voice_flags) 183 | 184 | def moving_average(array, width): 185 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 186 | ret = np.cumsum(array_padded, dtype=float) 187 | ret[width:] = ret[width:] - ret[:-width] 188 | return ret[width - 1:] / width 189 | 190 | audio_mask = moving_average(voice_flags, self.vad_moving_average_width) 191 | audio_mask = np.round(audio_mask).astype(np.bool) 192 | audio_mask[:] = binary_dilation(audio_mask[:], np.ones(self.vad_max_silence_length + 1)) 193 | audio_mask = np.repeat(audio_mask, samples_per_window) 194 | return wav[audio_mask] 195 | 196 | @classmethod 197 | def from_config(cls, config: dict): 198 | return cls(**config) 199 | 200 | 201 | class Normalizer: 202 | def normalize(self, S): 203 | raise NotImplementedError 204 | 205 | def denormalize(self, S): 206 | raise NotImplementedError 207 | 208 | 209 | class MelGAN(Normalizer): 210 | def __init__(self): 211 | super().__init__() 212 | self.clip_min = 1.0e-5 213 | 214 | def normalize(self, S): 215 | S = np.clip(S, a_min=self.clip_min, a_max=None) 216 | return np.log(S) 217 | 218 | def denormalize(self, S): 219 | return np.exp(S) 220 | 221 | 222 | class WaveRNN(Normalizer): 223 | def __init__(self): 224 | super().__init__() 225 | self.min_level_db = - 100 226 | self.max_norm = 4 227 | 228 | def normalize(self, S): 229 | S = self.amp_to_db(S) 230 | S = np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) 231 | return (S * 2 * self.max_norm) - self.max_norm 232 | 233 | def denormalize(self, S): 234 | S = (S + self.max_norm) / (2 * self.max_norm) 235 | S = (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db 236 | return self.db_to_amp(S) 237 | 238 | def amp_to_db(self, x): 239 | return 20 * np.log10(np.maximum(1e-5, x)) 240 | 241 | def db_to_amp(self, x): 242 | return np.power(10.0, x * 0.05) 243 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from random import Random 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from utils.training_config_manager import TrainingConfigManager 9 | from data.text.tokenizer import Tokenizer 10 | from data.metadata_readers import get_preprocessor_by_name 11 | 12 | 13 | def get_files(path: Union[Path, str], extension='.wav') -> List[Path]: 14 | """ Get all files from all subdirs with given extension. """ 15 | path = Path(path).expanduser().resolve() 16 | return list(path.rglob(f'*{extension}')) 17 | 18 | 19 | class DataReader: 20 | """ 21 | Reads dataset folder and constructs three useful objects: 22 | text_dict: {filename: text} 23 | wav_paths: {filename: path/to/filename.wav} 24 | filenames: [filename1, filename2, ...] 25 | 26 | IMPORTANT: Use only for information available from source dataset, not for 27 | training data. 28 | """ 29 | 30 | def __init__(self, wav_directory: str, metadata_path: str, metadata_reading_function=None, scan_wavs=False, 31 | training=False, is_processed=False): 32 | self.metadata_reading_function = metadata_reading_function 33 | self.wav_directory = Path(wav_directory) 34 | self.metadata_path = Path(metadata_path) 35 | if not is_processed: 36 | self.text_dict = self.metadata_reading_function(self.metadata_path) 37 | self.filenames = list(self.text_dict.keys()) 38 | else: 39 | self.text_dict, self.upsample = self.metadata_reading_function(self.metadata_path) 40 | self.filenames = list(self.text_dict.keys()) 41 | if training: 42 | self.filenames += self.upsample 43 | if scan_wavs: 44 | all_wavs = get_files(self.wav_directory, extension='.wav') 45 | self.wav_paths = {w.with_suffix('').name: w for w in all_wavs} 46 | 47 | @classmethod 48 | def from_config(cls, config_manager: TrainingConfigManager, kind: str, scan_wavs=False): 49 | kinds = ['original', 'phonemized', 'train', 'valid'] 50 | if kind not in kinds: 51 | raise ValueError(f'Invalid kind type. Expected one of: {kinds}') 52 | reader = get_preprocessor_by_name('post_processed_reader') 53 | training = False 54 | is_processed = True 55 | if kind == 'train': 56 | metadata = config_manager.train_metadata_path 57 | training = True 58 | elif kind == 'original': 59 | metadata = config_manager.metadata_path 60 | reader = get_preprocessor_by_name(config_manager.config['data_name']) 61 | is_processed = False 62 | elif kind == 'valid': 63 | metadata = config_manager.valid_metadata_path 64 | elif kind == 'phonemized': 65 | metadata = config_manager.phonemized_metadata_path 66 | 67 | return cls(wav_directory=config_manager.wav_directory, 68 | metadata_reading_function=reader, 69 | metadata_path=metadata, 70 | scan_wavs=scan_wavs, 71 | training=training, 72 | is_processed=is_processed) 73 | 74 | 75 | class AlignerPreprocessor: 76 | 77 | def __init__(self, 78 | mel_channels: int, 79 | mel_start_value: float, 80 | mel_end_value: float, 81 | tokenizer: Tokenizer): 82 | self.output_types = (tf.float32, tf.int32, tf.int32, tf.string) 83 | self.padded_shapes = ([None, mel_channels], [None], [None], []) 84 | self.start_vec = np.ones((1, mel_channels)) * mel_start_value 85 | self.end_vec = np.ones((1, mel_channels)) * mel_end_value 86 | self.tokenizer = tokenizer 87 | 88 | def __call__(self, mel, text, sample_name): 89 | encoded_phonemes = self.tokenizer(text) 90 | norm_mel = np.concatenate([self.start_vec, mel, self.end_vec], axis=0) 91 | stop_probs = np.ones((norm_mel.shape[0])) 92 | stop_probs[-1] = 2 93 | return norm_mel, encoded_phonemes, stop_probs, sample_name 94 | 95 | def get_sample_length(self, norm_mel, encoded_phonemes, stop_probs, sample_name): 96 | return tf.shape(norm_mel)[0] 97 | 98 | @classmethod 99 | def from_config(cls, config: TrainingConfigManager, tokenizer: Tokenizer): 100 | return cls(mel_channels=config.config['mel_channels'], 101 | mel_start_value=config.config['mel_start_value'], 102 | mel_end_value=config.config['mel_end_value'], 103 | tokenizer=tokenizer) 104 | 105 | 106 | class AlignerDataset: 107 | def __init__(self, 108 | data_reader: DataReader, 109 | preprocessor, 110 | mel_directory: str): 111 | self.metadata_reader = data_reader 112 | self.preprocessor = preprocessor 113 | self.mel_directory = Path(mel_directory) 114 | 115 | def _read_sample(self, sample_name): 116 | text = self.metadata_reader.text_dict[sample_name] 117 | mel = np.load((self.mel_directory / sample_name).with_suffix('.npy').as_posix()) 118 | return mel, text 119 | 120 | def _process_sample(self, sample_name): 121 | mel, text = self._read_sample(sample_name) 122 | return self.preprocessor(mel=mel, text=text, sample_name=sample_name) 123 | 124 | def get_dataset(self, bucket_batch_sizes, bucket_boundaries, shuffle=True, drop_remainder=False): 125 | return Dataset( 126 | samples=self.metadata_reader.filenames, 127 | preprocessor=self._process_sample, 128 | output_types=self.preprocessor.output_types, 129 | padded_shapes=self.preprocessor.padded_shapes, 130 | shuffle=shuffle, 131 | drop_remainder=drop_remainder, 132 | len_function=self.preprocessor.get_sample_length, 133 | bucket_batch_sizes=bucket_batch_sizes, 134 | bucket_boundaries=bucket_boundaries) 135 | 136 | @classmethod 137 | def from_config(cls, 138 | config: TrainingConfigManager, 139 | preprocessor, 140 | kind: str, 141 | mel_directory: str = None, ): 142 | kinds = ['original', 'phonemized', 'train', 'valid'] 143 | if kind not in kinds: 144 | raise ValueError(f'Invalid kind type. Expected one of: {kinds}') 145 | if mel_directory is None: 146 | mel_directory = config.mel_dir 147 | metadata_reader = DataReader.from_config(config, kind=kind) 148 | return cls(preprocessor=preprocessor, 149 | data_reader=metadata_reader, 150 | mel_directory=mel_directory) 151 | 152 | 153 | class TTSPreprocessor: 154 | def __init__(self, mel_channels, tokenizer: Tokenizer): 155 | self.output_types = (tf.float32, tf.int32, tf.int32, tf.float32, tf.string) 156 | self.padded_shapes = ([None, mel_channels], [None], [None], [None], []) 157 | self.tokenizer = tokenizer 158 | 159 | def __call__(self, text, mel, durations, pitch, sample_name): 160 | encoded_phonemes = self.tokenizer(text) 161 | return mel, encoded_phonemes, durations, pitch, sample_name 162 | 163 | def get_sample_length(self, mel, encoded_phonemes, durations, pitch, sample_name): 164 | return tf.shape(mel)[0] 165 | 166 | @classmethod 167 | def from_config(cls, config: TrainingConfigManager, tokenizer: Tokenizer): 168 | return cls(mel_channels=config.config['mel_channels'], 169 | tokenizer=tokenizer) 170 | 171 | 172 | class TTSDataset: 173 | def __init__(self, 174 | data_reader: DataReader, 175 | preprocessor: TTSPreprocessor, 176 | mel_directory: str, 177 | pitch_directory: str, 178 | duration_directory: str, 179 | pitch_per_char_directory: str): 180 | self.metadata_reader = data_reader 181 | self.preprocessor = preprocessor 182 | self.mel_directory = Path(mel_directory) 183 | self.duration_directory = Path(duration_directory) 184 | self.pitch_directory = Path(pitch_directory) 185 | self.pitch_per_char_directory = Path(pitch_per_char_directory) 186 | 187 | def _read_sample(self, sample_name: str): 188 | text = self.metadata_reader.text_dict[sample_name] 189 | mel = np.load((self.mel_directory / sample_name).with_suffix('.npy').as_posix()) 190 | durations = np.load( 191 | (self.duration_directory / sample_name).with_suffix('.npy').as_posix()) 192 | char_wise_pitch = np.load((self.pitch_per_char_directory / sample_name).with_suffix('.npy').as_posix()) 193 | return mel, text, durations, char_wise_pitch 194 | 195 | def _process_sample(self, sample_name: str): 196 | mel, text, durations, pitch = self._read_sample(sample_name) 197 | return self.preprocessor(mel=mel, text=text, durations=durations, pitch=pitch, sample_name=sample_name) 198 | 199 | def get_dataset(self, bucket_batch_sizes, bucket_boundaries, shuffle=True, drop_remainder=False): 200 | return Dataset( 201 | samples=self.metadata_reader.filenames, 202 | preprocessor=self._process_sample, 203 | output_types=self.preprocessor.output_types, 204 | padded_shapes=self.preprocessor.padded_shapes, 205 | len_function=self.preprocessor.get_sample_length, 206 | shuffle=shuffle, 207 | drop_remainder=drop_remainder, 208 | bucket_batch_sizes=bucket_batch_sizes, 209 | bucket_boundaries=bucket_boundaries) 210 | 211 | @classmethod 212 | def from_config(cls, 213 | config: TrainingConfigManager, 214 | preprocessor, 215 | kind: str, 216 | mel_directory: str = None, 217 | duration_directory: str = None, 218 | pitch_directory: str = None): 219 | kinds = ['phonemized', 'train', 'valid'] 220 | if kind not in kinds: 221 | raise ValueError(f'Invalid kind type. Expected one of: {kinds}') 222 | if mel_directory is None: 223 | mel_directory = config.mel_dir 224 | if duration_directory is None: 225 | duration_directory = config.duration_dir 226 | if pitch_directory is None: 227 | pitch_directory = config.pitch_dir 228 | metadata_reader = DataReader.from_config(config, 229 | kind=kind) 230 | return cls(preprocessor=preprocessor, 231 | data_reader=metadata_reader, 232 | mel_directory=mel_directory, 233 | duration_directory=duration_directory, 234 | pitch_directory=pitch_directory, 235 | pitch_per_char_directory=config.pitch_per_char) 236 | 237 | 238 | class Dataset: 239 | """ Model digestible dataset. """ 240 | 241 | def __init__(self, 242 | samples: list, 243 | preprocessor, 244 | len_function, 245 | padded_shapes: tuple, 246 | output_types: tuple, 247 | bucket_boundaries: list, 248 | bucket_batch_sizes: list, 249 | padding_values: tuple = None, 250 | shuffle=True, 251 | drop_remainder=True, 252 | seed=42): 253 | self._random = Random(seed) 254 | self._samples = samples[:] 255 | self.preprocessor = preprocessor 256 | dataset = tf.data.Dataset.from_generator(lambda: self._datagen(shuffle), 257 | output_types=output_types) 258 | # TODO: pass bin args 259 | binned_data = dataset.apply( 260 | tf.data.experimental.bucket_by_sequence_length( 261 | len_function, 262 | bucket_boundaries=bucket_boundaries, 263 | bucket_batch_sizes=bucket_batch_sizes, 264 | padded_shapes=padded_shapes, 265 | drop_remainder=drop_remainder, 266 | padding_values=padding_values 267 | )) 268 | self.dataset = binned_data 269 | self.data_iter = iter(binned_data.repeat(-1)) 270 | 271 | def next_batch(self): 272 | return next(self.data_iter) 273 | 274 | def all_batches(self): 275 | return iter(self.dataset) 276 | 277 | def _datagen(self, shuffle): 278 | """ 279 | Shuffle once before generating to avoid buffering 280 | """ 281 | samples = self._samples[:] 282 | if shuffle: 283 | self._random.shuffle(samples) 284 | return (self.preprocessor(s) for s in samples) 285 | -------------------------------------------------------------------------------- /data/metadata_readers.py: -------------------------------------------------------------------------------- 1 | """ 2 | methods for reading a dataset and return a dictionary of the form: 3 | { 4 | filename: text_line, 5 | ... 6 | } 7 | """ 8 | 9 | import sys 10 | from typing import Dict, List, Tuple 11 | 12 | 13 | def get_preprocessor_by_name(name: str): 14 | """ 15 | Returns the respective data function. 16 | Taken from https://github.com/mozilla/TTS/blob/master/TTS/tts/datasets/preprocess.py 17 | """ 18 | thismodule = sys.modules[__name__] 19 | return getattr(thismodule, name.lower()) 20 | 21 | 22 | def ljspeech(metadata_path: str, column_sep='|') -> dict: 23 | text_dict = {} 24 | with open(metadata_path, 'r', encoding='utf-8') as f: 25 | for l in f.readlines(): 26 | l_split = l.split(column_sep) 27 | filename, text = l_split[0], l_split[-1] 28 | if filename.endswith('.wav'): 29 | filename = filename.split('.')[0] 30 | text = text.replace('\n', '') 31 | text_dict.update({filename: text}) 32 | return text_dict 33 | 34 | 35 | def post_processed_reader(metadata_path: str, column_sep='|', upsample_indicators='?!', upsample_factor=10) -> Tuple[ 36 | Dict, List]: 37 | """ 38 | Used to read metadata files created within the repo. 39 | """ 40 | text_dict = {} 41 | upsample = [] 42 | with open(metadata_path, 'r', encoding='utf-8') as f: 43 | for l in f.readlines(): 44 | l_split = l.split(column_sep) 45 | filename, text = l_split[0], l_split[1] 46 | text = text.replace('\n', '') 47 | if any(el in text for el in list(upsample_indicators)): 48 | upsample.extend([filename] * upsample_factor) 49 | text_dict.update({filename: text}) 50 | return text_dict, upsample 51 | 52 | 53 | if __name__ == '__main__': 54 | metadata_path = '/Volumes/data/datasets/LJSpeech-1.1/metadata.csv' 55 | d = get_preprocessor_by_name('ljspeech')(metadata_path) 56 | key_list = list(d.keys()) 57 | print('metadata head') 58 | for key in key_list[:5]: 59 | print(f'{key}: {d[key]}') 60 | print('metadata tail') 61 | for key in key_list[-5:]: 62 | print(f'{key}: {d[key]}') 63 | -------------------------------------------------------------------------------- /data/text/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from data.text.symbols import all_phonemes 4 | from data.text.tokenizer import Phonemizer, Tokenizer 5 | 6 | 7 | class TextToTokens: 8 | def __init__(self, phonemizer: Phonemizer, tokenizer: Tokenizer): 9 | self.phonemizer = phonemizer 10 | self.tokenizer = tokenizer 11 | 12 | def __call__(self, input_text: Union[str, list]) -> list: 13 | phons = self.phonemizer(input_text) 14 | tokens = self.tokenizer(phons) 15 | return tokens 16 | 17 | @classmethod 18 | def default(cls, language: str, add_start_end: bool, with_stress: bool, model_breathing: bool, njobs=1): 19 | phonemizer = Phonemizer(language=language, njobs=njobs, with_stress=with_stress) 20 | tokenizer = Tokenizer(add_start_end=add_start_end, model_breathing=model_breathing) 21 | return cls(phonemizer=phonemizer, tokenizer=tokenizer) 22 | -------------------------------------------------------------------------------- /data/text/symbols.py: -------------------------------------------------------------------------------- 1 | _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ' 2 | _non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ' 3 | _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ' 4 | _suprasegmentals = 'ˈˌːˑ' 5 | _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' 6 | _diacrilics = 'ɚ˞ɫ' 7 | _phonemes = sorted(list( 8 | _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) 9 | _punctuations = '!,-.:;? \'()' 10 | _alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzäüößÄÖÜ' 11 | 12 | all_phonemes = sorted(list(_phonemes) + list(_punctuations)) 13 | -------------------------------------------------------------------------------- /data/text/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import re 3 | 4 | from phonemizer.phonemize import phonemize 5 | 6 | from data.text.symbols import all_phonemes, _punctuations 7 | 8 | 9 | class Tokenizer: 10 | 11 | def __init__(self, start_token='>', end_token='<', pad_token='/', add_start_end=True, alphabet=None, 12 | model_breathing=True): 13 | if not alphabet: 14 | self.alphabet = all_phonemes 15 | else: 16 | self.alphabet = sorted(list(set(alphabet))) # for testing 17 | self.idx_to_token = {i: s for i, s in enumerate(self.alphabet, start=1)} 18 | self.idx_to_token[0] = pad_token 19 | self.token_to_idx = {s: [i] for i, s in self.idx_to_token.items()} 20 | self.vocab_size = len(self.alphabet) + 1 21 | self.add_start_end = add_start_end 22 | if add_start_end: 23 | self.start_token_index = len(self.alphabet) + 1 24 | self.end_token_index = len(self.alphabet) + 2 25 | self.vocab_size += 2 26 | self.idx_to_token[self.start_token_index] = start_token 27 | self.idx_to_token[self.end_token_index] = end_token 28 | self.model_breathing = model_breathing 29 | if model_breathing: 30 | self.breathing_token_index = self.vocab_size 31 | self.token_to_idx[' '] = self.token_to_idx[' '] + [self.breathing_token_index] 32 | self.vocab_size += 1 33 | self.breathing_token = '@' 34 | self.idx_to_token[self.breathing_token_index] = self.breathing_token 35 | self.token_to_idx[self.breathing_token] = [self.breathing_token_index] 36 | 37 | def __call__(self, sentence: str) -> list: 38 | sequence = [self.token_to_idx[c] for c in sentence] # No filtering: text should only contain known chars. 39 | sequence = [item for items in sequence for item in items] 40 | if self.model_breathing: 41 | sequence = [self.breathing_token_index] + sequence 42 | if self.add_start_end: 43 | sequence = [self.start_token_index] + sequence + [self.end_token_index] 44 | return sequence 45 | 46 | def decode(self, sequence: list) -> str: 47 | return ''.join([self.idx_to_token[int(t)] for t in sequence]) 48 | 49 | 50 | class Phonemizer: 51 | def __init__(self, language: str, with_stress: bool, njobs=4): 52 | self.language = language 53 | self.njobs = njobs 54 | self.with_stress = with_stress 55 | self.special_hyphen = '—' 56 | self.punctuation = ';:,.!?¡¿—…"«»“”' 57 | self._whitespace_re = re.compile(r'\s+') 58 | self._whitespace_punctuation_re = re.compile(f'\s*([{_punctuations}])\s*') 59 | 60 | def __call__(self, text: Union[str, list], with_stress=None, njobs=None, language=None) -> Union[str, list]: 61 | language = language or self.language 62 | njobs = njobs or self.njobs 63 | with_stress = with_stress or self.with_stress 64 | # phonemizer does not like hyphens. 65 | text = self._preprocess(text) 66 | phonemes = phonemize(text, 67 | language=language, 68 | backend='espeak', 69 | strip=True, 70 | preserve_punctuation=True, 71 | with_stress=with_stress, 72 | punctuation_marks=self.punctuation, 73 | njobs=njobs, 74 | language_switch='remove-flags') 75 | return self._postprocess(phonemes) 76 | 77 | def _preprocess_string(self, text: str): 78 | text = text.replace('-', self.special_hyphen) 79 | return text 80 | 81 | def _preprocess(self, text: Union[str, list]) -> Union[str, list]: 82 | if isinstance(text, list): 83 | return [self._preprocess_string(t) for t in text] 84 | elif isinstance(text, str): 85 | return self._preprocess_string(text) 86 | else: 87 | raise TypeError(f'{self} input must be list or str, not {type(text)}') 88 | 89 | def _collapse_whitespace(self, text: str) -> str: 90 | text = re.sub(self._whitespace_re, ' ', text) 91 | return re.sub(self._whitespace_punctuation_re, r'\1', text) 92 | 93 | def _postprocess_string(self, text: str) -> str: 94 | text = text.replace(self.special_hyphen, '-') 95 | text = ''.join([c for c in text if c in all_phonemes]) 96 | text = self._collapse_whitespace(text) 97 | text = text.strip() 98 | return text 99 | 100 | def _postprocess(self, text: Union[str, list]) -> Union[str, list]: 101 | if isinstance(text, list): 102 | return [self._postprocess_string(t) for t in text] 103 | elif isinstance(text, str): 104 | return self._postprocess_string(text) 105 | else: 106 | raise TypeError(f'{self} input must be list or str, not {type(text)}') 107 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _site 2 | .sass-cache 3 | .jekyll-cache 4 | .jekyll-metadata 5 | vendor 6 | .bundle -------------------------------------------------------------------------------- /docs/404.html: -------------------------------------------------------------------------------- 1 | --- 2 | permalink: /404.html 3 | layout: default 4 | --- 5 | 6 | 19 | 20 |
21 |

404

22 | 23 |

Page not found :(

24 |

The requested page could not be found.

25 |
26 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | # Hello! This is where you manage which Jekyll version is used to run. 3 | # When you want to use a different version, change it below, save the 4 | # file and run `bundle install`. Run Jekyll with `bundle exec`, like so: 5 | # 6 | # bundle exec jekyll serve 7 | # 8 | # This will help ensure the proper Jekyll version is running. 9 | # Happy Jekylling! 10 | # gem "jekyll", "~> 4.0.0" 11 | # This is the default theme for new Jekyll sites. You may change this to anything you like. 12 | gem "minima", "~> 2.5" 13 | # If you want to use GitHub Pages, remove the "gem "jekyll"" above and 14 | # uncomment the line below. To upgrade, run `bundle update github-pages`. 15 | gem "github-pages", group: :jekyll_plugins 16 | # If you have any plugins, put them here! 17 | group :jekyll_plugins do 18 | gem "jekyll-feed", "~> 0.12" 19 | end 20 | 21 | # Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem 22 | # and associated library. 23 | install_if -> { RUBY_PLATFORM =~ %r!mingw|mswin|java! } do 24 | gem "tzinfo", "~> 1.2" 25 | gem "tzinfo-data" 26 | end 27 | 28 | # Performance-booster for watching directories on Windows 29 | gem "wdm", "~> 0.1.1", :install_if => Gem.win_platform? 30 | 31 | gem "jekyll", "~> 3.8" 32 | -------------------------------------------------------------------------------- /docs/Gemfile.lock: -------------------------------------------------------------------------------- 1 | GEM 2 | remote: https://rubygems.org/ 3 | specs: 4 | activesupport (6.0.3.1) 5 | concurrent-ruby (~> 1.0, >= 1.0.2) 6 | i18n (>= 0.7, < 2) 7 | minitest (~> 5.1) 8 | tzinfo (~> 1.1) 9 | zeitwerk (~> 2.2, >= 2.2.2) 10 | addressable (2.7.0) 11 | public_suffix (>= 2.0.2, < 5.0) 12 | coffee-script (2.4.1) 13 | coffee-script-source 14 | execjs 15 | coffee-script-source (1.11.1) 16 | colorator (1.1.0) 17 | commonmarker (0.17.13) 18 | ruby-enum (~> 0.5) 19 | concurrent-ruby (1.1.6) 20 | dnsruby (1.61.3) 21 | addressable (~> 2.5) 22 | em-websocket (0.5.1) 23 | eventmachine (>= 0.12.9) 24 | http_parser.rb (~> 0.6.0) 25 | ethon (0.12.0) 26 | ffi (>= 1.3.0) 27 | eventmachine (1.2.7) 28 | execjs (2.7.0) 29 | faraday (1.0.1) 30 | multipart-post (>= 1.2, < 3) 31 | ffi (1.12.2) 32 | forwardable-extended (2.6.0) 33 | gemoji (3.0.1) 34 | github-pages (204) 35 | github-pages-health-check (= 1.16.1) 36 | jekyll (= 3.8.5) 37 | jekyll-avatar (= 0.7.0) 38 | jekyll-coffeescript (= 1.1.1) 39 | jekyll-commonmark-ghpages (= 0.1.6) 40 | jekyll-default-layout (= 0.1.4) 41 | jekyll-feed (= 0.13.0) 42 | jekyll-gist (= 1.5.0) 43 | jekyll-github-metadata (= 2.13.0) 44 | jekyll-mentions (= 1.5.1) 45 | jekyll-optional-front-matter (= 0.3.2) 46 | jekyll-paginate (= 1.1.0) 47 | jekyll-readme-index (= 0.3.0) 48 | jekyll-redirect-from (= 0.15.0) 49 | jekyll-relative-links (= 0.6.1) 50 | jekyll-remote-theme (= 0.4.1) 51 | jekyll-sass-converter (= 1.5.2) 52 | jekyll-seo-tag (= 2.6.1) 53 | jekyll-sitemap (= 1.4.0) 54 | jekyll-swiss (= 1.0.0) 55 | jekyll-theme-architect (= 0.1.1) 56 | jekyll-theme-cayman (= 0.1.1) 57 | jekyll-theme-dinky (= 0.1.1) 58 | jekyll-theme-hacker (= 0.1.1) 59 | jekyll-theme-leap-day (= 0.1.1) 60 | jekyll-theme-merlot (= 0.1.1) 61 | jekyll-theme-midnight (= 0.1.1) 62 | jekyll-theme-minimal (= 0.1.1) 63 | jekyll-theme-modernist (= 0.1.1) 64 | jekyll-theme-primer (= 0.5.4) 65 | jekyll-theme-slate (= 0.1.1) 66 | jekyll-theme-tactile (= 0.1.1) 67 | jekyll-theme-time-machine (= 0.1.1) 68 | jekyll-titles-from-headings (= 0.5.3) 69 | jemoji (= 0.11.1) 70 | kramdown (= 1.17.0) 71 | liquid (= 4.0.3) 72 | mercenary (~> 0.3) 73 | minima (= 2.5.1) 74 | nokogiri (>= 1.10.4, < 2.0) 75 | rouge (= 3.13.0) 76 | terminal-table (~> 1.4) 77 | github-pages-health-check (1.16.1) 78 | addressable (~> 2.3) 79 | dnsruby (~> 1.60) 80 | octokit (~> 4.0) 81 | public_suffix (~> 3.0) 82 | typhoeus (~> 1.3) 83 | html-pipeline (2.12.3) 84 | activesupport (>= 2) 85 | nokogiri (>= 1.4) 86 | http_parser.rb (0.6.0) 87 | i18n (0.9.5) 88 | concurrent-ruby (~> 1.0) 89 | jekyll (3.8.5) 90 | addressable (~> 2.4) 91 | colorator (~> 1.0) 92 | em-websocket (~> 0.5) 93 | i18n (~> 0.7) 94 | jekyll-sass-converter (~> 1.0) 95 | jekyll-watch (~> 2.0) 96 | kramdown (~> 1.14) 97 | liquid (~> 4.0) 98 | mercenary (~> 0.3.3) 99 | pathutil (~> 0.9) 100 | rouge (>= 1.7, < 4) 101 | safe_yaml (~> 1.0) 102 | jekyll-avatar (0.7.0) 103 | jekyll (>= 3.0, < 5.0) 104 | jekyll-coffeescript (1.1.1) 105 | coffee-script (~> 2.2) 106 | coffee-script-source (~> 1.11.1) 107 | jekyll-commonmark (1.3.1) 108 | commonmarker (~> 0.14) 109 | jekyll (>= 3.7, < 5.0) 110 | jekyll-commonmark-ghpages (0.1.6) 111 | commonmarker (~> 0.17.6) 112 | jekyll-commonmark (~> 1.2) 113 | rouge (>= 2.0, < 4.0) 114 | jekyll-default-layout (0.1.4) 115 | jekyll (~> 3.0) 116 | jekyll-feed (0.13.0) 117 | jekyll (>= 3.7, < 5.0) 118 | jekyll-gist (1.5.0) 119 | octokit (~> 4.2) 120 | jekyll-github-metadata (2.13.0) 121 | jekyll (>= 3.4, < 5.0) 122 | octokit (~> 4.0, != 4.4.0) 123 | jekyll-mentions (1.5.1) 124 | html-pipeline (~> 2.3) 125 | jekyll (>= 3.7, < 5.0) 126 | jekyll-optional-front-matter (0.3.2) 127 | jekyll (>= 3.0, < 5.0) 128 | jekyll-paginate (1.1.0) 129 | jekyll-readme-index (0.3.0) 130 | jekyll (>= 3.0, < 5.0) 131 | jekyll-redirect-from (0.15.0) 132 | jekyll (>= 3.3, < 5.0) 133 | jekyll-relative-links (0.6.1) 134 | jekyll (>= 3.3, < 5.0) 135 | jekyll-remote-theme (0.4.1) 136 | addressable (~> 2.0) 137 | jekyll (>= 3.5, < 5.0) 138 | rubyzip (>= 1.3.0) 139 | jekyll-sass-converter (1.5.2) 140 | sass (~> 3.4) 141 | jekyll-seo-tag (2.6.1) 142 | jekyll (>= 3.3, < 5.0) 143 | jekyll-sitemap (1.4.0) 144 | jekyll (>= 3.7, < 5.0) 145 | jekyll-swiss (1.0.0) 146 | jekyll-theme-architect (0.1.1) 147 | jekyll (~> 3.5) 148 | jekyll-seo-tag (~> 2.0) 149 | jekyll-theme-cayman (0.1.1) 150 | jekyll (~> 3.5) 151 | jekyll-seo-tag (~> 2.0) 152 | jekyll-theme-dinky (0.1.1) 153 | jekyll (~> 3.5) 154 | jekyll-seo-tag (~> 2.0) 155 | jekyll-theme-hacker (0.1.1) 156 | jekyll (~> 3.5) 157 | jekyll-seo-tag (~> 2.0) 158 | jekyll-theme-leap-day (0.1.1) 159 | jekyll (~> 3.5) 160 | jekyll-seo-tag (~> 2.0) 161 | jekyll-theme-merlot (0.1.1) 162 | jekyll (~> 3.5) 163 | jekyll-seo-tag (~> 2.0) 164 | jekyll-theme-midnight (0.1.1) 165 | jekyll (~> 3.5) 166 | jekyll-seo-tag (~> 2.0) 167 | jekyll-theme-minimal (0.1.1) 168 | jekyll (~> 3.5) 169 | jekyll-seo-tag (~> 2.0) 170 | jekyll-theme-modernist (0.1.1) 171 | jekyll (~> 3.5) 172 | jekyll-seo-tag (~> 2.0) 173 | jekyll-theme-primer (0.5.4) 174 | jekyll (> 3.5, < 5.0) 175 | jekyll-github-metadata (~> 2.9) 176 | jekyll-seo-tag (~> 2.0) 177 | jekyll-theme-slate (0.1.1) 178 | jekyll (~> 3.5) 179 | jekyll-seo-tag (~> 2.0) 180 | jekyll-theme-tactile (0.1.1) 181 | jekyll (~> 3.5) 182 | jekyll-seo-tag (~> 2.0) 183 | jekyll-theme-time-machine (0.1.1) 184 | jekyll (~> 3.5) 185 | jekyll-seo-tag (~> 2.0) 186 | jekyll-titles-from-headings (0.5.3) 187 | jekyll (>= 3.3, < 5.0) 188 | jekyll-watch (2.2.1) 189 | listen (~> 3.0) 190 | jemoji (0.11.1) 191 | gemoji (~> 3.0) 192 | html-pipeline (~> 2.2) 193 | jekyll (>= 3.0, < 5.0) 194 | kramdown (1.17.0) 195 | liquid (4.0.3) 196 | listen (3.2.1) 197 | rb-fsevent (~> 0.10, >= 0.10.3) 198 | rb-inotify (~> 0.9, >= 0.9.10) 199 | mercenary (0.3.6) 200 | mini_portile2 (2.5.1) 201 | minima (2.5.1) 202 | jekyll (>= 3.5, < 5.0) 203 | jekyll-feed (~> 0.9) 204 | jekyll-seo-tag (~> 2.1) 205 | minitest (5.14.1) 206 | multipart-post (2.1.1) 207 | nokogiri (1.11.5) 208 | mini_portile2 (~> 2.5.0) 209 | racc (~> 1.4) 210 | octokit (4.18.0) 211 | faraday (>= 0.9) 212 | sawyer (~> 0.8.0, >= 0.5.3) 213 | pathutil (0.16.2) 214 | forwardable-extended (~> 2.6) 215 | public_suffix (3.1.1) 216 | racc (1.5.2) 217 | rb-fsevent (0.10.3) 218 | rb-inotify (0.10.1) 219 | ffi (~> 1.0) 220 | rouge (3.13.0) 221 | ruby-enum (0.8.0) 222 | i18n 223 | rubyzip (2.3.0) 224 | safe_yaml (1.0.5) 225 | sass (3.7.4) 226 | sass-listen (~> 4.0.0) 227 | sass-listen (4.0.0) 228 | rb-fsevent (~> 0.9, >= 0.9.4) 229 | rb-inotify (~> 0.9, >= 0.9.7) 230 | sawyer (0.8.2) 231 | addressable (>= 2.3.5) 232 | faraday (> 0.8, < 2.0) 233 | terminal-table (1.8.0) 234 | unicode-display_width (~> 1.1, >= 1.1.1) 235 | thread_safe (0.3.6) 236 | typhoeus (1.3.1) 237 | ethon (>= 0.9.0) 238 | tzinfo (1.2.7) 239 | thread_safe (~> 0.1) 240 | tzinfo-data (1.2019.3) 241 | tzinfo (>= 1.0.0) 242 | unicode-display_width (1.7.0) 243 | wdm (0.1.1) 244 | zeitwerk (2.3.0) 245 | 246 | PLATFORMS 247 | ruby 248 | 249 | DEPENDENCIES 250 | github-pages 251 | jekyll (~> 3.8) 252 | jekyll-feed (~> 0.12) 253 | minima (~> 2.5) 254 | tzinfo (~> 1.2) 255 | tzinfo-data 256 | wdm (~> 0.1.1) 257 | 258 | BUNDLED WITH 259 | 2.1.4 260 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # TransformerTTS Project Page 2 | 3 | We use [Jekyll](https://jekyllrb.com/) to generate our project page. 4 | 5 | ### Usage 6 | 7 | To serve locally run: 8 | ``` 9 | bundle exec jekyll serve 10 | ``` -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Welcome to Jekyll! 2 | # 3 | # This config file is meant for settings that affect your whole blog, values 4 | # which you are expected to set up once and rarely edit after that. If you find 5 | # yourself editing this file very often, consider using Jekyll's data files 6 | # feature for the data you need to update frequently. 7 | # 8 | # For technical reasons, this file is *NOT* reloaded automatically when you use 9 | # 'bundle exec jekyll serve'. If you change this file, please restart the server process. 10 | # 11 | # If you need help with YAML syntax, here are some quick references for you: 12 | # https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml 13 | # https://learnxinyminutes.com/docs/yaml/ 14 | # 15 | # Site settings 16 | # These are used to personalize your new site. If you look in the HTML files, 17 | # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. 18 | # You can create any custom variable you would like, and they will be accessible 19 | # in the templates via {{ site.myvariable }}. 20 | 21 | title: TransformerTTS 22 | email: ai@axelspringer.com 23 | description: >- # this means to ignore newlines until "baseurl:" 24 | Implementation of a Transformer based neural network for text to speech. 25 | google_analytics: UA-137434942-7 26 | baseurl: "/TransformerTTS" # the subpath of your site, e.g. /blog 27 | url: "" # the base hostname & protocol for your site, e.g. http://example.com 28 | twitter_username: axelspringerai 29 | github_username: as-ideas 30 | 31 | # Build settings 32 | markdown: kramdown 33 | theme: jekyll-theme-primer 34 | plugins: 35 | - jekyll-feed 36 | 37 | # Exclude from processing. 38 | # The following items will not be processed, by default. 39 | # Any item listed under the `exclude:` key here will be automatically added to 40 | # the internal "default list". 41 | # 42 | # Excluded items can be processed by explicitly listing the directories or 43 | # their entries' file path in the `include:` list. 44 | # 45 | # exclude: 46 | # - .sass-cache/ 47 | # - .jekyll-cache/ 48 | # - gemfiles/ 49 | # - Gemfile 50 | # - Gemfile.lock 51 | # - node_modules/ 52 | # - vendor/bundle/ 53 | # - vendor/cache/ 54 | # - vendor/gems/ 55 | # - vendor/ruby/ 56 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | {% seo %} 11 | 12 | 13 | 14 | Fork me on GitHub 15 |
16 | 17 | {{ content }} 18 | 19 | {% if site.github.private != true and site.github.license %} 20 | 23 | {% endif %} 24 |
25 | 26 | 27 | {% if site.google_analytics %} 28 | 36 | {% endif %} 37 | 38 | -------------------------------------------------------------------------------- /docs/assets/css/style.scss: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | 4 | @import "{{ site.theme }}"; 5 | 6 | .text { 7 | font-size: 20px; 8 | } -------------------------------------------------------------------------------- /docs/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/docs/favicon.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |
5 |

6 | 7 |

8 |

A Text-to-Speech Transformer in TensorFlow 2

9 |

10 | 11 |

Samples are converted using the pre-trained HiFiGAN vocoder and with the standard Griffin-Lim algorithm for comparison. 12 |

13 | 14 | ## 🎧 Model samples 15 | 16 |

Introductory speech ODSC Boston 2021

17 | 18 | 19 | 20 | 21 |

Peter piper picked a peck of pickled peppers.

22 | 23 | | HiFiGAN | Griffin-Lim | 24 | |:---:|:---:| 25 | ||| 26 | 27 |

President Trump met with other leaders at the Group of twenty conference.

28 | 29 | | HiFiGAN | Griffin-Lim | 30 | |:---:|:---:| 31 | ||| 32 | 33 |

Scientists at the CERN laboratory say they have discovered a new particle.

34 | 35 | | HiFiGAN | Griffin-Lim | 36 | |:---:|:---:| 37 | ||| 38 | 39 |

There’s a way to measure the acute emotional intelligence that has never gone out of style.

40 | 41 | | HiFiGAN | Griffin-Lim | 42 | |:---:|:---:| 43 | ||| 44 | 45 |

The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled.

46 | 47 | | HiFiGAN | Griffin-Lim | 48 | |:---:|:---:| 49 | ||| 50 | 51 | 52 |

If I were to talk to a human, I would definitely try to sound normal. Wouldn't I?

53 | 54 | | HiFiGAN | Griffin-Lim | 55 | |:---:|:---:| 56 | ||| 57 | 58 | 59 | ### Robustness 60 | 61 |

To deliver interfaces that are significantly better suited to create and process RFC eight twenty one , RFC eight twenty two , RFC nine seventy seven , and MIME content.

62 | 63 | | HiFiGAN | Griffin-Lim | 64 | |:---:|:---:| 65 | ||| 66 | 67 | 68 | ### Comparison with [ForwardTacotron](https://github.com/as-ideas/ForwardTacotron) 69 |

In a statement announcing his resignation, Mr Ross, said: "While the intentions may have been well meaning, the reaction to this news shows that Mr Cummings interpretation of the government advice was not shared by the vast majority of people who have done as the government asked."

70 | 71 | | TransformerTTS | ForwardTacotron | 72 | |:---:|:---:| 73 | ||| 74 | -------------------------------------------------------------------------------- /docs/tboard_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/docs/tboard_demo.gif -------------------------------------------------------------------------------- /docs/transformer_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/docs/transformer_logo.png -------------------------------------------------------------------------------- /extract_durations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | from tqdm import tqdm 7 | from p_tqdm import p_umap 8 | 9 | from utils.training_config_manager import TrainingConfigManager 10 | from utils.logging_utils import SummaryManager 11 | from data.datasets import AlignerPreprocessor 12 | from utils.alignments import get_durations_from_alignment 13 | from utils.scripts_utils import dynamic_memory_allocation 14 | from data.datasets import AlignerDataset 15 | from data.datasets import DataReader 16 | 17 | np.random.seed(42) 18 | tf.random.set_seed(42) 19 | dynamic_memory_allocation() 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--config', dest='config', type=str) 24 | parser.add_argument('--best', dest='best', action='store_true', 25 | help='Use best head instead of weighted average of heads.') 26 | parser.add_argument('--autoregressive_weights', type=str, default=None, 27 | help='Explicit path to autoregressive model weights.') 28 | parser.add_argument('--skip_char_pitch', dest='skip_char_pitch', action='store_true') 29 | parser.add_argument('--skip_durations', dest='skip_durations', action='store_true') 30 | args = parser.parse_args() 31 | weighted = not args.best 32 | tag_description = ''.join([ 33 | f'{"_weighted" * weighted}{"_best" * (not weighted)}', 34 | ]) 35 | writer_tag = f'DurationExtraction{tag_description}' 36 | print(writer_tag) 37 | config_manager = TrainingConfigManager(config_path=args.config, aligner=True) 38 | config = config_manager.config 39 | config_manager.print_config() 40 | 41 | if not args.skip_durations: 42 | model = config_manager.load_model(args.autoregressive_weights) 43 | if model.r != 1: 44 | print(f"ERROR: model's reduction factor is greater than 1, check config. (r={model.r}") 45 | 46 | data_prep = AlignerPreprocessor.from_config(config=config_manager, 47 | tokenizer=model.text_pipeline.tokenizer) 48 | data_handler = AlignerDataset.from_config(config_manager, 49 | preprocessor=data_prep, 50 | kind='phonemized') 51 | target_dir = config_manager.duration_dir 52 | config_manager.dump_config() 53 | dataset = data_handler.get_dataset(bucket_batch_sizes=config['bucket_batch_sizes'], 54 | bucket_boundaries=config['bucket_boundaries'], 55 | shuffle=False, 56 | drop_remainder=False) 57 | 58 | last_layer_key = 'Decoder_LastBlock_CrossAttention' 59 | print(f'Extracting attention from layer {last_layer_key}') 60 | 61 | summary_manager = SummaryManager(model=model, log_dir=config_manager.log_dir / 'Duration Extraction', 62 | config=config, 63 | default_writer='Duration Extraction') 64 | all_durations = np.array([]) 65 | new_alignments = [] 66 | iterator = tqdm(enumerate(dataset.all_batches())) 67 | step = 0 68 | for c, (mel_batch, text_batch, stop_batch, file_name_batch) in iterator: 69 | iterator.set_description(f'Processing dataset') 70 | outputs = model.val_step(inp=text_batch, 71 | tar=mel_batch, 72 | stop_prob=stop_batch) 73 | attention_values = outputs['decoder_attention'][last_layer_key].numpy() 74 | text = text_batch.numpy() 75 | 76 | mel = mel_batch.numpy() 77 | 78 | durations, final_align, jumpiness, peakiness, diag_measure = get_durations_from_alignment( 79 | batch_alignments=attention_values, 80 | mels=mel, 81 | phonemes=text, 82 | weighted=weighted) 83 | batch_avg_jumpiness = tf.reduce_mean(jumpiness, axis=0) 84 | batch_avg_peakiness = tf.reduce_mean(peakiness, axis=0) 85 | batch_avg_diag_measure = tf.reduce_mean(diag_measure, axis=0) 86 | for i in range(tf.shape(jumpiness)[1]): 87 | summary_manager.display_scalar(tag=f'DurationAttentionJumpiness/head{i}', 88 | scalar_value=tf.reduce_mean(batch_avg_jumpiness[i]), step=c) 89 | summary_manager.display_scalar(tag=f'DurationAttentionPeakiness/head{i}', 90 | scalar_value=tf.reduce_mean(batch_avg_peakiness[i]), step=c) 91 | summary_manager.display_scalar(tag=f'DurationAttentionDiagonality/head{i}', 92 | scalar_value=tf.reduce_mean(batch_avg_diag_measure[i]), step=c) 93 | 94 | for i, name in enumerate(file_name_batch): 95 | all_durations = np.append(all_durations, durations[i]) # for plotting only 96 | summary_manager.add_image(tag='ExtractedAlignments', 97 | image=tf.expand_dims(tf.expand_dims(final_align[i], 0), -1), 98 | step=step) 99 | 100 | step += 1 101 | np.save(str(target_dir / f"{name.numpy().decode('utf-8')}.npy"), durations[i]) 102 | 103 | all_durations[all_durations >= 20] = 20 # for plotting only 104 | buckets = len(set(all_durations)) # for plotting only 105 | summary_manager.add_histogram(values=all_durations, tag='ExtractedDurations', buckets=buckets) 106 | 107 | if not args.skip_char_pitch: 108 | def _pitch_per_char(pitch, durations, mel_len): 109 | durs_cum = np.cumsum(np.pad(durations, (1, 0))) 110 | pitch_char = np.zeros((durations.shape[0],), dtype=np.float) 111 | for idx, a, b in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]): 112 | values = pitch[a:b][np.where(pitch[a:b] != 0.0)[0]] 113 | values = values[np.where((values * pitch_stats['pitch_std'] + pitch_stats['pitch_mean']) < 400)[0]] 114 | pitch_char[idx] = np.mean(values) if len(values) > 0 else 0.0 115 | return pitch_char 116 | 117 | 118 | def process_per_char_pitch(sample_name: str): 119 | pitch = np.load((config_manager.pitch_dir / sample_name).with_suffix('.npy').as_posix()) 120 | durations = np.load((config_manager.duration_dir / sample_name).with_suffix('.npy').as_posix()) 121 | mel = np.load((config_manager.mel_dir / sample_name).with_suffix('.npy').as_posix()) 122 | char_wise_pitch = _pitch_per_char(pitch, durations, mel.shape[0]) 123 | np.save((config_manager.pitch_per_char / sample_name).with_suffix('.npy').as_posix(), char_wise_pitch) 124 | 125 | 126 | metadatareader = DataReader.from_config(config_manager, kind='phonemized', scan_wavs=False) 127 | pitch_stats = pickle.load(open(config_manager.data_dir / 'pitch_stats.pkl', 'rb')) 128 | print(f'\nComputing phoneme-wise pitch') 129 | print(f'{len(metadatareader.filenames)} items found in {metadatareader.metadata_path}.') 130 | wav_iter = p_umap(process_per_char_pitch, metadatareader.filenames) 131 | print('Done.') 132 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/model/__init__.py -------------------------------------------------------------------------------- /model/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | import ruamel.yaml 6 | 7 | from model.models import ForwardTransformer, Aligner 8 | 9 | 10 | def tts_ljspeech(step='95000') -> Tuple[ForwardTransformer, dict]: 11 | model_name = f'bdf06b9_ljspeech_step_{step}.zip' 12 | remote_dir = 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/api_weights/bdf06b9_ljspeech/' 13 | model_path = tf.keras.utils.get_file(model_name, 14 | remote_dir + model_name, 15 | extract=True, 16 | archive_format='zip', 17 | cache_subdir='TransformerTTS_models') 18 | model_path = Path(model_path).with_suffix('') # remove extension 19 | return ForwardTransformer.load_model(model_path.as_posix()) 20 | 21 | 22 | def tts_custom(config_path: str, weights_path: str) -> Tuple[ForwardTransformer, dict]: 23 | yaml = ruamel.yaml.YAML() 24 | with open(config_path, 'rb') as session_yaml: 25 | config = yaml.load(session_yaml) 26 | model = ForwardTransformer.from_config(config) 27 | model.build_model_weights() 28 | model.load_weights(weights_path) 29 | return model, config 30 | 31 | 32 | def aligner_custom(config_path: str, weights_path: str) -> Tuple[Aligner, dict]: 33 | yaml = ruamel.yaml.YAML() 34 | with open(config_path, 'rb') as session_yaml: 35 | config = yaml.load(session_yaml) 36 | model = Aligner.from_config(config) 37 | model.build_model_weights() 38 | model.load_weights(weights_path) 39 | return model, config 40 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from model.transformer_utils import positional_encoding 4 | 5 | 6 | class CNNResNorm(tf.keras.layers.Layer): 7 | """ 8 | Module used in attention blocks, after MHA 9 | """ 10 | 11 | def __init__(self, 12 | filters: list, 13 | kernel_size: int, 14 | inner_activation: str, 15 | padding: str, 16 | dout_rate: float): 17 | super(CNNResNorm, self).__init__() 18 | self.n_layers = len(filters) 19 | self.convolutions = [tf.keras.layers.Conv1D(filters=f, 20 | kernel_size=kernel_size, 21 | padding=padding) 22 | for f in filters[:-1]] 23 | self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(self.n_layers - 1)] 24 | self.last_conv = tf.keras.layers.Conv1D(filters=filters[-1], 25 | kernel_size=kernel_size, 26 | padding=padding) 27 | self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-6) 28 | self.dropout = tf.keras.layers.Dropout(rate=dout_rate) 29 | 30 | def call_convs(self, x): 31 | for i in range(0, self.n_layers - 1): 32 | x = self.convolutions[i](x) 33 | x = self.inner_activations[i](x) 34 | return x 35 | 36 | def call(self, inputs, training): 37 | x = self.call_convs(inputs) 38 | x = self.last_conv(x) 39 | x = self.dropout(x, training=training) 40 | return self.normalization(inputs + x) 41 | 42 | 43 | class TransposedCNNResNorm(tf.keras.layers.Layer): 44 | """ 45 | Module used in attention blocks, after MHA 46 | """ 47 | 48 | def __init__(self, 49 | filters: list, 50 | kernel_size: int, 51 | inner_activation: str, 52 | padding: str, 53 | dout_rate: float): 54 | super(TransposedCNNResNorm, self).__init__() 55 | self.n_layers = len(filters) 56 | self.convolutions = [tf.keras.layers.Conv1D(filters=f, 57 | kernel_size=kernel_size, 58 | padding=padding) 59 | for f in filters[:-1]] 60 | self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(self.n_layers - 1)] 61 | self.last_conv = tf.keras.layers.Conv1D(filters=filters[-1], 62 | kernel_size=kernel_size, 63 | padding=padding) 64 | self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-6) 65 | self.dropout = tf.keras.layers.Dropout(rate=dout_rate) 66 | 67 | def call_convs(self, x): 68 | for i in range(0, self.n_layers - 1): 69 | x = self.convolutions[i](x) 70 | x = self.inner_activations[i](x) 71 | return x 72 | 73 | def call(self, inputs, training): 74 | x = tf.transpose(inputs, (0, 1, 2)) 75 | x = self.call_convs(x) 76 | x = self.last_conv(x) 77 | x = tf.transpose(x, (0, 1, 2)) 78 | x = self.dropout(x, training=training) 79 | return self.normalization(inputs + x) 80 | 81 | 82 | class FFNResNorm(tf.keras.layers.Layer): 83 | """ 84 | Module used in attention blocks, after MHA 85 | """ 86 | 87 | def __init__(self, 88 | model_dim: int, 89 | dense_hidden_units: int, 90 | dropout_rate: float, 91 | **kwargs): 92 | super(FFNResNorm, self).__init__(**kwargs) 93 | self.d1 = tf.keras.layers.Dense(dense_hidden_units, 'relu') 94 | self.d2 = tf.keras.layers.Dense(model_dim) 95 | self.dropout = tf.keras.layers.Dropout(dropout_rate) 96 | self.last_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) 97 | 98 | def call(self, x, training): 99 | ffn_out = self.d1(x) 100 | ffn_out = self.d2(ffn_out) # (batch_size, input_seq_len, model_dim) 101 | ffn_out = self.dropout(ffn_out, training=training) 102 | return self.last_ln(ffn_out + x) 103 | 104 | 105 | class MultiHeadAttention(tf.keras.layers.Layer): 106 | 107 | def __init__(self, model_dim: int, num_heads: int, dropout: float, **kwargs): 108 | super(MultiHeadAttention, self).__init__(**kwargs) 109 | self.num_heads = num_heads 110 | self.model_dim = model_dim 111 | 112 | assert model_dim % self.num_heads == 0 113 | 114 | self.depth = model_dim // self.num_heads 115 | 116 | self.wq = tf.keras.layers.Dense(model_dim) 117 | self.wk = tf.keras.layers.Dense(model_dim) 118 | self.wv = tf.keras.layers.Dense(model_dim) 119 | self.attention = ScaledDotProductAttention(dropout=dropout) 120 | self.dense = tf.keras.layers.Dense(model_dim) 121 | self.dropout = tf.keras.layers.Dropout(dropout) 122 | 123 | def split_heads(self, x, batch_size: int): 124 | """ Split the last dimension into (num_heads, depth). 125 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) 126 | """ 127 | 128 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 129 | return tf.transpose(x, perm=[0, 2, 1, 3]) 130 | 131 | def call(self, v, k, q_in, mask, training): 132 | batch_size = tf.shape(q_in)[0] 133 | 134 | q = self.wq(q_in) # (batch_size, seq_len, model_dim) 135 | k = self.wk(k) # (batch_size, seq_len, model_dim) 136 | v = self.wv(v) # (batch_size, seq_len, model_dim) 137 | 138 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 139 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 140 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 141 | 142 | scaled_attention, attention_weights = self.attention([q, k, v, mask], training=training) 143 | 144 | scaled_attention = tf.transpose(scaled_attention, 145 | perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth) 146 | concat_attention = tf.reshape(scaled_attention, 147 | (batch_size, -1, self.model_dim)) # (batch_size, seq_len_q, model_dim) 148 | concat_query = tf.concat([q_in, concat_attention], axis=-1) 149 | output = self.dense(concat_query) # (batch_size, seq_len_q, model_dim) 150 | output = self.dropout(output, training=training) 151 | return output, attention_weights 152 | 153 | 154 | class ScaledDotProductAttention(tf.keras.layers.Layer): 155 | """ Calculate the attention weights. 156 | q, k, v must have matching leading dimensions. 157 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. 158 | The mask has different shapes depending on its type(padding or look ahead) 159 | but it must be broadcastable for addition. 160 | 161 | Args: 162 | q: query shape == (..., seq_len_q, depth) 163 | k: key shape == (..., seq_len_k, depth) 164 | v: value shape == (..., seq_len_v, depth_v) 165 | mask: Float tensor with shape broadcastable 166 | to (..., seq_len_q, seq_len_k). Defaults to None. 167 | 168 | Returns: 169 | output, attention_weights 170 | """ 171 | 172 | def __init__(self, dropout: float): 173 | super(ScaledDotProductAttention, self).__init__() 174 | self.dropout = tf.keras.layers.Dropout(rate=dropout) 175 | 176 | def call(self, inputs, training=False): 177 | q, k, v, mask = inputs 178 | 179 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 180 | 181 | # scale matmul_qk 182 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 183 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 184 | 185 | # add the mask to the scaled tensor. 186 | if mask is not None: 187 | scaled_attention_logits += mask * -1e9 # TODO: add mask expansion here and remove from create padding mask 188 | 189 | # softmax is normalized on the last axis (seq_len_k) so that the scores 190 | # add up to 1. 191 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 192 | attention_weights = self.dropout(attention_weights, training=training) 193 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) 194 | 195 | return output, attention_weights 196 | 197 | 198 | class SelfAttentionResNorm(tf.keras.layers.Layer): 199 | 200 | def __init__(self, 201 | model_dim: int, 202 | num_heads: int, 203 | dropout_rate: float, 204 | **kwargs): 205 | super(SelfAttentionResNorm, self).__init__(**kwargs) 206 | self.mha = MultiHeadAttention(model_dim, num_heads, dropout=dropout_rate) 207 | self.last_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6) 208 | 209 | def call(self, x, training, mask): 210 | attn_out, attn_weights = self.mha(x, x, x, mask, training=training) # (batch_size, input_seq_len, model_dim) 211 | return self.last_ln(attn_out + x), attn_weights 212 | 213 | 214 | class SelfAttentionDenseBlock(tf.keras.layers.Layer): 215 | 216 | def __init__(self, 217 | model_dim: int, 218 | num_heads: int, 219 | dense_hidden_units: int, 220 | dropout_rate: float, 221 | **kwargs): 222 | super(SelfAttentionDenseBlock, self).__init__(**kwargs) 223 | self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) 224 | self.ffn = FFNResNorm(model_dim, dense_hidden_units, dropout_rate=dropout_rate) 225 | 226 | def call(self, x, training, mask): 227 | attn_out, attn_weights = self.sarn(x, mask=mask, training=training) 228 | dense_mask = 1. - tf.squeeze(mask, axis=(1, 2))[:, :, None] 229 | attn_out = attn_out * dense_mask 230 | return self.ffn(attn_out, training=training) * dense_mask, attn_weights 231 | 232 | 233 | class SelfAttentionConvBlock(tf.keras.layers.Layer): 234 | 235 | def __init__(self, 236 | model_dim: int, 237 | num_heads: int, 238 | dropout_rate: float, 239 | conv_filters: list, 240 | kernel_size: int, 241 | conv_activation: str, 242 | transposed_convs: bool, 243 | **kwargs): 244 | super(SelfAttentionConvBlock, self).__init__(**kwargs) 245 | self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) 246 | if transposed_convs: 247 | self.conv = TransposedCNNResNorm(filters=conv_filters, 248 | kernel_size=kernel_size, 249 | inner_activation=conv_activation, 250 | dout_rate=dropout_rate, 251 | padding='same') 252 | else: 253 | self.conv = CNNResNorm(filters=conv_filters, 254 | kernel_size=kernel_size, 255 | inner_activation=conv_activation, 256 | dout_rate=dropout_rate, 257 | padding='same') 258 | 259 | def call(self, x, training, mask): 260 | attn_out, attn_weights = self.sarn(x, mask=mask, training=training) 261 | conv_mask = 1. - tf.squeeze(mask, axis=(1, 2))[:, :, None] 262 | attn_out = attn_out * conv_mask 263 | conv = self.conv(attn_out, training=training) 264 | return conv * conv_mask, attn_weights 265 | 266 | 267 | class SelfAttentionBlocks(tf.keras.layers.Layer): 268 | def __init__(self, 269 | model_dim: int, 270 | feed_forward_dimension: int, 271 | num_heads: list, 272 | maximum_position_encoding: int, 273 | conv_filters: list, 274 | dropout_rate: float, 275 | dense_blocks: int, 276 | kernel_size: int, 277 | conv_activation: str, 278 | transposed_convs: bool = None, 279 | **kwargs): 280 | super(SelfAttentionBlocks, self).__init__(**kwargs) 281 | self.model_dim = model_dim 282 | self.pos_encoding_scalar = tf.Variable(1.) 283 | self.pos_encoding = positional_encoding(maximum_position_encoding, model_dim) 284 | self.dropout = tf.keras.layers.Dropout(dropout_rate) 285 | self.encoder_SADB = [ 286 | SelfAttentionDenseBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads, 287 | dense_hidden_units=feed_forward_dimension, name=f'{self.name}_SADB_{i}') 288 | for i, n_heads in enumerate(num_heads[:dense_blocks])] 289 | self.encoder_SACB = [ 290 | SelfAttentionConvBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads, 291 | name=f'{self.name}_SACB_{i}', kernel_size=kernel_size, 292 | conv_activation=conv_activation, conv_filters=conv_filters, 293 | transposed_convs=transposed_convs) 294 | for i, n_heads in enumerate(num_heads[dense_blocks:])] 295 | self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6) 296 | 297 | def call(self, inputs, training, padding_mask, reduction_factor=1): 298 | seq_len = tf.shape(inputs)[1] 299 | x = self.layernorm(inputs) 300 | x += self.pos_encoding_scalar * self.pos_encoding[:, :seq_len * reduction_factor:reduction_factor, :] 301 | x = self.dropout(x, training=training) 302 | attention_weights = {} 303 | for i, block in enumerate(self.encoder_SADB): 304 | x, attn_weights = block(x, training=training, mask=padding_mask) 305 | attention_weights[f'{self.name}_DenseBlock{i + 1}_SelfAttention'] = attn_weights 306 | for i, block in enumerate(self.encoder_SACB): 307 | x, attn_weights = block(x, training=training, mask=padding_mask) 308 | attention_weights[f'{self.name}_ConvBlock{i + 1}_SelfAttention'] = attn_weights 309 | 310 | return x, attention_weights 311 | 312 | 313 | class CrossAttentionResnorm(tf.keras.layers.Layer): 314 | 315 | def __init__(self, 316 | model_dim: int, 317 | num_heads: int, 318 | dropout_rate: float, 319 | **kwargs): 320 | super(CrossAttentionResnorm, self).__init__(**kwargs) 321 | self.mha = MultiHeadAttention(model_dim, num_heads, dropout=dropout_rate) 322 | self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6) 323 | 324 | def call(self, q, k, v, training, mask): 325 | attn_values, attn_weights = self.mha(v, k=k, q_in=q, mask=mask, training=training) 326 | out = self.layernorm(attn_values + q) 327 | return out, attn_weights 328 | 329 | 330 | class CrossAttentionDenseBlock(tf.keras.layers.Layer): 331 | 332 | def __init__(self, 333 | model_dim: int, 334 | num_heads: int, 335 | dense_hidden_units: int, 336 | dropout_rate: float, 337 | **kwargs): 338 | super(CrossAttentionDenseBlock, self).__init__(**kwargs) 339 | self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) 340 | self.carn = CrossAttentionResnorm(model_dim, num_heads, dropout_rate=dropout_rate) 341 | self.ffn = FFNResNorm(model_dim, dense_hidden_units, dropout_rate=dropout_rate) 342 | 343 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask): 344 | attn1, attn_weights_block1 = self.sarn(x, mask=look_ahead_mask, training=training) 345 | 346 | attn2, attn_weights_block2 = self.carn(attn1, v=enc_output, k=enc_output, 347 | mask=padding_mask, training=training) 348 | ffn_out = self.ffn(attn2, training=training) 349 | return ffn_out, attn_weights_block1, attn_weights_block2 350 | 351 | # This is never used. 352 | # class CrossAttentionConvBlock(tf.keras.layers.Layer): 353 | # 354 | # def __init__(self, 355 | # model_dim: int, 356 | # num_heads: int, 357 | # conv_filters: list, 358 | # dropout_rate: float, 359 | # kernel_size: int, 360 | # conv_padding: str, 361 | # conv_activation: str, 362 | # **kwargs): 363 | # super(CrossAttentionConvBlock, self).__init__(**kwargs) 364 | # self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate) 365 | # self.carn = CrossAttentionResnorm(model_dim, num_heads, dropout_rate=dropout_rate) 366 | # self.conv = CNNResNorm(filters=conv_filters, 367 | # kernel_size=kernel_size, 368 | # inner_activation=conv_activation, 369 | # padding=conv_padding, 370 | # dout_rate=dropout_rate) 371 | # 372 | # def call(self, x, enc_output, training, look_ahead_mask, padding_mask): 373 | # attn1, attn_weights_block1 = self.sarn(x, mask=look_ahead_mask, training=training) 374 | # 375 | # attn2, attn_weights_block2 = self.carn(attn1, v=enc_output, k=enc_output, 376 | # mask=padding_mask, training=training) 377 | # ffn_out = self.conv(attn2, training=training) 378 | # return ffn_out, attn_weights_block1, attn_weights_block2 379 | 380 | 381 | class CrossAttentionBlocks(tf.keras.layers.Layer): 382 | 383 | def __init__(self, 384 | model_dim: int, 385 | feed_forward_dimension: int, 386 | num_heads: list, 387 | maximum_position_encoding: int, 388 | dropout_rate: float, 389 | **kwargs): 390 | super(CrossAttentionBlocks, self).__init__(**kwargs) 391 | self.model_dim = model_dim 392 | self.pos_encoding_scalar = tf.Variable(1.) 393 | self.pos_encoding = positional_encoding(maximum_position_encoding, model_dim) 394 | self.dropout = tf.keras.layers.Dropout(dropout_rate) 395 | self.CADB = [ 396 | CrossAttentionDenseBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads, 397 | dense_hidden_units=feed_forward_dimension, name=f'{self.name}_CADB_{i}') 398 | for i, n_heads in enumerate(num_heads[:-1])] 399 | self.last_CADB = CrossAttentionDenseBlock(model_dim=model_dim, dropout_rate=dropout_rate, 400 | num_heads=num_heads[-1], 401 | dense_hidden_units=feed_forward_dimension, 402 | name=f'{self.name}_CADB_last') 403 | self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6) 404 | 405 | def call(self, inputs, enc_output, training, decoder_padding_mask, encoder_padding_mask, 406 | reduction_factor=1): 407 | seq_len = tf.shape(inputs)[1] 408 | x = self.layernorm(inputs) 409 | x += self.pos_encoding_scalar * self.pos_encoding[:, :seq_len * reduction_factor:reduction_factor, :] 410 | x = self.dropout(x, training=training) 411 | attention_weights = {} 412 | for i, block in enumerate(self.CADB): 413 | x, _, attn_weights = block(x, enc_output, training, decoder_padding_mask, encoder_padding_mask) 414 | attention_weights[f'{self.name}_DenseBlock{i + 1}_CrossAttention'] = attn_weights 415 | x, _, attn_weights = self.last_CADB(x, enc_output, training, decoder_padding_mask, encoder_padding_mask) 416 | attention_weights[f'{self.name}_LastBlock_CrossAttention'] = attn_weights 417 | return x, attention_weights 418 | 419 | 420 | class DecoderPrenet(tf.keras.layers.Layer): 421 | 422 | def __init__(self, 423 | model_dim: int, 424 | dense_hidden_units: int, 425 | dropout_rate: float, 426 | **kwargs): 427 | super(DecoderPrenet, self).__init__(**kwargs) 428 | self.d1 = tf.keras.layers.Dense(dense_hidden_units, 429 | activation='relu') # (batch_size, seq_len, dense_hidden_units) 430 | self.d2 = tf.keras.layers.Dense(model_dim, activation='relu') # (batch_size, seq_len, model_dim) 431 | self.rate = tf.Variable(dropout_rate, trainable=False) 432 | self.dropout_1 = tf.keras.layers.Dropout(self.rate) 433 | self.dropout_2 = tf.keras.layers.Dropout(self.rate) 434 | 435 | def call(self, x, training): 436 | self.dropout_1.rate = self.rate 437 | self.dropout_2.rate = self.rate 438 | x = self.d1(x) 439 | # use dropout also in inference for positional encoding relevance 440 | x = self.dropout_1(x, training=training) 441 | x = self.d2(x) 442 | x = self.dropout_2(x, training=training) 443 | return x 444 | 445 | 446 | class Postnet(tf.keras.layers.Layer): 447 | 448 | def __init__(self, mel_channels: int, **kwargs): 449 | super(Postnet, self).__init__(**kwargs) 450 | self.mel_channels = mel_channels 451 | self.stop_linear = tf.keras.layers.Dense(3) 452 | self.mel_out = tf.keras.layers.Dense(mel_channels) 453 | 454 | def call(self, x): 455 | stop = self.stop_linear(x) 456 | mel = self.mel_out(x) 457 | return { 458 | 'mel': mel, 459 | 'stop_prob': stop, 460 | } 461 | 462 | 463 | class StatPredictor(tf.keras.layers.Layer): 464 | def __init__(self, 465 | conv_filters: list, 466 | kernel_size: int, 467 | conv_padding: str, 468 | conv_activation: str, 469 | dense_activation: str, 470 | dropout_rate: float, 471 | **kwargs): 472 | super(StatPredictor, self).__init__(**kwargs) 473 | self.conv_blocks = CNNDropout(filters=conv_filters, 474 | kernel_size=kernel_size, 475 | padding=conv_padding, 476 | inner_activation=conv_activation, 477 | last_activation=conv_activation, 478 | dout_rate=dropout_rate) 479 | self.linear = tf.keras.layers.Dense(1, activation=dense_activation) 480 | 481 | def call(self, x, training, mask): 482 | x = x * mask 483 | x = self.conv_blocks(x, training=training) 484 | x = self.linear(x) 485 | return x * mask 486 | 487 | 488 | class CNNDropout(tf.keras.layers.Layer): 489 | def __init__(self, 490 | filters: list, 491 | kernel_size: int, 492 | inner_activation: str, 493 | last_activation: str, 494 | padding: str, 495 | dout_rate: float): 496 | super(CNNDropout, self).__init__() 497 | self.n_layers = len(filters) 498 | self.convolutions = [tf.keras.layers.Conv1D(filters=f, 499 | kernel_size=kernel_size, 500 | padding=padding) 501 | for f in filters[:-1]] 502 | self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(self.n_layers - 1)] 503 | self.last_conv = tf.keras.layers.Conv1D(filters=filters[-1], 504 | kernel_size=kernel_size, 505 | padding=padding) 506 | self.last_activation = tf.keras.layers.Activation(last_activation) 507 | self.dropouts = [tf.keras.layers.Dropout(rate=dout_rate) for _ in range(self.n_layers)] 508 | self.normalization = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(self.n_layers)] 509 | 510 | def call_convs(self, x, training): 511 | for i in range(0, self.n_layers - 1): 512 | x = self.convolutions[i](x) 513 | x = self.inner_activations[i](x) 514 | x = self.normalization[i](x) 515 | x = self.dropouts[i](x, training=training) 516 | return x 517 | 518 | def call(self, inputs, training): 519 | x = self.call_convs(inputs, training=training) 520 | x = self.last_conv(x) 521 | x = self.last_activation(x) 522 | x = self.normalization[-1](x) 523 | x = self.dropouts[-1](x, training=training) 524 | return x 525 | 526 | 527 | class Expand(tf.keras.layers.Layer): 528 | """ Expands a 3D tensor on its second axis given a list of dimensions. 529 | Tensor should be: 530 | batch_size, seq_len, dimension 531 | 532 | E.g: 533 | input = tf.Tensor([[[0.54710746 0.8943467 ] 534 | [0.7140938 0.97968304] 535 | [0.5347662 0.15213418]]], shape=(1, 3, 2), dtype=float32) 536 | dimensions = tf.Tensor([1 3 2], shape=(3,), dtype=int32) 537 | output = tf.Tensor([[[0.54710746 0.8943467 ] 538 | [0.7140938 0.97968304] 539 | [0.7140938 0.97968304] 540 | [0.7140938 0.97968304] 541 | [0.5347662 0.15213418] 542 | [0.5347662 0.15213418]]], shape=(1, 6, 2), dtype=float32) 543 | """ 544 | 545 | def __init__(self, model_dim, **kwargs): 546 | super(Expand, self).__init__(**kwargs) 547 | self.model_dimension = model_dim 548 | 549 | def call(self, x, dimensions): 550 | dimensions = tf.squeeze(dimensions, axis=-1) 551 | dimensions = tf.cast(tf.math.round(dimensions), tf.int32) 552 | seq_len = tf.shape(x)[1] 553 | batch_size = tf.shape(x)[0] 554 | # build masks from dimensions 555 | max_dim = tf.math.reduce_max(dimensions) 556 | tot_dim = tf.math.reduce_sum(dimensions) 557 | index_masks = tf.RaggedTensor.from_row_lengths(tf.ones(tot_dim), tf.reshape(dimensions, [-1])).to_tensor() 558 | index_masks = tf.cast(tf.reshape(index_masks, (batch_size, seq_len * max_dim)), tf.float32) 559 | non_zeros = seq_len * max_dim - tf.reduce_sum(max_dim - dimensions, axis=1) 560 | # stack and mask 561 | tiled = tf.tile(x, [1, 1, max_dim]) 562 | reshaped = tf.reshape(tiled, (batch_size, seq_len * max_dim, self.model_dimension)) 563 | mask_reshape = tf.multiply(reshaped, index_masks[:, :, tf.newaxis]) 564 | ragged = tf.RaggedTensor.from_row_lengths(mask_reshape[index_masks > 0], non_zeros) 565 | return ragged.to_tensor() 566 | -------------------------------------------------------------------------------- /model/transformer_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def get_angles(pos, i, model_dim): 6 | angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(model_dim)) 7 | return pos * angle_rates 8 | 9 | 10 | def positional_encoding(position, model_dim): 11 | angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(model_dim)[np.newaxis, :], model_dim) 12 | 13 | # apply sin to even indices in the array; 2i 14 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 15 | 16 | # apply cos to odd indices in the array; 2i+1 17 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 18 | 19 | pos_encoding = angle_rads[np.newaxis, ...] 20 | 21 | return tf.cast(pos_encoding, dtype=tf.float32) 22 | 23 | 24 | def create_encoder_padding_mask(seq): 25 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32) 26 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, y, x) 27 | 28 | 29 | def create_mel_padding_mask(seq): 30 | seq = tf.reduce_sum(tf.math.abs(seq), axis=-1) 31 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32) 32 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, y, x) 33 | 34 | 35 | def create_look_ahead_mask(size): 36 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 37 | return mask 38 | -------------------------------------------------------------------------------- /notebooks/synthesize_forward_melgan.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "zdMgfG7GMF_R" 8 | }, 9 | "source": [ 10 | "
\n", 11 | "

Transformer TTS: A Text-to-Speech Transformer in TensorFlow 2

\n", 12 | "

Audio synthesis with Forward Transformer TTS and MelGAN Vocoder

\n", 13 | "
\n", 14 | "\n", 15 | "## Forward Model" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "colab": { 23 | "base_uri": "https://localhost:8080/", 24 | "height": 225 25 | }, 26 | "colab_type": "code", 27 | "id": "JQ5YuFPAxXUy", 28 | "outputId": "23b61caa-9bc7-4bb8-b31b-0da6a029d6e6" 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "# Clone the Transformer TTS and MelGAN repos\n", 33 | "!git clone https://github.com/as-ideas/TransformerTTS.git\n", 34 | "!git clone https://github.com/seungwonpark/melgan.git" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "colab": { 42 | "base_uri": "https://localhost:8080/", 43 | "height": 1000 44 | }, 45 | "colab_type": "code", 46 | "id": "9bIzkIGjMRwm", 47 | "outputId": "c078ee93-da4c-4c93-daf2-e092acd8b929" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "# Install requirements\n", 52 | "!apt-get install -y espeak\n", 53 | "!pip install -r TransformerTTS/requirements.txt" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/", 62 | "height": 225 63 | }, 64 | "colab_type": "code", 65 | "id": "cOxdx6L5Hjcf", 66 | "outputId": "e919b138-a9d4-431e-dcf6-873925e8f0a9" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "!cd TransformerTTS/; git checkout c3405c53e435a06c809533aa4453923469081147" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "colab": {}, 78 | "colab_type": "code", 79 | "id": "W3tlwOlRbABh" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# Set up the paths\n", 84 | "from pathlib import Path\n", 85 | "MelGAN_path = 'melgan/'\n", 86 | "TTS_path = 'TransformerTTS/'\n", 87 | "\n", 88 | "import sys\n", 89 | "sys.path.append(TTS_path)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "colab": { 97 | "base_uri": "https://localhost:8080/", 98 | "height": 69 99 | }, 100 | "colab_type": "code", 101 | "id": "LucwkAK1yEVq", 102 | "outputId": "c2000cef-7533-4e95-e1c0-93559a93aec4" 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "# Load pretrained model\n", 107 | "from model.factory import tts_ljspeech\n", 108 | "from data.audio import Audio\n", 109 | "\n", 110 | "model, config = tts_ljspeech()\n", 111 | "audio = Audio(config)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "colab": {}, 119 | "colab_type": "code", 120 | "id": "_5RKHIDQyZvo" 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "# Synthesize text\n", 125 | "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", 126 | "out_normal = model.predict(sentence)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "colab": { 134 | "base_uri": "https://localhost:8080/", 135 | "height": 62 136 | }, 137 | "colab_type": "code", 138 | "id": "GXxdDHOAyZ6f", 139 | "outputId": "a5d611c2-316a-4aec-834b-93562b0f487a" 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "# Convert spectrogram to wav (with griffin lim)\n", 144 | "wav = audio.reconstruct_waveform(out_normal['mel'].numpy().T)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/", 153 | "height": 62 154 | }, 155 | "colab_type": "code", 156 | "id": "GXxdDHOAyZ6f", 157 | "outputId": "a5d611c2-316a-4aec-834b-93562b0f487a" 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "import IPython.display as ipd\n", 162 | "\n", 163 | "ipd.display(ipd.Audio(wav, rate=config['sampling_rate']))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "colab_type": "text", 170 | "id": "uyidCx84bAB1" 171 | }, 172 | "source": [ 173 | "You can also vary the speech speed" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "colab": { 181 | "base_uri": "https://localhost:8080/", 182 | "height": 62 183 | }, 184 | "colab_type": "code", 185 | "id": "_N3rM6qmbAB2", 186 | "outputId": "c68e3da8-e15d-431c-e377-c98457b83a52" 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "# 20% faster\n", 191 | "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", 192 | "out = model.predict(sentence, speed_regulator=1.20)\n", 193 | "wav = audio.reconstruct_waveform(out['mel'].numpy().T)\n", 194 | "ipd.display(ipd.Audio(wav, rate=config['sampling_rate']))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": { 201 | "colab": { 202 | "base_uri": "https://localhost:8080/", 203 | "height": 62 204 | }, 205 | "colab_type": "code", 206 | "id": "QxAIl9LkbAB6", 207 | "outputId": "a8e0531e-268f-4fd4-8a55-2c7914cc67d6" 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "# 10% slower\n", 212 | "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n", 213 | "out = model.predict(sentence, speed_regulator=.9)\n", 214 | "wav = audio.reconstruct_waveform(out['mel'].numpy().T)\n", 215 | "ipd.display(ipd.Audio(wav, rate=config['sampling_rate']))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "colab_type": "text", 222 | "id": "eZJo81viVus-" 223 | }, 224 | "source": [ 225 | "### MelGAN" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": { 232 | "colab": { 233 | "base_uri": "https://localhost:8080/", 234 | "height": 34 235 | }, 236 | "colab_type": "code", 237 | "id": "WjIuQALHTr-R", 238 | "outputId": "503b1a50-b658-4f38-acd3-ed8346571f24" 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "# Do some sys cleaning\n", 243 | "sys.path.remove(TTS_path)\n", 244 | "sys.modules.pop('model')" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": { 251 | "colab": { 252 | "base_uri": "https://localhost:8080/", 253 | "height": 101, 254 | "referenced_widgets": [ 255 | "d2d68febdf104697bbf2d4a7d132891a", 256 | "2820cb67ddb6435a9abc1b7252e7fe44", 257 | "a29f9e1975f248348c32254fe5b3d026", 258 | "b0b25d43cc6949d9b921d33984deea0f", 259 | "171439a16fba452a8fdab03a5711a6dc", 260 | "7b4b771dce304335933820d574f49e4f", 261 | "4c9a01915a2648f192071342e0cd3202", 262 | "90ad1e7188db4f98be3081554bb06cbc" 263 | ] 264 | }, 265 | "colab_type": "code", 266 | "id": "L4gZZOgmbACF", 267 | "outputId": "d7d953a7-9fe5-44e3-c750-4327dcf3d541" 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "sys.path.append(MelGAN_path)\n", 272 | "import torch\n", 273 | "import numpy as np\n", 274 | "\n", 275 | "vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')\n", 276 | "vocoder.eval()\n", 277 | "\n", 278 | "mel = torch.tensor(out_normal['mel'].numpy().T[np.newaxis,:,:])" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": { 285 | "colab": {}, 286 | "colab_type": "code", 287 | "id": "iqXjJY_2bACJ" 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "if torch.cuda.is_available():\n", 292 | " vocoder = vocoder.cuda()\n", 293 | " mel = mel.cuda()\n", 294 | "\n", 295 | "with torch.no_grad():\n", 296 | " audio = vocoder.inference(mel)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": { 303 | "colab": { 304 | "base_uri": "https://localhost:8080/", 305 | "height": 62 306 | }, 307 | "colab_type": "code", 308 | "id": "vQYaZawLXTJI", 309 | "outputId": "0137a959-76c6-4c2b-cf0d-87d68c585651" 310 | }, 311 | "outputs": [], 312 | "source": [ 313 | "# Display audio\n", 314 | "ipd.display(ipd.Audio(audio.cpu().numpy(), rate=22050))" 315 | ] 316 | } 317 | ], 318 | "metadata": { 319 | "accelerator": "GPU", 320 | "colab": { 321 | "collapsed_sections": [], 322 | "name": "synthesize_forward_melgan", 323 | "provenance": [] 324 | }, 325 | "kernelspec": { 326 | "display_name": "ttsTF", 327 | "language": "python", 328 | "name": "ttstf" 329 | }, 330 | "language_info": { 331 | "codemirror_mode": { 332 | "name": "ipython", 333 | "version": 3 334 | }, 335 | "file_extension": ".py", 336 | "mimetype": "text/x-python", 337 | "name": "python", 338 | "nbconvert_exporter": "python", 339 | "pygments_lexer": "ipython3", 340 | "version": "3.6.9" 341 | }, 342 | "widgets": { 343 | "application/vnd.jupyter.widget-state+json": { 344 | "171439a16fba452a8fdab03a5711a6dc": { 345 | "model_module": "@jupyter-widgets/controls", 346 | "model_name": "ProgressStyleModel", 347 | "state": { 348 | "_model_module": "@jupyter-widgets/controls", 349 | "_model_module_version": "1.5.0", 350 | "_model_name": "ProgressStyleModel", 351 | "_view_count": null, 352 | "_view_module": "@jupyter-widgets/base", 353 | "_view_module_version": "1.2.0", 354 | "_view_name": "StyleView", 355 | "bar_color": null, 356 | "description_width": "initial" 357 | } 358 | }, 359 | "2820cb67ddb6435a9abc1b7252e7fe44": { 360 | "model_module": "@jupyter-widgets/base", 361 | "model_name": "LayoutModel", 362 | "state": { 363 | "_model_module": "@jupyter-widgets/base", 364 | "_model_module_version": "1.2.0", 365 | "_model_name": "LayoutModel", 366 | "_view_count": null, 367 | "_view_module": "@jupyter-widgets/base", 368 | "_view_module_version": "1.2.0", 369 | "_view_name": "LayoutView", 370 | "align_content": null, 371 | "align_items": null, 372 | "align_self": null, 373 | "border": null, 374 | "bottom": null, 375 | "display": null, 376 | "flex": null, 377 | "flex_flow": null, 378 | "grid_area": null, 379 | "grid_auto_columns": null, 380 | "grid_auto_flow": null, 381 | "grid_auto_rows": null, 382 | "grid_column": null, 383 | "grid_gap": null, 384 | "grid_row": null, 385 | "grid_template_areas": null, 386 | "grid_template_columns": null, 387 | "grid_template_rows": null, 388 | "height": null, 389 | "justify_content": null, 390 | "justify_items": null, 391 | "left": null, 392 | "margin": null, 393 | "max_height": null, 394 | "max_width": null, 395 | "min_height": null, 396 | "min_width": null, 397 | "object_fit": null, 398 | "object_position": null, 399 | "order": null, 400 | "overflow": null, 401 | "overflow_x": null, 402 | "overflow_y": null, 403 | "padding": null, 404 | "right": null, 405 | "top": null, 406 | "visibility": null, 407 | "width": null 408 | } 409 | }, 410 | "4c9a01915a2648f192071342e0cd3202": { 411 | "model_module": "@jupyter-widgets/controls", 412 | "model_name": "DescriptionStyleModel", 413 | "state": { 414 | "_model_module": "@jupyter-widgets/controls", 415 | "_model_module_version": "1.5.0", 416 | "_model_name": "DescriptionStyleModel", 417 | "_view_count": null, 418 | "_view_module": "@jupyter-widgets/base", 419 | "_view_module_version": "1.2.0", 420 | "_view_name": "StyleView", 421 | "description_width": "" 422 | } 423 | }, 424 | "7b4b771dce304335933820d574f49e4f": { 425 | "model_module": "@jupyter-widgets/base", 426 | "model_name": "LayoutModel", 427 | "state": { 428 | "_model_module": "@jupyter-widgets/base", 429 | "_model_module_version": "1.2.0", 430 | "_model_name": "LayoutModel", 431 | "_view_count": null, 432 | "_view_module": "@jupyter-widgets/base", 433 | "_view_module_version": "1.2.0", 434 | "_view_name": "LayoutView", 435 | "align_content": null, 436 | "align_items": null, 437 | "align_self": null, 438 | "border": null, 439 | "bottom": null, 440 | "display": null, 441 | "flex": null, 442 | "flex_flow": null, 443 | "grid_area": null, 444 | "grid_auto_columns": null, 445 | "grid_auto_flow": null, 446 | "grid_auto_rows": null, 447 | "grid_column": null, 448 | "grid_gap": null, 449 | "grid_row": null, 450 | "grid_template_areas": null, 451 | "grid_template_columns": null, 452 | "grid_template_rows": null, 453 | "height": null, 454 | "justify_content": null, 455 | "justify_items": null, 456 | "left": null, 457 | "margin": null, 458 | "max_height": null, 459 | "max_width": null, 460 | "min_height": null, 461 | "min_width": null, 462 | "object_fit": null, 463 | "object_position": null, 464 | "order": null, 465 | "overflow": null, 466 | "overflow_x": null, 467 | "overflow_y": null, 468 | "padding": null, 469 | "right": null, 470 | "top": null, 471 | "visibility": null, 472 | "width": null 473 | } 474 | }, 475 | "90ad1e7188db4f98be3081554bb06cbc": { 476 | "model_module": "@jupyter-widgets/base", 477 | "model_name": "LayoutModel", 478 | "state": { 479 | "_model_module": "@jupyter-widgets/base", 480 | "_model_module_version": "1.2.0", 481 | "_model_name": "LayoutModel", 482 | "_view_count": null, 483 | "_view_module": "@jupyter-widgets/base", 484 | "_view_module_version": "1.2.0", 485 | "_view_name": "LayoutView", 486 | "align_content": null, 487 | "align_items": null, 488 | "align_self": null, 489 | "border": null, 490 | "bottom": null, 491 | "display": null, 492 | "flex": null, 493 | "flex_flow": null, 494 | "grid_area": null, 495 | "grid_auto_columns": null, 496 | "grid_auto_flow": null, 497 | "grid_auto_rows": null, 498 | "grid_column": null, 499 | "grid_gap": null, 500 | "grid_row": null, 501 | "grid_template_areas": null, 502 | "grid_template_columns": null, 503 | "grid_template_rows": null, 504 | "height": null, 505 | "justify_content": null, 506 | "justify_items": null, 507 | "left": null, 508 | "margin": null, 509 | "max_height": null, 510 | "max_width": null, 511 | "min_height": null, 512 | "min_width": null, 513 | "object_fit": null, 514 | "object_position": null, 515 | "order": null, 516 | "overflow": null, 517 | "overflow_x": null, 518 | "overflow_y": null, 519 | "padding": null, 520 | "right": null, 521 | "top": null, 522 | "visibility": null, 523 | "width": null 524 | } 525 | }, 526 | "a29f9e1975f248348c32254fe5b3d026": { 527 | "model_module": "@jupyter-widgets/controls", 528 | "model_name": "FloatProgressModel", 529 | "state": { 530 | "_dom_classes": [], 531 | "_model_module": "@jupyter-widgets/controls", 532 | "_model_module_version": "1.5.0", 533 | "_model_name": "FloatProgressModel", 534 | "_view_count": null, 535 | "_view_module": "@jupyter-widgets/controls", 536 | "_view_module_version": "1.5.0", 537 | "_view_name": "ProgressView", 538 | "bar_style": "success", 539 | "description": "100%", 540 | "description_tooltip": null, 541 | "layout": "IPY_MODEL_7b4b771dce304335933820d574f49e4f", 542 | "max": 17090302, 543 | "min": 0, 544 | "orientation": "horizontal", 545 | "style": "IPY_MODEL_171439a16fba452a8fdab03a5711a6dc", 546 | "value": 17090302 547 | } 548 | }, 549 | "b0b25d43cc6949d9b921d33984deea0f": { 550 | "model_module": "@jupyter-widgets/controls", 551 | "model_name": "HTMLModel", 552 | "state": { 553 | "_dom_classes": [], 554 | "_model_module": "@jupyter-widgets/controls", 555 | "_model_module_version": "1.5.0", 556 | "_model_name": "HTMLModel", 557 | "_view_count": null, 558 | "_view_module": "@jupyter-widgets/controls", 559 | "_view_module_version": "1.5.0", 560 | "_view_name": "HTMLView", 561 | "description": "", 562 | "description_tooltip": null, 563 | "layout": "IPY_MODEL_90ad1e7188db4f98be3081554bb06cbc", 564 | "placeholder": "​", 565 | "style": "IPY_MODEL_4c9a01915a2648f192071342e0cd3202", 566 | "value": " 16.3M/16.3M [00:00<00:00, 91.6MB/s]" 567 | } 568 | }, 569 | "d2d68febdf104697bbf2d4a7d132891a": { 570 | "model_module": "@jupyter-widgets/controls", 571 | "model_name": "HBoxModel", 572 | "state": { 573 | "_dom_classes": [], 574 | "_model_module": "@jupyter-widgets/controls", 575 | "_model_module_version": "1.5.0", 576 | "_model_name": "HBoxModel", 577 | "_view_count": null, 578 | "_view_module": "@jupyter-widgets/controls", 579 | "_view_module_version": "1.5.0", 580 | "_view_name": "HBoxView", 581 | "box_style": "", 582 | "children": [ 583 | "IPY_MODEL_a29f9e1975f248348c32254fe5b3d026", 584 | "IPY_MODEL_b0b25d43cc6949d9b921d33984deea0f" 585 | ], 586 | "layout": "IPY_MODEL_2820cb67ddb6435a9abc1b7252e7fe44" 587 | } 588 | } 589 | } 590 | } 591 | }, 592 | "nbformat": 4, 593 | "nbformat_minor": 1 594 | } 595 | -------------------------------------------------------------------------------- /predict_tts.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | from model.factory import tts_ljspeech 7 | from data.audio import Audio 8 | from model.models import ForwardTransformer 9 | 10 | if __name__ == '__main__': 11 | parser = ArgumentParser() 12 | parser.add_argument('--path', '-p', dest='path', default=None, type=str) 13 | parser.add_argument('--step', dest='step', default='90000', type=str) 14 | parser.add_argument('--text', '-t', dest='text', default=None, type=str) 15 | parser.add_argument('--file', '-f', dest='file', default=None, type=str) 16 | parser.add_argument('--outdir', '-o', dest='outdir', default=None, type=str) 17 | parser.add_argument('--store_mel', '-m', dest='store_mel', action='store_true') 18 | parser.add_argument('--verbose', '-v', dest='verbose', action='store_true') 19 | parser.add_argument('--single', '-s', dest='single', action='store_true') 20 | args = parser.parse_args() 21 | 22 | if args.file is not None: 23 | with open(args.file, 'r') as file: 24 | text = file.readlines() 25 | fname = Path(args.file).stem 26 | elif args.text is not None: 27 | text = [args.text] 28 | fname = 'custom_text' 29 | else: 30 | fname = None 31 | text = None 32 | print(f'Specify either an input text (-t "some text") or a text input file (-f /path/to/file.txt)') 33 | exit() 34 | # load the appropriate model 35 | outdir = Path(args.outdir) if args.outdir is not None else Path('.') 36 | if args.path is not None: 37 | print(f'Loading model from {args.path}') 38 | model = ForwardTransformer.load_model(args.path) 39 | else: 40 | model = tts_ljspeech(args.step) 41 | file_name = f"{fname}_{model.config['data_name']}_{model.config['git_hash']}_{model.config['step']}" 42 | outdir = outdir / 'outputs' / f'{fname}' 43 | outdir.mkdir(exist_ok=True, parents=True) 44 | output_path = (outdir / file_name).with_suffix('.wav') 45 | audio = Audio.from_config(model.config) 46 | print(f'Output wav under {output_path.parent}') 47 | wavs = [] 48 | for i, text_line in enumerate(text): 49 | phons = model.text_pipeline.phonemizer(text_line) 50 | tokens = model.text_pipeline.tokenizer(phons) 51 | if args.verbose: 52 | print(f'Predicting {text_line}') 53 | print(f'Phonemes: "{phons}"') 54 | print(f'Tokens: "{tokens}"') 55 | out = model.predict(tokens, encode=False, phoneme_max_duration=None) 56 | mel = out['mel'].numpy().T 57 | wav = audio.reconstruct_waveform(mel) 58 | wavs.append(wav) 59 | if args.store_mel: 60 | np.save((outdir / (file_name + f'_{i}')).with_suffix('.mel'), out['mel'].numpy()) 61 | if args.single: 62 | audio.save_wav(wav, (outdir / (file_name + f'_{i}')).with_suffix('.wav')) 63 | audio.save_wav(np.concatenate(wavs), output_path) 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | librosa==0.7.1 3 | numba==0.48 4 | numpy>=1.17.4 5 | phonemizer~=2.2.1 6 | ruamel.yaml>=0.16.6 7 | tensorflow>=2.2.0 8 | tqdm==4.40.1 9 | p_tqdm 10 | soundfile 11 | webrtcvad 12 | scipy 13 | pyworld 14 | -------------------------------------------------------------------------------- /test_sentences.txt: -------------------------------------------------------------------------------- 1 | This is a nice test. 2 | Is this a nice test? 3 | President Trump met with other leaders at the Group of twenty conference. 4 | Scientists at the CERN laboratory say they have discovered a new particle. 5 | There’s a way to measure the acute emotional intelligence that has never gone out of style. 6 | The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled. 7 | Peter piper picked a peck of pickled peppers. 8 | If I were to talk to a human, I would definitely try to sound normal. Wouldn't I? -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_char_tokenizer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | from data.text.tokenizer import Tokenizer 7 | 8 | 9 | class TestCharTokenizer(unittest.TestCase): 10 | 11 | def test_tokenizer(self): 12 | tokenizer = Tokenizer(alphabet=list('ab c')) 13 | self.assertEqual(5, tokenizer.start_token_index) 14 | self.assertEqual(6, tokenizer.end_token_index) 15 | self.assertEqual(7, tokenizer.vocab_size) 16 | 17 | seq = tokenizer('a b d') 18 | self.assertEqual([5, 1, 3, 2, 3, 6], seq) 19 | 20 | seq = np.array([5, 1, 3, 2, 8, 6]) 21 | seq = tf.convert_to_tensor(seq) 22 | text = tokenizer.decode(seq) 23 | self.assertEqual('>a b<', text) 24 | -------------------------------------------------------------------------------- /tests/test_config.yaml: -------------------------------------------------------------------------------- 1 | # ARCHITECTURE 2 | decoder_model_dimension: 128 3 | encoder_model_dimension: 128 4 | decoder_num_heads: [1] 5 | encoder_num_heads: [1] 6 | encoder_feed_forward_dimension: 128 7 | decoder_feed_forward_dimension: 128 8 | decoder_prenet_dimension: 128 9 | max_position_encoding: 10000 10 | postnet_conv_filters: 64 11 | postnet_conv_layers: 1 12 | postnet_kernel_size: 5 13 | dropout_rate: 0.1 14 | # DATA 15 | n_samples: 600 16 | mel_channels: 80 17 | sr: 22050 18 | mel_start_value: -3 19 | mel_end_value: 1 20 | # TRAINING 21 | use_decoder_prenet_dropout_schedule: True 22 | decoder_prenet_dropout_schedule_max: 0.9 23 | decoder_prenet_dropout_schedule_min: 0.6 24 | decoder_prenet_dropout_schedule_max_steps: 30_000 25 | fixed_decoder_prenet_dropout: 0.6 26 | epochs: 10 27 | batch_size: 2 28 | learning_rate_schedule: 29 | - [0, 1.0e-3] 30 | reduction_factor_schedule: 31 | - [0, 10] 32 | mask_prob: 0.3 33 | use_block_attention: False 34 | debug: True 35 | # LOGGING 36 | text_freq: 5000 37 | image_freq: 1000 38 | weights_save_freq: 10 39 | plot_attention_freq: 500 40 | keep_n_weights: 5 41 | warmup_steps: 1_000 42 | warmup_lr: 1.0e-6 43 | #TOKENIZER 44 | use_phonemes: True 45 | phoneme_language: 'en' 46 | tokenizer_alphabet: "!,.:;?'- abcdefghijklmnopqrstuvwxyzäüöß" -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from utils.losses import new_scaled_crossentropy, masked_crossentropy 6 | 7 | 8 | class TestCharTokenizer(unittest.TestCase): 9 | 10 | def test_crossentropy(self): 11 | scaled_crossent = new_scaled_crossentropy(index=2, scaling=5) 12 | 13 | targets = np.array([[0, 1, 2]]) 14 | logits = np.array([[[.3, .2, .1], [.3, .2, .1], [.3, .2, .1]]]) 15 | 16 | loss = scaled_crossent(targets, logits) 17 | self.assertAlmostEqual(2.3705523014068604, float(loss)) 18 | 19 | scaled_crossent = new_scaled_crossentropy(index=2, scaling=1) 20 | loss = scaled_crossent(targets, logits) 21 | self.assertAlmostEqual(0.7679619193077087, float(loss)) 22 | 23 | loss = masked_crossentropy(targets, logits) 24 | self.assertAlmostEqual(0.7679619193077087, float(loss)) 25 | -------------------------------------------------------------------------------- /train_aligner.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tqdm import trange 4 | 5 | from utils.training_config_manager import TrainingConfigManager 6 | from data.datasets import AlignerDataset, AlignerPreprocessor 7 | from utils.decorators import ignore_exception, time_it 8 | from utils.scheduling import piecewise_linear_schedule, reduction_schedule 9 | from utils.logging_utils import SummaryManager 10 | from utils.scripts_utils import dynamic_memory_allocation, basic_train_parser 11 | from utils.metrics import attention_score 12 | from utils.spectrogram_ops import mel_lengths, phoneme_lengths 13 | from utils.alignments import get_durations_from_alignment 14 | 15 | np.random.seed(42) 16 | tf.random.set_seed(42) 17 | 18 | dynamic_memory_allocation() 19 | parser = basic_train_parser() 20 | args = parser.parse_args() 21 | 22 | 23 | def cut_with_durations(durations, mel, phonemes, snippet_len=10): 24 | phon_dur = np.pad(durations, (1, 0)) 25 | starts = np.cumsum(phon_dur)[:-1] 26 | ends = np.cumsum(phon_dur)[1:] 27 | cut_mels = [] 28 | cut_texts = [] 29 | for end_idx in range(snippet_len, len(phon_dur), snippet_len): 30 | start_idx = end_idx - snippet_len 31 | cut_mels.append(mel[starts[start_idx]: ends[end_idx - 1], :]) 32 | cut_texts.append(phonemes[start_idx: end_idx]) 33 | return cut_mels, cut_texts 34 | 35 | 36 | @ignore_exception 37 | @time_it 38 | def validate(model, 39 | val_dataset, 40 | summary_manager, 41 | weighted_durations): 42 | val_loss = {'loss': 0.} 43 | norm = 0. 44 | current_r = model.r 45 | model.set_constants(reduction_factor=1) 46 | for val_mel, val_text, val_stop, fname in val_dataset.all_batches(): 47 | model_out = model.val_step(inp=val_text, 48 | tar=val_mel, 49 | stop_prob=val_stop) 50 | norm += 1 51 | val_loss['loss'] += model_out['loss'] 52 | val_loss['loss'] /= norm 53 | summary_manager.display_loss(model_out, tag='Validation', plot_all=True) 54 | summary_manager.display_last_attention(model_out, tag='ValidationAttentionHeads', fname=fname) 55 | attention_values = model_out['decoder_attention']['Decoder_LastBlock_CrossAttention'].numpy() 56 | text = val_text.numpy() 57 | mel = val_mel.numpy() 58 | model.set_constants(reduction_factor=current_r) 59 | modes = list({False, weighted_durations}) 60 | for mode in modes: 61 | durations, final_align, jumpiness, peakiness, diag_measure = get_durations_from_alignment( 62 | batch_alignments=attention_values, 63 | mels=mel, 64 | phonemes=text, 65 | weighted=mode) 66 | for k in range(len(durations)): 67 | phon_dur = durations[k] 68 | imel = mel[k][1:] # remove start token (is padded so end token can't be removed/not an issue) 69 | itext = text[k][1:] # remove start token (is padded so end token can't be removed/not an issue) 70 | iphon = model.text_pipeline.tokenizer.decode(itext).replace('/', '') 71 | cut_mels, cut_texts = cut_with_durations(durations=phon_dur, mel=imel, phonemes=iphon) 72 | for cut_idx, cut_text in enumerate(cut_texts): 73 | weighted_label = 'weighted_' * mode 74 | summary_manager.display_audio( 75 | tag=f'CutAudio {weighted_label}{fname[k].numpy().decode("utf-8")}/{cut_idx}/{cut_text}', 76 | mel=cut_mels[cut_idx], description=iphon) 77 | return val_loss['loss'] 78 | 79 | 80 | config_manager = TrainingConfigManager(config_path=args.config, aligner=True) 81 | config = config_manager.config 82 | config_manager.create_remove_dirs(clear_dir=args.clear_dir, 83 | clear_logs=args.clear_logs, 84 | clear_weights=args.clear_weights) 85 | config_manager.dump_config() 86 | config_manager.print_config() 87 | 88 | # get model, prepare data for model, create datasets 89 | model = config_manager.get_model() 90 | config_manager.compile_model(model) 91 | data_prep = AlignerPreprocessor.from_config(config_manager, 92 | tokenizer=model.text_pipeline.tokenizer) # TODO: tokenizer is now static 93 | train_data_handler = AlignerDataset.from_config(config_manager, 94 | preprocessor=data_prep, 95 | kind='train') 96 | valid_data_handler = AlignerDataset.from_config(config_manager, 97 | preprocessor=data_prep, 98 | kind='valid') 99 | 100 | train_dataset = train_data_handler.get_dataset(bucket_batch_sizes=config['bucket_batch_sizes'], 101 | bucket_boundaries=config['bucket_boundaries'], 102 | shuffle=True) 103 | valid_dataset = valid_data_handler.get_dataset(bucket_batch_sizes=config['val_bucket_batch_size'], 104 | bucket_boundaries=config['bucket_boundaries'], 105 | shuffle=False, drop_remainder=True) 106 | 107 | # create logger and checkpointer and restore latest model 108 | 109 | summary_manager = SummaryManager(model=model, log_dir=config_manager.log_dir, config=config) 110 | checkpoint = tf.train.Checkpoint(step=tf.Variable(1), 111 | optimizer=model.optimizer, 112 | net=model) 113 | manager = tf.train.CheckpointManager(checkpoint, str(config_manager.weights_dir), 114 | max_to_keep=config['keep_n_weights'], 115 | keep_checkpoint_every_n_hours=config['keep_checkpoint_every_n_hours']) 116 | manager_training = tf.train.CheckpointManager(checkpoint, str(config_manager.weights_dir / 'latest'), 117 | max_to_keep=1, checkpoint_name='latest') 118 | 119 | checkpoint.restore(manager_training.latest_checkpoint) 120 | if manager_training.latest_checkpoint: 121 | print(f'\nresuming training from step {model.step} ({manager_training.latest_checkpoint})') 122 | else: 123 | print(f'\nstarting training from scratch') 124 | 125 | if config['debug'] is True: 126 | print('\nWARNING: DEBUG is set to True. Training in eager mode.') 127 | # main event 128 | print('\nTRAINING') 129 | 130 | texts = [] 131 | for text_file in config['test_stencences']: 132 | with open(text_file, 'r') as file: 133 | text = file.readlines() 134 | texts.append(text) 135 | 136 | losses = [] 137 | test_mel, test_phonemes, _, test_fname = valid_dataset.next_batch() 138 | val_test_sample, val_test_fname, val_test_mel = test_phonemes[0], test_fname[0], test_mel[0] 139 | val_test_sample = tf.boolean_mask(val_test_sample, val_test_sample!=0) 140 | 141 | _ = train_dataset.next_batch() 142 | t = trange(model.step, config['max_steps'], leave=True) 143 | for _ in t: 144 | t.set_description(f'step {model.step}') 145 | mel, phonemes, stop, sample_name = train_dataset.next_batch() 146 | learning_rate = piecewise_linear_schedule(model.step, config['learning_rate_schedule']) 147 | reduction_factor = reduction_schedule(model.step, config['reduction_factor_schedule']) 148 | t.display(f'reduction factor {reduction_factor}', pos=10) 149 | force_encoder_diagonal = model.step < config['force_encoder_diagonal_steps'] 150 | force_decoder_diagonal = model.step < config['force_decoder_diagonal_steps'] 151 | model.set_constants(learning_rate=learning_rate, 152 | reduction_factor=reduction_factor, 153 | force_encoder_diagonal=force_encoder_diagonal, 154 | force_decoder_diagonal=force_decoder_diagonal) 155 | 156 | output = model.train_step(inp=phonemes, 157 | tar=mel, 158 | stop_prob=stop) 159 | losses.append(float(output['loss'])) 160 | 161 | t.display(f'step loss: {losses[-1]}', pos=1) 162 | for pos, n_steps in enumerate(config['n_steps_avg_losses']): 163 | if len(losses) > n_steps: 164 | t.display(f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}', pos=pos + 2) 165 | 166 | summary_manager.display_loss(output, tag='Train') 167 | summary_manager.display_scalar(tag='Meta/learning_rate', scalar_value=model.optimizer.lr) 168 | summary_manager.display_scalar(tag='Meta/reduction_factor', scalar_value=model.r) 169 | summary_manager.display_scalar(scalar_value=t.avg_time, tag='Meta/iter_time') 170 | summary_manager.display_scalar(scalar_value=tf.shape(sample_name)[0], tag='Meta/batch_size') 171 | if model.step % config['train_images_plotting_frequency'] == 0: 172 | summary_manager.display_attention_heads(output, tag='TrainAttentionHeads') 173 | summary_manager.display_mel(mel=output['mel'][0], tag=f'Train/predicted_mel') 174 | for layer, k in enumerate(output['decoder_attention'].keys()): 175 | mel_lens = mel_lengths(mel_batch=mel, padding_value=0) // model.r # [N] 176 | phon_len = phoneme_lengths(phonemes) 177 | loc_score, peak_score, diag_measure = attention_score(att=output['decoder_attention'][k], 178 | mel_len=mel_lens, 179 | phon_len=phon_len, 180 | r=model.r) 181 | loc_score = tf.reduce_mean(loc_score, axis=0) 182 | peak_score = tf.reduce_mean(peak_score, axis=0) 183 | diag_measure = tf.reduce_mean(diag_measure, axis=0) 184 | for i in range(tf.shape(loc_score)[0]): 185 | summary_manager.display_scalar(tag=f'TrainDecoderAttentionJumpiness/layer{layer}_head{i}', 186 | scalar_value=tf.reduce_mean(loc_score[i])) 187 | summary_manager.display_scalar(tag=f'TrainDecoderAttentionPeakiness/layer{layer}_head{i}', 188 | scalar_value=tf.reduce_mean(peak_score[i])) 189 | summary_manager.display_scalar(tag=f'TrainDecoderAttentionDiagonality/layer{layer}_head{i}', 190 | scalar_value=tf.reduce_mean(diag_measure[i])) 191 | 192 | if model.step % 1000 == 0: 193 | save_path = manager_training.save() 194 | if model.step % config['weights_save_frequency'] == 0: 195 | save_path = manager.save() 196 | t.display(f'checkpoint at step {model.step}: {save_path}', pos=len(config['n_steps_avg_losses']) + 2) 197 | 198 | if model.step % config['validation_frequency'] == 0 and (model.step >= config['prediction_start_step']): 199 | val_loss, time_taken = validate(model=model, 200 | val_dataset=valid_dataset, 201 | summary_manager=summary_manager, 202 | weighted_durations=config['extract_attention_weighted']) 203 | t.display(f'validation loss at step {model.step}: {val_loss} (took {time_taken}s)', 204 | pos=len(config['n_steps_avg_losses']) + 3) 205 | 206 | if model.step % config['prediction_frequency'] == 0 and (model.step >= config['prediction_start_step']): 207 | for j, text in enumerate(texts): 208 | for i, text_line in enumerate(text): 209 | out = model.predict(text_line, encode=True) 210 | wav = summary_manager.audio.reconstruct_waveform(out['mel'].numpy().T) 211 | wav = tf.expand_dims(wav, 0) 212 | wav = tf.expand_dims(wav, -1) 213 | summary_manager.add_audio(f'Predictions/{text_line}', wav.numpy(), sr=summary_manager.config['sampling_rate'], 214 | step=summary_manager.global_step) 215 | 216 | out = model.predict(val_test_sample, encode=False)#, max_length=tf.shape(val_test_mel)[-2]) 217 | wav = summary_manager.audio.reconstruct_waveform(out['mel'].numpy().T) 218 | wav = tf.expand_dims(wav, 0) 219 | wav = tf.expand_dims(wav, -1) 220 | summary_manager.add_audio(f'Predictions/val_sample {val_test_fname.numpy().decode("utf-8")}', wav.numpy(), sr=summary_manager.config['sampling_rate'], 221 | step=summary_manager.global_step) 222 | print('Done.') 223 | -------------------------------------------------------------------------------- /train_tts.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tqdm import trange 4 | 5 | from utils.training_config_manager import TrainingConfigManager 6 | from data.datasets import TTSDataset, TTSPreprocessor 7 | from utils.decorators import ignore_exception, time_it 8 | from utils.scheduling import piecewise_linear_schedule 9 | from utils.logging_utils import SummaryManager 10 | from model.transformer_utils import create_mel_padding_mask 11 | from utils.scripts_utils import dynamic_memory_allocation, basic_train_parser 12 | from data.metadata_readers import post_processed_reader 13 | 14 | np.random.seed(42) 15 | tf.random.set_seed(42) 16 | dynamic_memory_allocation() 17 | 18 | 19 | def display_target_symbol_duration_distributions(): 20 | phon_data, ups = post_processed_reader(config.phonemized_metadata_path) 21 | dur_dict = {} 22 | for key in phon_data.keys(): 23 | dur_dict[key] = np.load((config.duration_dir / key).with_suffix('.npy')) 24 | symbol_durs = {} 25 | for key in dur_dict: 26 | for i, phoneme in enumerate(phon_data[key]): 27 | symbol_durs.setdefault(phoneme, []).append(dur_dict[key][i]) 28 | for symbol in symbol_durs.keys(): 29 | summary_manager.add_histogram(tag=f'"{symbol}"/Target durations', values=symbol_durs[symbol], 30 | buckets=len(set(symbol_durs[symbol])) + 1, step=0) 31 | 32 | 33 | def display_predicted_symbol_duration_distributions(all_durations): 34 | phon_data, ups = post_processed_reader(config.phonemized_metadata_path) 35 | symbol_durs = {} 36 | for key in all_durations.keys(): 37 | clean_key = key.decode('utf-8') 38 | for i, phoneme in enumerate(phon_data[clean_key]): 39 | symbol_durs.setdefault(phoneme, []).append(all_durations[key][i]) 40 | for symbol in symbol_durs.keys(): 41 | summary_manager.add_histogram(tag=f'"{symbol}"/Predicted durations', values=symbol_durs[symbol]) 42 | 43 | 44 | @ignore_exception 45 | @time_it 46 | def validate(model, 47 | val_dataset, 48 | summary_manager): 49 | val_loss = {'loss': 0.} 50 | norm = 0. 51 | for mel, phonemes, durations, pitch, fname in val_dataset.all_batches(): 52 | model_out = model.val_step(input_sequence=phonemes, 53 | target_sequence=mel, 54 | target_durations=durations, 55 | target_pitch=pitch) 56 | norm += 1 57 | val_loss['loss'] += model_out['loss'] 58 | val_loss['loss'] /= norm 59 | summary_manager.display_loss(model_out, tag='Validation', plot_all=True) 60 | summary_manager.display_attention_heads(model_out, tag='ValidationAttentionHeads') 61 | summary_manager.add_histogram(tag=f'Validation/Predicted durations', values=model_out['duration']) 62 | summary_manager.add_histogram(tag=f'Validation/Target durations', values=durations) 63 | summary_manager.display_plot1D(tag=f'Validation/{fname[0].numpy().decode("utf-8")} predicted pitch', 64 | y=model_out['pitch'][0]) 65 | summary_manager.display_plot1D(tag=f'Validation/{fname[0].numpy().decode("utf-8")} target pitch', y=pitch[0]) 66 | summary_manager.display_mel(mel=model_out['mel'][0], 67 | tag=f'Validation/{fname[0].numpy().decode("utf-8")} predicted_mel') 68 | summary_manager.display_mel(mel=mel[0], tag=f'Validation/{fname[0].numpy().decode("utf-8")} target_mel') 69 | summary_manager.display_audio(tag=f'Validation {fname[0].numpy().decode("utf-8")}/prediction', 70 | mel=model_out['mel'][0]) 71 | summary_manager.display_audio(tag=f'Validation {fname[0].numpy().decode("utf-8")}/target', mel=mel[0]) 72 | # predict withoyt enforcing durations and pitch 73 | model_out = model.predict(phonemes, encode=False) 74 | pred_lengths = tf.cast(tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32) 75 | pred_lengths = tf.squeeze(pred_lengths) 76 | tar_lengths = tf.cast(tf.reduce_sum(1 - create_mel_padding_mask(mel), axis=-1), tf.int32) 77 | tar_lengths = tf.squeeze(tar_lengths) 78 | for j, pred_mel in enumerate(model_out['mel']): 79 | predval = pred_mel[:pred_lengths[j], :] 80 | tar_value = mel[j, :tar_lengths[j], :] 81 | summary_manager.display_mel(mel=predval, tag=f'Test/{fname[j].numpy().decode("utf-8")}/predicted') 82 | summary_manager.display_mel(mel=tar_value, tag=f'Test/{fname[j].numpy().decode("utf-8")}/target') 83 | summary_manager.display_audio(tag=f'Prediction {fname[j].numpy().decode("utf-8")}/target', mel=tar_value) 84 | summary_manager.display_audio(tag=f'Prediction {fname[j].numpy().decode("utf-8")}/prediction', 85 | mel=predval) 86 | return val_loss['loss'] 87 | 88 | 89 | parser = basic_train_parser() 90 | args = parser.parse_args() 91 | 92 | config = TrainingConfigManager(config_path=args.config) 93 | config_dict = config.config 94 | config.create_remove_dirs(clear_dir=args.clear_dir, 95 | clear_logs=args.clear_logs, 96 | clear_weights=args.clear_weights) 97 | config.dump_config() 98 | config.print_config() 99 | 100 | model = config.get_model() 101 | config.compile_model(model) 102 | 103 | data_prep = TTSPreprocessor.from_config(config=config, 104 | tokenizer=model.text_pipeline.tokenizer) 105 | train_data_handler = TTSDataset.from_config(config, 106 | preprocessor=data_prep, 107 | kind='train') 108 | valid_data_handler = TTSDataset.from_config(config, 109 | preprocessor=data_prep, 110 | kind='valid') 111 | train_dataset = train_data_handler.get_dataset(bucket_batch_sizes=config_dict['bucket_batch_sizes'], 112 | bucket_boundaries=config_dict['bucket_boundaries'], 113 | shuffle=True) 114 | valid_dataset = valid_data_handler.get_dataset(bucket_batch_sizes=config_dict['val_bucket_batch_size'], 115 | bucket_boundaries=config_dict['bucket_boundaries'], 116 | shuffle=False, 117 | drop_remainder=True) 118 | 119 | # create logger and checkpointer and restore latest model 120 | summary_manager = SummaryManager(model=model, log_dir=config.log_dir, config=config_dict) 121 | checkpoint = tf.train.Checkpoint(step=tf.Variable(1), 122 | optimizer=model.optimizer, 123 | net=model) 124 | manager_training = tf.train.CheckpointManager(checkpoint, str(config.weights_dir / 'latest'), 125 | max_to_keep=1, checkpoint_name='latest') 126 | 127 | checkpoint.restore(manager_training.latest_checkpoint) 128 | if manager_training.latest_checkpoint: 129 | print(f'\nresuming training from step {model.step} ({manager_training.latest_checkpoint})') 130 | else: 131 | print(f'\nstarting training from scratch') 132 | 133 | if config_dict['debug'] is True: 134 | print('\nWARNING: DEBUG is set to True. Training in eager mode.') 135 | 136 | display_target_symbol_duration_distributions() 137 | # main event 138 | print('\nTRAINING') 139 | losses = [] 140 | texts = [] 141 | for text_file in config_dict['text_prediction']: 142 | with open(text_file, 'r') as file: 143 | text = file.readlines() 144 | texts.append(text) 145 | 146 | all_files = len(set(train_data_handler.metadata_reader.filenames)) # without duplicates 147 | all_durations = {} 148 | t = trange(model.step, config_dict['max_steps'], leave=True) 149 | for _ in t: 150 | t.set_description(f'step {model.step}') 151 | mel, phonemes, durations, pitch, fname = train_dataset.next_batch() 152 | learning_rate = piecewise_linear_schedule(model.step, config_dict['learning_rate_schedule']) 153 | model.set_constants(learning_rate=learning_rate) 154 | output = model.train_step(input_sequence=phonemes, 155 | target_sequence=mel, 156 | target_durations=durations, 157 | target_pitch=pitch) 158 | losses.append(float(output['loss'])) 159 | 160 | predicted_durations = dict(zip(fname.numpy(), output['duration'].numpy())) 161 | all_durations.update(predicted_durations) 162 | if len(all_durations) >= all_files: # all the dataset has been processed 163 | display_predicted_symbol_duration_distributions(all_durations) 164 | all_durations = {} 165 | 166 | t.display(f'step loss: {losses[-1]}', pos=1) 167 | for pos, n_steps in enumerate(config_dict['n_steps_avg_losses']): 168 | if len(losses) > n_steps: 169 | t.display(f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}', pos=pos + 2) 170 | 171 | summary_manager.display_loss(output, tag='Train') 172 | summary_manager.display_scalar(scalar_value=t.avg_time, tag='Meta/iter_time') 173 | summary_manager.display_scalar(scalar_value=tf.shape(fname)[0], tag='Meta/batch_size') 174 | summary_manager.display_scalar(tag='Meta/learning_rate', scalar_value=model.optimizer.lr) 175 | if model.step % config_dict['train_images_plotting_frequency'] == 0: 176 | summary_manager.display_attention_heads(output, tag='TrainAttentionHeads') 177 | summary_manager.display_mel(mel=output['mel'][0], tag=f'Train/predicted_mel') 178 | summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel') 179 | summary_manager.display_plot1D(tag=f'Train/Predicted pitch', y=output['pitch'][0]) 180 | summary_manager.display_plot1D(tag=f'Train/Target pitch', y=pitch[0]) 181 | 182 | if model.step % 1000 == 0: 183 | save_path = manager_training.save() 184 | if (model.step % config_dict['weights_save_frequency'] == 0) & ( 185 | model.step >= config_dict['weights_save_starting_step']): 186 | model.save_model(config.weights_dir / f'step_{model.step}') 187 | t.display(f'checkpoint at step {model.step}: {config.weights_dir / f"step_{model.step}"}', 188 | pos=len(config_dict['n_steps_avg_losses']) + 2) 189 | 190 | if model.step % config_dict['validation_frequency'] == 0: 191 | t.display(f'Validating', pos=len(config_dict['n_steps_avg_losses']) + 3) 192 | val_loss, time_taken = validate(model=model, 193 | val_dataset=valid_dataset, 194 | summary_manager=summary_manager) 195 | t.display(f'validation loss at step {model.step}: {val_loss} (took {time_taken}s)', 196 | pos=len(config_dict['n_steps_avg_losses']) + 3) 197 | 198 | if model.step % config_dict['prediction_frequency'] == 0 and (model.step >= config_dict['prediction_start_step']): 199 | for i, text in enumerate(texts): 200 | wavs = [] 201 | for i, text_line in enumerate(text): 202 | out = model.predict(text_line, encode=True) 203 | wav = summary_manager.audio.reconstruct_waveform(out['mel'].numpy().T) 204 | wavs.append(wav) 205 | wavs = np.concatenate(wavs) 206 | wavs = tf.expand_dims(wavs, 0) 207 | wavs = tf.expand_dims(wavs, -1) 208 | summary_manager.add_audio(f'Text file input', wavs.numpy(), sr=summary_manager.config['sampling_rate'], 209 | step=summary_manager.global_step) 210 | 211 | print('Done.') 212 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/utils/__init__.py -------------------------------------------------------------------------------- /utils/alignments.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from utils.metrics import attention_score 4 | from utils.spectrogram_ops import mel_lengths, phoneme_lengths 5 | 6 | logger = tf.get_logger() 7 | logger.setLevel('ERROR') 8 | import numpy as np 9 | from scipy.sparse import coo_matrix 10 | from scipy.sparse.csgraph import dijkstra 11 | 12 | 13 | def to_node_index(i, j, cols): 14 | return cols * i + j 15 | 16 | 17 | def from_node_index(node_index, cols): 18 | return node_index // cols, node_index % cols 19 | 20 | 21 | def to_adj_matrix(mat): 22 | rows = mat.shape[0] 23 | cols = mat.shape[1] 24 | 25 | row_ind = [] 26 | col_ind = [] 27 | data = [] 28 | 29 | for i in range(rows): 30 | for j in range(cols): 31 | 32 | node = to_node_index(i, j, cols) 33 | 34 | if j < cols - 1: 35 | right_node = to_node_index(i, j + 1, cols) 36 | weight_right = mat[i, j + 1] 37 | row_ind.append(node) 38 | col_ind.append(right_node) 39 | data.append(weight_right) 40 | 41 | if i < rows - 1 and j < cols: 42 | bottom_node = to_node_index(i + 1, j, cols) 43 | weight_bottom = mat[i + 1, j] 44 | row_ind.append(node) 45 | col_ind.append(bottom_node) 46 | data.append(weight_bottom) 47 | 48 | if i < rows - 1 and j < cols - 1: 49 | bottom_right_node = to_node_index(i + 1, j + 1, cols) 50 | weight_bottom_right = mat[i + 1, j + 1] 51 | row_ind.append(node) 52 | col_ind.append(bottom_right_node) 53 | data.append(weight_bottom_right) 54 | 55 | adj_mat = coo_matrix((data, (row_ind, col_ind)), shape=(rows * cols, rows * cols)) 56 | return adj_mat.tocsr() 57 | 58 | 59 | def extract_durations_with_dijkstra(attention_map: np.array) -> np.array: 60 | """ 61 | Extracts durations from the attention matrix by finding the shortest monotonic path from 62 | top left to bottom right. 63 | """ 64 | attn_max = np.max(attention_map) 65 | path_probs = attn_max - attention_map 66 | adj_matrix = to_adj_matrix(path_probs) 67 | dist_matrix, predecessors = dijkstra(csgraph=adj_matrix, directed=True, 68 | indices=0, return_predecessors=True) 69 | path = [] 70 | pr_index = predecessors[-1] 71 | while pr_index != 0: 72 | path.append(pr_index) 73 | pr_index = predecessors[pr_index] 74 | path.reverse() 75 | 76 | # append first and last node 77 | path = [0] + path + [dist_matrix.size - 1] 78 | cols = path_probs.shape[1] 79 | mel_text = {} 80 | durations = np.zeros(attention_map.shape[1], dtype=np.int32) 81 | 82 | # collect indices (mel, text) along the path 83 | for node_index in path: 84 | i, j = from_node_index(node_index, cols) 85 | mel_text[i] = j 86 | 87 | for j in mel_text.values(): 88 | durations[j] += 1 89 | 90 | return durations 91 | 92 | 93 | def duration_to_alignment_matrix(durations): 94 | starts = np.cumsum(np.append([0], durations[:-1])) 95 | tot_duration = np.sum(durations) 96 | pads = tot_duration - starts - durations 97 | alignments = [np.concatenate([np.zeros(starts[i]), np.ones(durations[i]), np.zeros(pads[i])]) for i in 98 | range(len(durations))] 99 | return np.array(alignments) 100 | 101 | 102 | def get_durations_from_alignment(batch_alignments, mels, phonemes, weighted=False): 103 | """ 104 | 105 | :param batch_alignments: attention weights from autoregressive model. 106 | :param mels: mel spectrograms. 107 | :param phonemes: phoneme sequence. 108 | :param weighted: if True use weighted average of durations of heads, best head if False. 109 | :param binary: if True take maximum attention peak, sum if False. 110 | :param fill_gaps: if True fills zeros durations with ones. 111 | :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate. 112 | :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration 113 | needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum. 114 | :return: 115 | """ 116 | # mel_len - 1 because we remove last timestep, which is end_vector. start vector is not predicted (or removed from GTA) 117 | mel_len = mel_lengths(mels, padding_value=0.) - 1 # [N] 118 | # phonemes contain start and end tokens (start will be removed later) 119 | phon_len = phoneme_lengths(phonemes) - 1 120 | jumpiness, peakiness, diag_measure = attention_score(att=batch_alignments, mel_len=mel_len, phon_len=phon_len, r=1) 121 | attn_scores = diag_measure + jumpiness + peakiness 122 | durations = [] 123 | final_alignment = [] 124 | for batch_num, al in enumerate(batch_alignments): 125 | unpad_mel_len = mel_len[batch_num] 126 | unpad_phon_len = phon_len[batch_num] 127 | unpad_alignments = al[:, 1:unpad_mel_len, 1:unpad_phon_len] # first dim is heads 128 | scored_attention = unpad_alignments * attn_scores[batch_num][:, None, None] 129 | 130 | if weighted: 131 | ref_attention_weights = np.sum(scored_attention, axis=0) 132 | else: 133 | best_head = np.argmax(attn_scores[batch_num]) 134 | ref_attention_weights = unpad_alignments[best_head] 135 | integer_durations = extract_durations_with_dijkstra(ref_attention_weights) 136 | 137 | assert np.sum(integer_durations) == mel_len[batch_num]-1, f'{np.sum(integer_durations)} vs {mel_len[batch_num]-1}' 138 | new_alignment = duration_to_alignment_matrix(integer_durations.astype(int)) 139 | best_head = np.argmax(attn_scores[batch_num]) 140 | best_attention = unpad_alignments[best_head] 141 | final_alignment.append(best_attention.T + new_alignment) 142 | durations.append(integer_durations) 143 | return durations, final_alignment, jumpiness, peakiness, diag_measure 144 | -------------------------------------------------------------------------------- /utils/decorators.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from time import time 3 | 4 | 5 | def ignore_exception(f): 6 | def apply_func(*args, **kwargs): 7 | try: 8 | result = f(*args, **kwargs) 9 | return result 10 | except Exception: 11 | print(f'Catched exception in {f}:') 12 | traceback.print_exc() 13 | return None 14 | 15 | return apply_func 16 | 17 | 18 | def time_it(f): 19 | def apply_func(*args, **kwargs): 20 | t_start = time() 21 | result = f(*args, **kwargs) 22 | t_end = time() 23 | dur = round(t_end - t_start, ndigits=2) 24 | return result, dur 25 | 26 | return apply_func 27 | -------------------------------------------------------------------------------- /utils/display.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def buffer_image(figure): 8 | buf = io.BytesIO() 9 | figure.savefig(buf, format='png') 10 | buf.seek(0) 11 | plt.close('all') 12 | return buf 13 | 14 | def plot1D(y, figsize=None, title='', x=None): 15 | f = plt.figure(figsize=figsize) 16 | if x is None: 17 | x = np.arange(len(y)) 18 | plt.plot(x,y) 19 | plt.title(title) 20 | buf = buffer_image(f) 21 | return buf 22 | 23 | 24 | def plot_image(image, with_bar, figsize=None, title=''): 25 | """Create a pyplot plot and save to buffer.""" 26 | f = plt.figure(figsize=figsize) 27 | plt.imshow(image) 28 | plt.title(title) 29 | if with_bar: 30 | plt.colorbar() 31 | buf = buffer_image(f) 32 | return buf 33 | 34 | 35 | def tight_grid(images): 36 | images = np.array(images) 37 | images = np.pad(images, [[0, 0], [1, 1], [1, 1]], 'constant', constant_values=1) # add borders 38 | if len(images.shape) != 3: 39 | raise Exception 40 | else: 41 | n, y, x = images.shape 42 | ratio = y / x 43 | if ratio > 1: 44 | ny = max(int(np.sqrt(n / ratio)), 1) 45 | nx = int(n / ny) 46 | nx += n - (nx * ny) 47 | extra = nx * ny - n 48 | else: 49 | nx = max(int(np.sqrt(n * ratio)), 1) 50 | ny = int(n / nx) 51 | ny += n - (nx * ny) 52 | extra = nx * ny - n 53 | tot = np.append(images, np.zeros((extra, y, x)), axis=0) 54 | img = np.block([[*tot[i * nx:(i + 1) * nx]] for i in range(ny)]) 55 | return img 56 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import tensorflow as tf 4 | 5 | from data.audio import Audio 6 | from utils.display import tight_grid, buffer_image, plot_image, plot1D 7 | from utils.vec_ops import norm_tensor 8 | from utils.decorators import ignore_exception 9 | 10 | 11 | def control_frequency(f): 12 | def apply_func(*args, **kwargs): 13 | # args[0] is self 14 | plot_all = ('plot_all' in kwargs) and kwargs['plot_all'] 15 | if (args[0].global_step % args[0].plot_frequency == 0) or plot_all: 16 | result = f(*args, **kwargs) 17 | return result 18 | else: 19 | return None 20 | 21 | return apply_func 22 | 23 | 24 | class SummaryManager: 25 | """ Writes tensorboard logs during training. 26 | 27 | :arg model: model object that is trained 28 | :arg log_dir: base directory where logs of a config are created 29 | :arg config: configuration dictionary 30 | :arg max_plot_frequency: every how many steps to plot 31 | """ 32 | 33 | def __init__(self, 34 | model: tf.keras.models.Model, 35 | log_dir: str, 36 | config: dict, 37 | max_plot_frequency=10, 38 | default_writer='log_dir'): 39 | self.model = model 40 | self.log_dir = Path(log_dir) 41 | self.config = config 42 | self.audio = Audio.from_config(config) 43 | self.plot_frequency = max_plot_frequency 44 | self.default_writer = default_writer 45 | self.writers = {} 46 | self.add_writer(tag=default_writer, path=self.log_dir, default=True) 47 | 48 | def add_writer(self, path, tag=None, default=False): 49 | """ Adds a writer to self.writers if the writer does not exist already. 50 | To avoid spamming writers on disk. 51 | 52 | :returns the writer on path with tag tag or path 53 | """ 54 | if not tag: 55 | tag = path 56 | if tag not in self.writers.keys(): 57 | self.writers[tag] = tf.summary.create_file_writer(str(path)) 58 | if default: 59 | self.default_writer = tag 60 | return self.writers[tag] 61 | 62 | @property 63 | def global_step(self): 64 | if self.model is not None: 65 | return self.model.step 66 | else: 67 | return 0 68 | 69 | def add_scalars(self, tag, dictionary, step=None): 70 | if step is None: 71 | step = self.global_step 72 | for k in dictionary.keys(): 73 | with self.add_writer(str(self.log_dir / k)).as_default(): 74 | tf.summary.scalar(name=tag, data=dictionary[k], step=step) 75 | 76 | def add_scalar(self, tag, scalar_value, step=None): 77 | if step is None: 78 | step = self.global_step 79 | with self.writers[self.default_writer].as_default(): 80 | tf.summary.scalar(name=tag, data=scalar_value, step=step) 81 | 82 | def add_image(self, tag, image, step=None): 83 | if step is None: 84 | step = self.global_step 85 | with self.writers[self.default_writer].as_default(): 86 | tf.summary.image(name=tag, data=image, step=step, max_outputs=4) 87 | 88 | def add_histogram(self, tag, values, buckets=None, step=None): 89 | if step is None: 90 | step = self.global_step 91 | with self.writers[self.default_writer].as_default(): 92 | tf.summary.histogram(name=tag, data=values, step=step, buckets=buckets) 93 | 94 | def add_audio(self, tag, wav, sr, step=None, description=None): 95 | if step is None: 96 | step = self.global_step 97 | with self.writers[self.default_writer].as_default(): 98 | tf.summary.audio(name=tag, 99 | data=wav, 100 | sample_rate=sr, 101 | step=step, 102 | description=description) 103 | 104 | def add_text(self, tag, text, step=None): 105 | if step is None: 106 | step = self.global_step 107 | with self.writers[self.default_writer].as_default(): 108 | tf.summary.text(name=tag, 109 | data=text, 110 | step=step) 111 | 112 | @ignore_exception 113 | def display_attention_heads(self, outputs: dict, tag='', step: int=None, fname: list=None): 114 | if step is None: 115 | step = self.global_step 116 | for layer in ['encoder_attention', 'decoder_attention']: 117 | for k in outputs[layer].keys(): 118 | if fname is None: 119 | image = tight_grid(norm_tensor(outputs[layer][k][0])) # dim 0 of image_batch is now number of heads 120 | if k == 'Decoder_LastBlock_CrossAttention': 121 | batch_plot_path = f'{tag}_Decoder_Final_Attention' 122 | else: 123 | batch_plot_path = f'{tag}_{layer}/{k}' 124 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step) 125 | else: 126 | for j, file in enumerate(fname): 127 | image = tight_grid( 128 | norm_tensor(outputs[layer][k][j])) # dim 0 of image_batch is now number of heads 129 | if k == 'Decoder_LastBlock_CrossAttention': 130 | batch_plot_path = f'{tag}_Decoder_Final_Attention/{file.numpy().decode("utf-8")}' 131 | else: 132 | batch_plot_path = f'{tag}_{layer}/{k}/{file.numpy().decode("utf-8")}' 133 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step) 134 | 135 | @ignore_exception 136 | def display_last_attention(self, outputs, tag='', step=None, fname=None): 137 | if step is None: 138 | step = self.global_step 139 | 140 | if fname is None: 141 | image = tight_grid(norm_tensor(outputs['decoder_attention']['Decoder_LastBlock_CrossAttention'][0])) # dim 0 of image_batch is now number of heads 142 | batch_plot_path = f'{tag}_Decoder_Final_Attention' 143 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step) 144 | else: 145 | for j, file in enumerate(fname): 146 | image = tight_grid( 147 | norm_tensor(outputs['decoder_attention']['Decoder_LastBlock_CrossAttention'][j])) # dim 0 of image_batch is now number of heads 148 | batch_plot_path = f'{tag}_Decoder_Final_Attention/{file.numpy().decode("utf-8")}' 149 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step) 150 | 151 | @ignore_exception 152 | def display_mel(self, mel, tag='', step=None): 153 | if step is None: 154 | step = self.global_step 155 | img = tf.transpose(mel) 156 | figure = self.audio.display_mel(img, is_normal=True) 157 | buf = buffer_image(figure) 158 | img_tf = tf.image.decode_png(buf.getvalue(), channels=3) 159 | self.add_image(tag, tf.expand_dims(img_tf, 0), step=step) 160 | 161 | @ignore_exception 162 | def display_image(self, image, with_bar=False, figsize=None, tag='', step=None): 163 | if step is None: 164 | step = self.global_step 165 | buf = plot_image(image, with_bar=with_bar, figsize=figsize) 166 | image = tf.image.decode_png(buf.getvalue(), channels=4) 167 | image = tf.expand_dims(image, 0) 168 | self.add_image(tag=tag, image=image, step=step) 169 | 170 | @ignore_exception 171 | def display_plot1D(self, y, x=None, figsize=None, tag='', step=None): 172 | if step is None: 173 | step = self.global_step 174 | buf = plot1D(y, x=x, figsize=figsize) 175 | image = tf.image.decode_png(buf.getvalue(), channels=4) 176 | image = tf.expand_dims(image, 0) 177 | self.add_image(tag=tag, image=image, step=step) 178 | 179 | @control_frequency 180 | @ignore_exception 181 | def display_loss(self, output, tag='', plot_all=False, step=None): 182 | if step is None: 183 | step = self.global_step 184 | self.add_scalars(tag=f'{tag}/losses', dictionary=output['losses'], step=step) 185 | self.add_scalar(tag=f'{tag}/loss', scalar_value=output['loss'], step=step) 186 | 187 | @control_frequency 188 | @ignore_exception 189 | def display_scalar(self, tag, scalar_value, plot_all=False, step=None): 190 | if step is None: 191 | step = self.global_step 192 | self.add_scalar(tag=tag, scalar_value=scalar_value, step=step) 193 | 194 | @ignore_exception 195 | def display_audio(self, tag, mel, step=None, description=None): 196 | wav = tf.transpose(mel) 197 | wav = self.audio.reconstruct_waveform(wav) 198 | wav = tf.expand_dims(wav, 0) 199 | wav = tf.expand_dims(wav, -1) 200 | self.add_audio(tag, wav.numpy(), sr=self.config['sampling_rate'], step=step, description=description) 201 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def new_scaled_crossentropy(index=2, scaling=1.0): 5 | """ 6 | Returns masked crossentropy with extra scaling: 7 | Scales the loss for given stop_index by stop_scaling 8 | """ 9 | 10 | def masked_crossentropy(targets: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: 11 | crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 12 | padding_mask = tf.math.equal(targets, 0) 13 | padding_mask = tf.math.logical_not(padding_mask) 14 | padding_mask = tf.cast(padding_mask, dtype=tf.float32) 15 | stop_mask = tf.math.equal(targets, index) 16 | stop_mask = tf.cast(stop_mask, dtype=tf.float32) * (scaling - 1.) 17 | combined_mask = padding_mask + stop_mask 18 | loss = crossentropy(targets, logits, sample_weight=combined_mask) 19 | return loss 20 | 21 | return masked_crossentropy 22 | 23 | 24 | def masked_crossentropy(targets: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: 25 | crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 26 | mask = tf.math.logical_not(tf.math.equal(targets, 0)) 27 | mask = tf.cast(mask, dtype=tf.int32) 28 | loss = crossentropy(targets, logits, sample_weight=mask) 29 | return loss 30 | 31 | 32 | def masked_mean_squared_error(targets: tf.Tensor, logits: tf.Tensor) -> tf.Tensor: 33 | mse = tf.keras.losses.MeanSquaredError() 34 | mask = tf.math.logical_not(tf.math.equal(targets, 0)) 35 | mask = tf.cast(mask, dtype=tf.int32) 36 | mask = tf.reduce_max(mask, axis=-1) 37 | loss = mse(targets, logits, sample_weight=mask) 38 | return loss 39 | 40 | 41 | def masked_mean_absolute_error(targets: tf.Tensor, logits: tf.Tensor, mask_value=0, 42 | mask: tf.Tensor = None) -> tf.Tensor: 43 | mae = tf.keras.losses.MeanAbsoluteError() 44 | if mask is not None: 45 | mask = tf.math.logical_not(tf.math.equal(targets, mask_value)) 46 | mask = tf.cast(mask, dtype=tf.int32) 47 | mask = tf.reduce_max(mask, axis=-1) 48 | loss = mae(targets, logits, sample_weight=mask) 49 | return loss 50 | 51 | 52 | def masked_binary_crossentropy(targets: tf.Tensor, logits: tf.Tensor, mask_value=-1) -> tf.Tensor: 53 | bc = tf.keras.losses.BinaryCrossentropy(reduction='none') 54 | mask = tf.math.logical_not(tf.math.equal(logits, 55 | mask_value)) # TODO: masking based on the logits requires a masking layer. But masking layer produces 0. as outputs. 56 | # Need explicit masking 57 | mask = tf.cast(mask, dtype=tf.int32) 58 | loss_ = bc(targets, logits) 59 | loss_ *= mask 60 | return tf.reduce_mean(loss_) 61 | 62 | 63 | def weighted_sum_losses(targets, pred, loss_functions, coeffs): 64 | total_loss = 0 65 | loss_vals = [] 66 | for i in range(len(loss_functions)): 67 | loss = loss_functions[i](targets[i], pred[i]) 68 | loss_vals.append(loss) 69 | total_loss += coeffs[i] * loss 70 | return total_loss, loss_vals 71 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def attention_score(att, mel_len, phon_len, r): 5 | """ 6 | returns a tuple of scores (loc_score, sharp_score), where loc_score measures monotonicity and 7 | sharp_score measures the sharpness of attention peaks 8 | attn_weights : [N, n_heads, mel_dim, phoneme_dim] 9 | """ 10 | assert len(tf.shape(att)) == 4 11 | 12 | mask = tf.range(tf.shape(att)[2])[None, :] < mel_len[:, None] 13 | mask = tf.cast(mask, tf.int32)[:, None, :] # [N, 1, mel_dim] 14 | 15 | # distance between max (jumpiness) 16 | loc_score = attention_jumps_score(att=att, mel_mask=mask, mel_len=mel_len, r=r) 17 | 18 | # variance 19 | peak_score = attention_peak_score(att, mask) 20 | 21 | # diagonality 22 | diag_score = diagonality_score(att, mel_len, phon_len) 23 | 24 | return loc_score, peak_score, 3. / diag_score 25 | 26 | 27 | def attention_jumps_score(att, mel_mask, mel_len, r): 28 | max_loc = tf.argmax(att, axis=3) # [N, n_heads, mel_max] 29 | max_loc_diff = tf.abs(max_loc[:, :, 1:] - max_loc[:, :, :-1]) # [N, h_heads, mel_max - 1] 30 | loc_score = tf.cast(max_loc_diff >= 0, tf.int32) * tf.cast(max_loc_diff <= r, tf.int32) # [N, h_heads, mel_max - 1] 31 | loc_score = tf.reduce_sum(loc_score * mel_mask[:, :, 1:], axis=-1) 32 | loc_score = loc_score / (mel_len - 1)[:, None] 33 | return tf.cast(loc_score, tf.float32) 34 | 35 | 36 | def attention_peak_score(att, mel_mask): 37 | max_loc = tf.reduce_max(att, axis=3) # [N, n_heads, mel_dim] 38 | peak_score = tf.reduce_mean(max_loc * tf.cast(mel_mask, tf.float32), axis=-1) 39 | return tf.cast(peak_score, tf.float32) 40 | 41 | def diagonality_score(att, mel_len, phon_len, diag_mask=None): 42 | if diag_mask is None: 43 | diag_mask = batch_diagonal_mask(att, mel_len, phon_len) 44 | diag_score = tf.reduce_sum(att * diag_mask, axis=(-2, -1)) 45 | return diag_score 46 | 47 | def batch_diagonal_mask(att, mel_len, phon_len): 48 | batch_size = tf.shape(att)[0] 49 | mel_size = tf.shape(att)[2] 50 | phon_size = tf.shape(att)[3] 51 | diag_mask = tf.TensorArray(tf.float32, size=batch_size) 52 | for i in range(batch_size): 53 | d_mask = diagonal_mask(mel_len[i], phon_len[i], padded_shape=(mel_size, phon_size)) 54 | diag_mask = diag_mask.write(i, d_mask) 55 | diag_mask = tf.cast(diag_mask.stack(), tf.float32) 56 | diag_mask = tf.expand_dims(diag_mask, 1) 57 | return diag_mask 58 | 59 | 60 | def diagonal_mask(mel_len, phon_len, padded_shape): 61 | """ exponential loss mask based on distance from euclidean diagonal""" 62 | max_m = tf.cast(mel_len, tf.int32) 63 | if max_m > padded_shape[0]: # this can happen due to rounding errors when calculating mel lengths with r>1 64 | max_m = padded_shape[0] 65 | max_n = tf.cast(phon_len, tf.int32) 66 | i = tf.tile(tf.range(max_n)[None, :], [max_m, 1]) / max_n 67 | j = tf.tile(tf.range(max_m)[:, None], [1, max_n]) / max_m 68 | diag_mask = tf.math.sqrt(tf.square(i - j)) 69 | expanded_mask = tf.pad(diag_mask, [[0, padded_shape[0] - max_m], [0, padded_shape[1] - max_n]]) 70 | return tf.cast(expanded_mask, tf.float32) 71 | -------------------------------------------------------------------------------- /utils/scheduling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def linear_function(x, x0, x1, y0, y1): 6 | m = (y1 - y0) / (x1 - x0) 7 | b = y0 - m * x0 8 | return m * x + b 9 | 10 | 11 | def piecewise_linear(step, X, Y): 12 | """ 13 | Piecewise linear function. 14 | 15 | :param step: current step. 16 | :param X: list of breakpoints 17 | :param Y: list of values at breakpoints 18 | :return: value of piecewise linear function with values Y_i at step X_i 19 | """ 20 | assert len(X) == len(Y) 21 | X = np.array(X) 22 | if step < X[0]: 23 | return Y[0] 24 | idx = np.where(step >= X)[0][-1] 25 | if idx == (len(Y) - 1): 26 | return Y[-1] 27 | else: 28 | return linear_function(step, X[idx], X[idx + 1], Y[idx], Y[idx + 1]) 29 | 30 | 31 | def piecewise_linear_schedule(step, schedule): 32 | schedule = np.array(schedule) 33 | x_schedule = schedule[:, 0] 34 | y_schedule = schedule[:, 1] 35 | value = piecewise_linear(step, x_schedule, y_schedule) 36 | return tf.cast(value, tf.float32) 37 | 38 | 39 | def reduction_schedule(step, schedule): 40 | schedule = np.array(schedule) 41 | r = schedule[0, 0] 42 | for i in range(schedule.shape[0]): 43 | if schedule[i, 0] <= step: 44 | r = schedule[i, 1] 45 | else: 46 | break 47 | return int(r) 48 | -------------------------------------------------------------------------------- /utils/scripts_utils.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import argparse 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def dynamic_memory_allocation(): 8 | gpus = tf.config.experimental.list_physical_devices('GPU') 9 | if gpus: 10 | try: 11 | # Currently, memory growth needs to be the same across GPUs 12 | for gpu in gpus: 13 | tf.config.experimental.set_memory_growth(gpu, True) 14 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 15 | print(len(gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs') 16 | except Exception: 17 | traceback.print_exc() 18 | 19 | 20 | def basic_train_parser(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--config', dest='config', type=str) 23 | parser.add_argument('--reset_dir', dest='clear_dir', action='store_true', 24 | help="deletes everything under this config's folder.") 25 | parser.add_argument('--reset_logs', dest='clear_logs', action='store_true', 26 | help="deletes logs under this config's folder.") 27 | parser.add_argument('--reset_weights', dest='clear_weights', action='store_true', 28 | help="deletes weights under this config's folder.") 29 | return parser 30 | -------------------------------------------------------------------------------- /utils/spectrogram_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def mel_padding_mask(mel_batch, padding_value=0): 5 | return 1.0 - tf.cast(mel_batch == padding_value, tf.float32) 6 | 7 | 8 | def mel_lengths(mel_batch, padding_value=0): 9 | mask = mel_padding_mask(mel_batch, padding_value=padding_value) 10 | mel_channels = tf.shape(mel_batch)[-1] 11 | sum_tot = tf.cast(mel_channels, tf.float32) * padding_value 12 | idxs = tf.cast(tf.reduce_sum(mask, axis=-1) != sum_tot, tf.int32) 13 | return tf.reduce_sum(idxs, axis=-1) 14 | 15 | 16 | def phoneme_lengths(phonemes, phoneme_padding=0): 17 | return tf.reduce_sum(tf.cast(phonemes != phoneme_padding, tf.int32), axis=-1) 18 | -------------------------------------------------------------------------------- /utils/training_config_manager.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import shutil 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import ruamel.yaml 8 | 9 | from model.models import Aligner, ForwardTransformer 10 | from utils.scheduling import reduction_schedule 11 | 12 | 13 | class TrainingConfigManager: 14 | def __init__(self, config_path: str, aligner=False): 15 | if aligner: 16 | model_kind = 'aligner' 17 | else: 18 | model_kind = 'tts' 19 | self.config_path = Path(config_path) 20 | self.model_kind = model_kind 21 | self.yaml = ruamel.yaml.YAML() 22 | self.config = self._load_config() 23 | self.git_hash = self._get_git_hash() 24 | self.data_name = self.config['data_name'] # raw data 25 | # make session names 26 | self.session_names = {'data': f"{self.config['text_settings_name']}.{self.config['audio_settings_name']}"} 27 | self.session_names['aligner'] = f"{self.config['aligner_settings_name']}.{self.session_names['data']}" 28 | self.session_names['tts'] = f"{self.config['tts_settings_name']}.{self.config['aligner_settings_name']}" 29 | # create paths 30 | self.wav_directory = Path(self.config['wav_directory']) 31 | self.data_dir = Path(f"{self.config['train_data_directory']}.{self.data_name}") 32 | self.metadata_path = Path(self.config['metadata_path']) 33 | self.base_dir = Path(self.config['log_directory']) / self.data_name / self.session_names[model_kind] 34 | self.log_dir = self.base_dir / 'logs' 35 | self.weights_dir = self.base_dir / 'weights' 36 | self.train_metadata_path = self.data_dir / f"train_metadata.{self.config['text_settings_name']}.txt" 37 | self.valid_metadata_path = self.data_dir / f"valid_metadata.{self.config['text_settings_name']}.txt" 38 | self.phonemized_metadata_path = self.data_dir / f"phonemized_metadata.{self.config['text_settings_name']}.txt" 39 | self.mel_dir = self.data_dir / f"mels.{self.config['audio_settings_name']}" 40 | self.pitch_dir = self.data_dir / f"pitch.{self.config['audio_settings_name']}" 41 | self.duration_dir = self.data_dir / f"durations.{self.session_names['aligner']}" 42 | self.pitch_per_char = self.data_dir / f"char_pitch.{self.session_names['aligner']}" 43 | # training parameters 44 | self.learning_rate = np.array(self.config['learning_rate_schedule'])[0, 1].astype(np.float32) 45 | if model_kind == 'aligner': 46 | self.max_r = np.array(self.config['reduction_factor_schedule'])[0, 1].astype(np.int32) 47 | self.stop_scaling = self.config.get('stop_loss_scaling', 1.) 48 | 49 | def _load_config(self): 50 | all_config = {} 51 | with open(str(self.config_path), 'rb') as session_yaml: 52 | session_config = self.yaml.load(session_yaml) 53 | for key in ['paths', 'naming', 'training_data_settings','audio_settings', 54 | 'text_settings', f'{self.model_kind}_settings']: 55 | all_config.update(session_config[key]) 56 | return all_config 57 | 58 | @staticmethod 59 | def _get_git_hash(): 60 | try: 61 | return subprocess.check_output(['git', 'describe', '--always']).strip().decode() 62 | except Exception as e: 63 | print(f'WARNING: could not retrieve git hash. {e}') 64 | 65 | def _check_hash(self): 66 | try: 67 | git_hash = subprocess.check_output(['git', 'describe', '--always']).strip().decode() 68 | if self.config['git_hash'] != git_hash: 69 | print(f"WARNING: git hash mismatch. Current: {git_hash}. Training config hash: {self.config['git_hash']}") 70 | except Exception as e: 71 | print(f'WARNING: could not check git hash. {e}') 72 | 73 | @staticmethod 74 | def _print_dict_values(values, key_name, level=0, tab_size=2): 75 | tab = level * tab_size * ' ' 76 | print(tab + '-', key_name, ':', values) 77 | 78 | def _print_dictionary(self, dictionary, recursion_level=0): 79 | for key in dictionary.keys(): 80 | if isinstance(key, dict): 81 | recursion_level += 1 82 | self._print_dictionary(dictionary[key], recursion_level) 83 | else: 84 | self._print_dict_values(dictionary[key], key_name=key, level=recursion_level) 85 | 86 | def print_config(self): 87 | print('\nCONFIGURATION', self.session_names[self.model_kind]) 88 | self._print_dictionary(self.config) 89 | 90 | def update_config(self): 91 | self.config['git_hash'] = self.git_hash 92 | self.config['automatic'] = True 93 | 94 | def get_model(self, ignore_hash=False): 95 | if not ignore_hash: 96 | self._check_hash() 97 | if self.model_kind == 'aligner': 98 | return Aligner.from_config(self.config, max_r=self.max_r) 99 | else: 100 | return ForwardTransformer.from_config(self.config) 101 | 102 | def compile_model(self, model, beta_1=0.9, beta_2=0.98): 103 | optimizer = tf.keras.optimizers.Adam(self.learning_rate, 104 | beta_1=beta_1, 105 | beta_2=beta_2, 106 | epsilon=1e-9) 107 | if self.model_kind == 'aligner': 108 | model._compile(stop_scaling=self.stop_scaling, optimizer=optimizer) 109 | else: 110 | model._compile(optimizer=optimizer) 111 | 112 | def dump_config(self): 113 | self.update_config() 114 | with open(self.base_dir / f"config.yaml", 'w') as model_yaml: 115 | self.yaml.dump(self.config, model_yaml) 116 | 117 | def create_remove_dirs(self, clear_dir=False, clear_logs=False, clear_weights=False): 118 | self.base_dir.mkdir(exist_ok=True, parents=True) 119 | self.data_dir.mkdir(exist_ok=True) 120 | self.pitch_dir.mkdir(exist_ok=True) 121 | self.pitch_per_char.mkdir(exist_ok=True) 122 | self.mel_dir.mkdir(exist_ok=True) 123 | self.duration_dir.mkdir(exist_ok=True) 124 | if clear_dir: 125 | delete = input(f'Delete {self.log_dir} AND {self.weights_dir}? (y/[n])') 126 | if delete == 'y': 127 | shutil.rmtree(self.log_dir, ignore_errors=True) 128 | shutil.rmtree(self.weights_dir, ignore_errors=True) 129 | if clear_logs: 130 | delete = input(f'Delete {self.log_dir}? (y/[n])') 131 | if delete == 'y': 132 | shutil.rmtree(self.log_dir, ignore_errors=True) 133 | if clear_weights: 134 | delete = input(f'Delete {self.weights_dir}? (y/[n])') 135 | if delete == 'y': 136 | shutil.rmtree(self.weights_dir, ignore_errors=True) 137 | self.log_dir.mkdir(exist_ok=True) 138 | self.weights_dir.mkdir(exist_ok=True) 139 | 140 | def load_model(self, checkpoint_path: str = None, verbose=True): 141 | model = self.get_model() 142 | self.compile_model(model) 143 | ckpt = tf.train.Checkpoint(net=model) 144 | manager = tf.train.CheckpointManager(ckpt, self.weights_dir, 145 | max_to_keep=None) 146 | if checkpoint_path: 147 | ckpt.restore(checkpoint_path) 148 | if verbose: 149 | print(f'restored weights from {checkpoint_path} at step {model.step}') 150 | else: 151 | if manager.latest_checkpoint is None: 152 | print(f'WARNING: could not find weights file. Trying to load from \n {self.weights_dir}.') 153 | print('Edit config to point at the right log directory.') 154 | ckpt.restore(manager.latest_checkpoint) 155 | if verbose: 156 | print(f'restored weights from {manager.latest_checkpoint} at step {model.step}') 157 | if self.model_kind == 'aligner': 158 | reduction_factor = reduction_schedule(model.step, self.config['reduction_factor_schedule']) 159 | model.set_constants(reduction_factor=reduction_factor) 160 | return model 161 | -------------------------------------------------------------------------------- /utils/vec_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def norm_tensor(tensor): 5 | return tf.math.divide( 6 | tf.math.subtract( 7 | tensor, 8 | tf.math.reduce_min(tensor) 9 | ), 10 | tf.math.subtract( 11 | tf.math.reduce_max(tensor), 12 | tf.math.reduce_min(tensor) 13 | ) 14 | ) 15 | --------------------------------------------------------------------------------