├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── aligner_test_sentences.txt
├── config
├── data_config_wavernn.yaml
└── training_config.yaml
├── create_training_data.py
├── data
├── __init__.py
├── audio.py
├── datasets.py
├── metadata_readers.py
└── text
│ ├── __init__.py
│ ├── symbols.py
│ └── tokenizer.py
├── docs
├── .gitignore
├── 404.html
├── Gemfile
├── Gemfile.lock
├── README.md
├── _config.yml
├── _layouts
│ └── default.html
├── assets
│ └── css
│ │ └── style.scss
├── favicon.png
├── index.md
├── tboard_demo.gif
└── transformer_logo.png
├── extract_durations.py
├── model
├── __init__.py
├── factory.py
├── layers.py
├── models.py
└── transformer_utils.py
├── notebooks
└── synthesize_forward_melgan.ipynb
├── predict_tts.py
├── requirements.txt
├── test_sentences.txt
├── tests
├── __init__.py
├── test_char_tokenizer.py
├── test_config.yaml
└── test_loss.py
├── train_aligner.py
├── train_tts.py
└── utils
├── __init__.py
├── alignments.py
├── decorators.py
├── display.py
├── logging_utils.py
├── losses.py
├── metrics.py
├── scheduling.py
├── scripts_utils.py
├── spectrogram_ops.py
├── training_config_manager.py
└── vec_ops.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=Python
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## CUSTOM
2 | samples
3 | .idea
4 | *.pkl
5 | *.hdf5
6 | .DS_Store
7 | private
8 | logs
9 | *.pt
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | pip-wheel-metadata/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | COPYRIGHT
2 |
3 | Copyright (c) 2020 Axel Springer AI. All rights reserved.
4 |
5 | LICENSE
6 |
7 | The MIT License (MIT)
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy
10 | of this software and associated documentation files (the "Software"), to deal
11 | in the Software without restriction, including without limitation the rights
12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | copies of the Software, and to permit persons to whom the Software is
14 | furnished to do so, subject to the following conditions:
15 |
16 | The above copyright notice and this permission notice shall be included in all
17 | copies or substantial portions of the Software.
18 |
19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | A Text-to-Speech Transformer in TensorFlow 2
9 |
10 |
11 |
12 | Implementation of a non-autoregressive Transformer based neural network for Text-to-Speech (TTS).
13 | This repo is based, among others, on the following papers:
14 | - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895)
15 | - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263)
16 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558)
17 | - [FastPitch: Parallel Text-to-speech with Pitch Prediction](https://fastpitch.github.io/)
18 |
19 | Our pre-trained LJSpeech model is compatible with the pre-trained vocoders:
20 | - [MelGAN](https://github.com/seungwonpark/melgan)
21 | - [HiFiGAN](https://github.com/jik876/hifi-gan)
22 |
23 | (older versions are available also for [WaveRNN](https://github.com/fatchord/WaveRNN))
24 |
25 | For quick inference with these vocoders, checkout the [Vocoding branch](https://github.com/as-ideas/TransformerTTS/tree/vocoding)
26 |
27 | #### Non-Autoregressive
28 | Being non-autoregressive, this Transformer model is:
29 | - Robust: No repeats and failed attention modes for challenging sentences.
30 | - Fast: With no autoregression, predictions take a fraction of the time.
31 | - Controllable: It is possible to control the speed and pitch of the generated utterance.
32 |
33 | ## 🔈 Samples
34 |
35 | [Can be found here.](https://as-ideas.github.io/TransformerTTS/)
36 |
37 | These samples' spectrograms are converted using the pre-trained [MelGAN](https://github.com/seungwonpark/melgan) vocoder.
38 |
39 |
40 | Try it out on Colab:
41 |
42 | [](https://colab.research.google.com/github/as-ideas/TransformerTTS/blob/main/notebooks/synthesize_forward_melgan.ipynb)
43 |
44 | ## Updates
45 | - 06/20: Added normalisation and pre-trained models compatible with the faster [MelGAN](https://github.com/seungwonpark/melgan) vocoder.
46 | - 11/20: Added pitch prediction. Autoregressive model is now specialized as an Aligner and Forward is now the only TTS model. Changed models architectures. Discontinued WaveRNN support. Improved duration extraction with Dijkstra algorithm.
47 | - 03/20: Vocoding branch.
48 |
49 | ## 📖 Contents
50 | - [Installation](#installation)
51 | - [API](#pre-trained-ljspeech-api)
52 | - [Dataset](#dataset)
53 | - [Training](#training)
54 | - [Aligner](#train-aligner-model)
55 | - [TTS](#train-tts-model)
56 | - [Prediction](#prediction)
57 | - [Model Weights](#model-weights)
58 |
59 | ## Installation
60 |
61 | Make sure you have:
62 |
63 | * Python >= 3.6
64 |
65 | Install espeak as phonemizer backend (for macOS use brew):
66 | ```
67 | sudo apt-get install espeak
68 | ```
69 |
70 | Then install the rest with pip:
71 | ```
72 | pip install -r requirements.txt
73 | ```
74 |
75 | Read the individual scripts for more command line arguments.
76 |
77 | ## Pre-Trained LJSpeech API
78 | Use our pre-trained model (with Griffin-Lim) from command line with
79 | ```commandline
80 | python predict_tts.py -t "Please, say something."
81 | ```
82 | Or in a python script
83 | ```python
84 | from data.audio import Audio
85 | from model.factory import tts_ljspeech
86 |
87 | model = tts_ljspeech()
88 | audio = Audio.from_config(model.config)
89 | out = model.predict('Please, say something.')
90 |
91 | # Convert spectrogram to wav (with griffin lim)
92 | wav = audio.reconstruct_waveform(out['mel'].numpy().T)
93 | ```
94 |
95 | You can specify the model step with the `--step` flag (CL) or `step` parameter (script).
96 | Steps from 60000 to 100000 are available at a frequency of 5K steps (60000, 65000, ..., 95000, 100000).
97 |
98 | IMPORTANT: make sure to checkout the correct repository version to use the API.
99 | Currently 493be6345341af0df3ae829de79c2793c9afd0ec
100 |
101 | ## Dataset
102 | You can directly use [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) to create the training dataset.
103 |
104 | #### Configuration
105 | * If training on LJSpeech, or if unsure, simply use ```config/training_config.yaml``` to create [MelGAN](https://github.com/seungwonpark/melgan) or [HiFiGAN](https://github.com/jik876/hifi-gan) compatible models
106 | * swap the content of ```data_config_wavernn.yaml``` in ```config/training_config.yaml``` to create models compatible with [WaveRNN](https://github.com/fatchord/WaveRNN)
107 | * **EDIT PATHS**: in `config/training_config.yaml` edit the paths to point at your dataset and log folders
108 |
109 | #### Custom dataset
110 | Prepare a folder containing your metadata and wav files, for instance
111 | ```
112 | |- dataset_folder/
113 | | |- metadata.csv
114 | | |- wavs/
115 | | |- file1.wav
116 | | |- ...
117 | ```
118 | if `metadata.csv` has the following format
119 | ``` wav_file_name|transcription ```
120 | you can use the ljspeech preprocessor in ```data/metadata_readers.py```, otherwise add your own under the same file.
121 |
122 | Make sure that:
123 | - the metadata reader function name is the same as ```data_name``` field in ```training_config.yaml```.
124 | - the metadata file (can be anything) is specified under ```metadata_path``` in ```training_config.yaml```
125 |
126 | ## Training
127 | Change the ```--config``` argument based on the configuration of your choice.
128 | ### Train Aligner Model
129 | #### Create training dataset
130 | ```bash
131 | python create_training_data.py --config config/training_config.yaml
132 | ```
133 | This will populate the training data directory (default `transformer_tts_data.ljspeech`).
134 | #### Training
135 | ```bash
136 | python train_aligner.py --config config/training_config.yaml
137 | ```
138 | ### Train TTS Model
139 | #### Compute alignment dataset
140 | First use the aligner model to create the durations dataset
141 | ```bash
142 | python extract_durations.py --config config/training_config.yaml
143 | ```
144 | this will add the `durations.` as well as the char-wise pitch folders to the training data directory.
145 | #### Training
146 | ```bash
147 | python train_tts.py --config config/training_config.yaml
148 | ```
149 | #### Training & Model configuration
150 | - Training and model settings can be configured in `training_config.yaml`
151 |
152 | #### Resume or restart training
153 | - To resume training simply use the same configuration files
154 | - To restart training, delete the weights and/or the logs from the logs folder with the training flag `--reset_dir` (both) or `--reset_logs`, `--reset_weights`
155 |
156 | #### Monitor training
157 | ```bash
158 | tensorboard --logdir /logs/directory/
159 | ```
160 |
161 | 
162 | ## Prediction
163 | ### With model weights
164 | From command line with
165 | ```commandline
166 | python predict_tts.py -t "Please, say something." -p /path/to/weights/
167 | ```
168 | Or in a python script
169 | ```python
170 | from model.models import ForwardTransformer
171 | from data.audio import Audio
172 | model = ForwardTransformer.load_model('/path/to/weights/')
173 | audio = Audio.from_config(model.config)
174 | out = model.predict('Please, say something.')
175 |
176 | # Convert spectrogram to wav (with griffin lim)
177 | wav = audio.reconstruct_waveform(out['mel'].numpy().T)
178 | ```
179 |
180 | ## Model Weights
181 | Access the pre-trained models with the API call.
182 |
183 | Old weights
184 | | Model URL | Commit | Vocoder Commit|
185 | |---|---|---|
186 | |[ljspeech_tts_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ljspeech_weights_tts.zip)| 0cd7d33 | aca5990 |
187 | |[ljspeech_melgan_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_melgan_forward_transformer.zip)| 1c1cb03| aca5990 |
188 | |[ljspeech_melgan_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_melgan_autoregressive_transformer.zip)| 1c1cb03| aca5990 |
189 | |[ljspeech_wavernn_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_wavernn_forward_transformer.zip)| 1c1cb03| 3595219 |
190 | |[ljspeech_wavernn_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_wavernn_autoregressive_transformer.zip)| 1c1cb03| 3595219 |
191 | |[ljspeech_wavernn_forward_model](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_forward_transformer.zip)| d9ccee6| 3595219 |
192 | |[ljspeech_wavernn_autoregressive_model_v2](https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_autoregressive_transformer.zip)| d9ccee6| 3595219 |
193 | |[ljspeech_wavernn_autoregressive_model_v1](https://github.com/as-ideas/tts_model_outputs/tree/master/ljspeech_transformertts)| 2f3a1b5| 3595219 |
194 | ## Maintainers
195 | * Francesco Cardinale, github: [cfrancesco](https://github.com/cfrancesco)
196 |
197 | ## Special thanks
198 | [MelGAN](https://github.com/seungwonpark/melgan) and [WaveRNN](https://github.com/fatchord/WaveRNN): data normalization and samples' vocoders are from these repos.
199 |
200 | [Erogol](https://github.com/erogol) and the Mozilla TTS team for the lively exchange on the topic.
201 |
202 |
203 | ## Copyright
204 | See [LICENSE](LICENSE) for details.
205 |
--------------------------------------------------------------------------------
/aligner_test_sentences.txt:
--------------------------------------------------------------------------------
1 | Scientists at the CERN laboratory say they have discovered a new particle.
2 | If this is that, then what are those.
--------------------------------------------------------------------------------
/config/data_config_wavernn.yaml:
--------------------------------------------------------------------------------
1 | audio_settings_name: WaveRNN_default
2 | text_settings_name: Stress_NoBreathing
3 |
4 | # TRAINING DATA SETTINGS
5 | n_samples: 100000
6 | n_test: 100
7 | mel_start_value: .5
8 | mel_end_value: -.5
9 | max_mel_len: 1_200
10 | min_mel_len: 80
11 | bucket_boundaries: [200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200] # mel bucketing
12 | bucket_batch_sizes: [64, 42, 32, 25, 21, 18, 16, 14, 12, 6, 1]
13 | val_bucket_batch_size: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1]
14 |
15 | # AUDIO SETTINGS
16 | sampling_rate: 22050
17 | n_fft: 2048
18 | mel_channels: 80
19 | hop_length: 275
20 | win_length: 1100
21 | f_min: 40
22 | f_max: null
23 | normalizer: WaveRNN # which mel normalization to use from utils.audio.py [MelGAN or WaveRNN]
24 |
25 | # SILENCE CUTTING
26 | trim_silence_top_db: 60
27 | trim_silence: False
28 | trim_long_silences: True
29 | # Params for trimming long silences, from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/hparams.py
30 | vad_window_length: 30 # In milliseconds
31 | vad_moving_average_width: 8
32 | vad_max_silence_length: 12
33 | vad_sample_rate: 16000
34 |
35 | # TOKENIZER
36 | phoneme_language: 'en-us'
37 | with_stress: True # use stress symbols in phonemization
38 | model_breathing: false # add a token for the initial breathing
--------------------------------------------------------------------------------
/config/training_config.yaml:
--------------------------------------------------------------------------------
1 | paths:
2 | # PATHS: change accordingly
3 | wav_directory: '/path/to/wav_directory' # path to directory cointaining the wavs
4 | metadata_path: '/path/to/metadata.csv' # name of metadata file under wav_directory
5 | log_directory: '/path/to/logs_directory' # weights and logs are stored here
6 | train_data_directory: 'transformer_tts_data' # training data is stored here
7 |
8 | naming:
9 | data_name: ljspeech # raw data naming for default data reader (select function from data/metadata_readers.py)
10 | audio_settings_name: MelGAN_default
11 | text_settings_name: Stress_NoBreathing
12 | aligner_settings_name: alinger_extralayer_layernorm
13 | tts_settings_name: tts_swap_conv_dims
14 |
15 | # TRAINING DATA SETTINGS
16 | training_data_settings:
17 | n_test: 100
18 | mel_start_value: .5
19 | mel_end_value: -.5
20 | max_mel_len: 1_200
21 | min_mel_len: 80
22 | bucket_boundaries: [200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200] # mel bucketing
23 | bucket_batch_sizes: [64, 42, 32, 25, 21, 18, 16, 14, 12, 6, 1]
24 | val_bucket_batch_size: [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1]
25 |
26 | # AUDIO SETTINGS
27 | audio_settings:
28 | sampling_rate: 22050
29 | n_fft: 1024
30 | mel_channels: 80
31 | hop_length: 256
32 | win_length: 1024
33 | f_min: 0
34 | f_max: 8000
35 | normalizer: MelGAN # which mel normalization to use from utils.audio.py [MelGAN or WaveRNN]
36 |
37 | # SILENCE CUTTING
38 | trim_silence_top_db: 60
39 | trim_silence: False
40 | trim_long_silences: True
41 | # Params for trimming long silences, from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/hparams.py
42 | vad_window_length: 30 # In milliseconds
43 | vad_moving_average_width: 8
44 | vad_max_silence_length: 12
45 | vad_sample_rate: 16000
46 |
47 | # Wav normalization
48 | norm_wav: True
49 | target_dBFS: -30
50 | int16_max: 32767
51 |
52 | text_settings:
53 | # TOKENIZER
54 | phoneme_language: 'en-us'
55 | with_stress: True # use stress symbols in phonemization
56 | model_breathing: false # add a token for the initial breathing
57 |
58 | aligner_settings:
59 | # ARCHITECTURE
60 | decoder_model_dimension: 256
61 | encoder_model_dimension: 256
62 | decoder_num_heads: [4, 4, 4, 4, 1] # the length of this defines the number of layers
63 | encoder_num_heads: [4, 4, 4, 4] # the length of this defines the number of layers
64 | encoder_feed_forward_dimension: 512
65 | decoder_feed_forward_dimension: 512
66 | decoder_prenet_dimension: 256
67 | encoder_prenet_dimension: 256
68 | encoder_max_position_encoding: 10000
69 | decoder_max_position_encoding: 10000
70 |
71 | # LOSSES
72 | stop_loss_scaling: 8
73 |
74 | # TRAINING
75 | dropout_rate: 0.1
76 | decoder_prenet_dropout: 0.1
77 | learning_rate_schedule:
78 | - [0, 1.0e-4]
79 | reduction_factor_schedule:
80 | - [0, 10]
81 | - [80_000, 5]
82 | - [100_000, 2]
83 | - [130_000, 1]
84 | max_steps: 260_000
85 | force_encoder_diagonal_steps: 500
86 | force_decoder_diagonal_steps: 7_000
87 | extract_attention_weighted: False # weighted average between last layer decoder attention heads when extracting durations
88 | debug: False
89 |
90 | # LOGGING
91 | validation_frequency: 5_000
92 | weights_save_frequency: 5_000
93 | train_images_plotting_frequency: 1_000
94 | keep_n_weights: 2
95 | keep_checkpoint_every_n_hours: 12
96 | n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
97 | prediction_start_step: 10_000 # step after which to predict durations at validation time
98 | prediction_frequency: 5_000
99 | test_stencences:
100 | - aligner_test_sentences.txt
101 |
102 | tts_settings:
103 | # ARCHITECTURE
104 | decoder_model_dimension: 384
105 | encoder_model_dimension: 384
106 | decoder_num_heads: [2, 2, 2, 2, 2, 2] # the length of this defines the number of layers
107 | encoder_num_heads: [2, 2, 2, 2, 2, 2] # the length of this defines the number of layers
108 | encoder_feed_forward_dimension: null
109 | decoder_feed_forward_dimension: null
110 | encoder_attention_conv_filters: [1536, 384]
111 | decoder_attention_conv_filters: [1536, 384]
112 | encoder_attention_conv_kernel: 3
113 | decoder_attention_conv_kernel: 3
114 | encoder_max_position_encoding: 2000
115 | decoder_max_position_encoding: 10000
116 | encoder_dense_blocks: 0
117 | decoder_dense_blocks: 0
118 | transposed_attn_convs: True # if True, convolutions after MHA are over time.
119 |
120 | # STATS PREDICTORS ARCHITECTURE
121 | duration_conv_filters: [256, 226]
122 | pitch_conv_filters: [256, 226]
123 | duration_kernel_size: 3
124 | pitch_kernel_size: 3
125 |
126 | # TRAINING
127 | predictors_dropout: 0.1
128 | dropout_rate: 0.1
129 | learning_rate_schedule:
130 | - [0, 1.0e-4]
131 | max_steps: 100_000
132 | debug: False
133 |
134 | # LOGGING
135 | validation_frequency: 5_000
136 | prediction_frequency: 5_000
137 | weights_save_frequency: 5_000
138 | weights_save_starting_step: 5_000
139 | train_images_plotting_frequency: 1_000
140 | keep_n_weights: 5
141 | keep_checkpoint_every_n_hours: 12
142 | n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
143 | prediction_start_step: 4_000
144 | text_prediction:
145 | - test_sentences.txt
--------------------------------------------------------------------------------
/create_training_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import pickle
4 |
5 | import numpy as np
6 | from p_tqdm import p_uimap, p_umap
7 |
8 | from utils.logging_utils import SummaryManager
9 | from data.text import TextToTokens
10 | from data.datasets import DataReader
11 | from utils.training_config_manager import TrainingConfigManager
12 | from data.audio import Audio
13 | from data.text.symbols import _alphabet
14 |
15 | np.random.seed(42)
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--config', type=str, required=True)
19 | parser.add_argument('--skip_phonemes', action='store_true')
20 | parser.add_argument('--skip_mels', action='store_true')
21 |
22 | args = parser.parse_args()
23 | for arg in vars(args):
24 | print('{}: {}'.format(arg, getattr(args, arg)))
25 |
26 | cm = TrainingConfigManager(args.config, aligner=True)
27 | cm.create_remove_dirs()
28 | metadatareader = DataReader.from_config(cm, kind='original', scan_wavs=True)
29 | summary_manager = SummaryManager(model=None, log_dir=cm.log_dir / 'data_preprocessing', config=cm.config,
30 | default_writer='data_preprocessing')
31 | file_ids_from_wavs = list(metadatareader.wav_paths.keys())
32 | print(f"Reading wavs from {metadatareader.wav_directory}")
33 | print(f"Reading metadata from {metadatareader.metadata_path}")
34 | print(f'\nFound {len(metadatareader.filenames)} metadata lines.')
35 | print(f'\nFound {len(file_ids_from_wavs)} wav files.')
36 | cross_file_ids = [fid for fid in file_ids_from_wavs if fid in metadatareader.filenames]
37 | print(f'\nThere are {len(cross_file_ids)} wav file names that correspond to metadata lines.')
38 |
39 | if not args.skip_mels:
40 |
41 | def process_wav(wav_path: Path):
42 | file_name = wav_path.stem
43 | y, sr = audio.load_wav(str(wav_path))
44 | pitch = audio.extract_pitch(y)
45 | mel = audio.mel_spectrogram(y)
46 | assert mel.shape[1] == audio.config['mel_channels'], len(mel.shape) == 2
47 | assert mel.shape[0] == pitch.shape[0], f'{mel.shape[0]} == {pitch.shape[0]} (wav {y.shape})'
48 | mel_path = (cm.mel_dir / file_name).with_suffix('.npy')
49 | pitch_path = (cm.pitch_dir / file_name).with_suffix('.npy')
50 | np.save(mel_path, mel)
51 | np.save(pitch_path, pitch)
52 | return {'fname': file_name, 'mel.len': mel.shape[0], 'pitch.path': pitch_path, 'pitch': pitch}
53 |
54 |
55 | print(f"\nMels will be stored stored under")
56 | print(f"{cm.mel_dir}")
57 | audio = Audio.from_config(config=cm.config)
58 | wav_files = [metadatareader.wav_paths[k] for k in cross_file_ids]
59 | len_dict = {}
60 | remove_files = []
61 | mel_lens = []
62 | pitches = {}
63 | wav_iter = p_uimap(process_wav, wav_files)
64 | for out_dict in wav_iter:
65 | len_dict.update({out_dict['fname']: out_dict['mel.len']})
66 | pitches.update({out_dict['pitch.path']: out_dict['pitch']})
67 | if out_dict['mel.len'] > cm.config['max_mel_len'] or out_dict['mel.len'] < cm.config['min_mel_len']:
68 | remove_files.append(out_dict['fname'])
69 | else:
70 | mel_lens.append(out_dict['mel.len'])
71 |
72 |
73 | def normalize_pitch_vectors(pitch_vecs):
74 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]]
75 | for v in pitch_vecs.values()])
76 | mean, std = np.mean(nonzeros), np.std(nonzeros)
77 | return mean, std
78 |
79 |
80 | def process_pitches(item: tuple):
81 | fname, pitch = item
82 | zero_idxs = np.where(pitch == 0.0)[0]
83 | pitch -= mean
84 | pitch /= std
85 | pitch[zero_idxs] = 0.0
86 | np.save(fname, pitch)
87 |
88 |
89 | mean, std = normalize_pitch_vectors(pitches)
90 | pickle.dump({'pitch_mean': mean, 'pitch_std': std}, open(cm.data_dir / 'pitch_stats.pkl', 'wb'))
91 | pitch_iter = p_umap(process_pitches, pitches.items())
92 |
93 | pickle.dump(len_dict, open(cm.data_dir / 'mel_len.pkl', 'wb'))
94 | pickle.dump(remove_files, open(cm.data_dir / 'under-over_sized_mels.pkl', 'wb'))
95 | summary_manager.add_histogram('Mel Lengths', values=np.array(mel_lens))
96 | total_mel_len = np.sum(mel_lens)
97 | total_wav_len = total_mel_len * audio.config['hop_length']
98 | summary_manager.display_scalar('Total duration (hours)',
99 | scalar_value=total_wav_len / audio.config['sampling_rate'] / 60. ** 2)
100 |
101 | if not args.skip_phonemes:
102 | remove_files = pickle.load(open(cm.data_dir / 'under-over_sized_mels.pkl', 'rb'))
103 | phonemized_metadata_path = cm.phonemized_metadata_path
104 | train_metadata_path = cm.train_metadata_path
105 | test_metadata_path = cm.valid_metadata_path
106 | print(f'\nReading metadata from {metadatareader.metadata_path}')
107 | print(f'\nFound {len(metadatareader.filenames)} lines.')
108 | filter_metadata = []
109 | for fname in cross_file_ids:
110 | item = metadatareader.text_dict[fname]
111 | non_p = [c for c in item if c in _alphabet]
112 | if len(non_p) < 1:
113 | filter_metadata.append(fname)
114 | if len(filter_metadata) > 0:
115 | print(f'Removing {len(filter_metadata)} suspiciously short line(s):')
116 | for fname in filter_metadata:
117 | print(f'{fname}: {metadatareader.text_dict[fname]}')
118 | print(f'\nRemoving {len(remove_files)} line(s) due to mel filtering.')
119 | remove_files += filter_metadata
120 | metadata_file_ids = [fname for fname in cross_file_ids if fname not in remove_files]
121 | metadata_len = len(metadata_file_ids)
122 | sample_items = np.random.choice(metadata_file_ids, 5)
123 | test_len = cm.config['n_test']
124 | train_len = metadata_len - test_len
125 | print(f'\nMetadata contains {metadata_len} lines.')
126 | print(f'\nFiles will be stored under {cm.data_dir}')
127 | print(f' - all: {phonemized_metadata_path}')
128 | print(f' - {train_len} training lines: {train_metadata_path}')
129 | print(f' - {test_len} validation lines: {test_metadata_path}')
130 |
131 | print('\nMetadata samples:')
132 | for i in sample_items:
133 | print(f'{i}:{metadatareader.text_dict[i]}')
134 | summary_manager.add_text(f'{i}/text', text=metadatareader.text_dict[i])
135 |
136 | # run cleaner on raw text
137 | text_proc = TextToTokens.default(cm.config['phoneme_language'], add_start_end=False,
138 | with_stress=cm.config['with_stress'], model_breathing=cm.config['model_breathing'],
139 | njobs=1)
140 |
141 |
142 | def process_phonemes(file_id):
143 | text = metadatareader.text_dict[file_id]
144 | try:
145 | phon = text_proc.phonemizer(text)
146 | except Exception as e:
147 | print(f'{e}\nFile id {file_id}')
148 | raise BrokenPipeError
149 | return (file_id, phon)
150 |
151 |
152 | print('\nPHONEMIZING')
153 | phonemized_data = {}
154 | phon_iter = p_uimap(process_phonemes, metadata_file_ids)
155 | for (file_id, phonemes) in phon_iter:
156 | phonemized_data.update({file_id: phonemes})
157 |
158 | print('\nPhonemized metadata samples:')
159 | for i in sample_items:
160 | print(f'{i}:{phonemized_data[i]}')
161 | summary_manager.add_text(f'{i}/phonemes', text=phonemized_data[i])
162 |
163 | new_metadata = [f'{k}|{v}\n' for k, v in phonemized_data.items()]
164 | shuffled_metadata = np.random.permutation(new_metadata)
165 | train_metadata = shuffled_metadata[0:train_len]
166 | test_metadata = shuffled_metadata[-test_len:]
167 |
168 | with open(phonemized_metadata_path, 'w+', encoding='utf-8') as file:
169 | file.writelines(new_metadata)
170 | with open(train_metadata_path, 'w+', encoding='utf-8') as file:
171 | file.writelines(train_metadata)
172 | with open(test_metadata_path, 'w+', encoding='utf-8') as file:
173 | file.writelines(test_metadata)
174 | # some checks
175 | assert metadata_len == len(set(list(phonemized_data.keys()))), \
176 | f'Length of metadata ({metadata_len}) does not match the length of the phoneme array ({len(set(list(phonemized_data.keys())))}). Check for empty text lines in metadata.'
177 | assert len(train_metadata) + len(test_metadata) == metadata_len, \
178 | f'Train and/or validation lengths incorrect. ({len(train_metadata)} + {len(test_metadata)} != {metadata_len})'
179 |
180 | print('\nDone')
181 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/data/__init__.py
--------------------------------------------------------------------------------
/data/audio.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import struct
3 |
4 | import librosa
5 | import numpy as np
6 | import librosa.display
7 | from matplotlib import pyplot as plt
8 | import soundfile as sf
9 | import webrtcvad
10 | from scipy.ndimage import binary_dilation
11 | import pyworld as pw
12 |
13 |
14 | class Audio():
15 | def __init__(self,
16 | sampling_rate: int,
17 | n_fft: int,
18 | mel_channels: int,
19 | hop_length: int,
20 | win_length: int,
21 | f_min: int,
22 | f_max: int,
23 | normalizer: str,
24 | norm_wav: bool = None,
25 | target_dBFS: int = None,
26 | int16_max: int = None,
27 | trim_long_silences: bool = None,
28 | trim_silence: bool = None,
29 | trim_silence_top_db: int = None,
30 | vad_window_length: int = None,
31 | vad_sample_rate: int = None,
32 | vad_moving_average_width: int = None,
33 | vad_max_silence_length: int = None,
34 | **kwargs):
35 |
36 | self.config = self._make_config(locals())
37 | self.sampling_rate = sampling_rate
38 | self.n_fft = n_fft
39 | self.mel_channels = mel_channels
40 | self.hop_length = hop_length
41 | self.win_length = win_length
42 | self.f_min = f_min
43 | self.f_max = f_max
44 | self.norm_wav = norm_wav
45 | self.target_dBFS = target_dBFS
46 | self.int16_max = int16_max
47 | self.trim_long_silences = trim_long_silences
48 | self.trim_silence = trim_silence
49 | self.trim_silence_top_db = trim_silence_top_db
50 | self.vad_window_length = vad_window_length
51 | self.vad_sample_rate = vad_sample_rate
52 | self.vad_moving_average_width = vad_moving_average_width
53 | self.vad_max_silence_length = vad_max_silence_length
54 | self.normalizer = getattr(sys.modules[__name__], normalizer)()
55 |
56 | def _make_config(self, locals) -> dict:
57 | config = {}
58 | for k in locals:
59 | if (k != 'self') and (k != '__class__'):
60 | if isinstance(locals[k], dict):
61 | config.update(locals[k])
62 | else:
63 | config.update({k: locals[k]})
64 | return dict(config)
65 |
66 | def _normalize(self, S):
67 | return self.normalizer.normalize(S)
68 |
69 | def _denormalize(self, S):
70 | return self.normalizer.denormalize(S)
71 |
72 | def _linear_to_mel(self, spectrogram):
73 | return librosa.feature.melspectrogram(
74 | S=spectrogram,
75 | sr=self.sampling_rate,
76 | n_fft=self.n_fft,
77 | n_mels=self.mel_channels,
78 | fmin=self.f_min,
79 | fmax=self.f_max)
80 |
81 | def _stft(self, y):
82 | return librosa.stft(
83 | y=y,
84 | n_fft=self.n_fft,
85 | hop_length=self.hop_length,
86 | win_length=self.win_length)
87 |
88 | def mel_spectrogram(self, wav):
89 | """ This is what the model is trained to reproduce. """
90 | D = self._stft(wav)
91 | S = self._linear_to_mel(np.abs(D))
92 | return self._normalize(S).T
93 |
94 | def reconstruct_waveform(self, mel, n_iter=32):
95 | """ Uses Griffin-Lim phase reconstruction to convert from a normalized
96 | mel spectrogram back into a waveform. """
97 | amp_mel = self._denormalize(mel)
98 | S = librosa.feature.inverse.mel_to_stft(
99 | amp_mel,
100 | power=1,
101 | sr=self.sampling_rate,
102 | n_fft=self.n_fft,
103 | fmin=self.f_min,
104 | fmax=self.f_max)
105 | wav = librosa.core.griffinlim(
106 | S,
107 | n_iter=n_iter,
108 | hop_length=self.hop_length,
109 | win_length=self.win_length)
110 | return wav
111 |
112 | def display_mel(self, mel, is_normal=True):
113 | if is_normal:
114 | mel = self._denormalize(mel)
115 | f = plt.figure(figsize=(10, 4))
116 | s_db = librosa.power_to_db(mel, ref=np.max)
117 | ax = librosa.display.specshow(s_db,
118 | x_axis='time',
119 | y_axis='mel',
120 | sr=self.sampling_rate,
121 | fmin=self.f_min,
122 | fmax=self.f_max)
123 | f.add_subplot(ax)
124 | return f
125 |
126 | def load_wav(self, wav_path, preprocess=True):
127 | y, sr = librosa.load(wav_path, sr=self.sampling_rate)
128 | if preprocess:
129 | y = self.preprocess(y)
130 | return y, sr
131 |
132 | def preprocess(self, y):
133 | if self.norm_wav:
134 | y = self.normalize_volume(y, increase_only=True)
135 | if self.trim_long_silences:
136 | y = self.trim_audio_long_silences(y)
137 | if self.trim_silence:
138 | y = self.trim_audio_silence(y)
139 | if y.shape[0] % self.hop_length == 0:
140 | y = np.pad(y, (0, 1))
141 | return y
142 |
143 | def save_wav(self, y, wav_path):
144 | sf.write(wav_path, data=y, samplerate=self.sampling_rate)
145 |
146 | def extract_pitch(self, y):
147 | _f0, t = pw.dio(y.astype(np.float64), fs=self.sampling_rate,
148 | frame_period=self.hop_length / self.sampling_rate * 1000)
149 | f0 = pw.stonemask(y.astype(np.float64), _f0, t, fs=self.sampling_rate) # pitch refinement
150 |
151 | return f0
152 |
153 | # from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/audio.py
154 | def normalize_volume(self, wav, increase_only=False, decrease_only=False):
155 | if increase_only and decrease_only:
156 | raise ValueError("Both increase only and decrease only are set")
157 | rms = np.sqrt(np.mean((wav * self.int16_max) ** 2))
158 | wave_dBFS = 20 * np.log10(rms / self.int16_max)
159 | dBFS_change = self.target_dBFS - wave_dBFS
160 | if dBFS_change < 0 and increase_only or dBFS_change > 0 and decrease_only:
161 | return wav
162 | return wav * (10 ** (dBFS_change / 20))
163 |
164 | def trim_audio_silence(self, wav):
165 | trimmed = librosa.effects.trim(wav,
166 | top_db=self.trim_silence_top_db,
167 | frame_length=256,
168 | hop_length=64)
169 | return trimmed[0]
170 |
171 | # from https://github.com/resemble-ai/Resemblyzer/blob/master/resemblyzer/audio.py
172 | def trim_audio_long_silences(self, wav):
173 | samples_per_window = (self.vad_window_length * self.vad_sample_rate) // 1000
174 | wav = wav[:len(wav) - (len(wav) % samples_per_window)]
175 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * self.int16_max)).astype(np.int16))
176 | voice_flags = []
177 | vad = webrtcvad.Vad(mode=3)
178 | for window_start in range(0, len(wav), samples_per_window):
179 | window_end = window_start + samples_per_window
180 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
181 | sample_rate=self.vad_sample_rate))
182 | voice_flags = np.array(voice_flags)
183 |
184 | def moving_average(array, width):
185 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
186 | ret = np.cumsum(array_padded, dtype=float)
187 | ret[width:] = ret[width:] - ret[:-width]
188 | return ret[width - 1:] / width
189 |
190 | audio_mask = moving_average(voice_flags, self.vad_moving_average_width)
191 | audio_mask = np.round(audio_mask).astype(np.bool)
192 | audio_mask[:] = binary_dilation(audio_mask[:], np.ones(self.vad_max_silence_length + 1))
193 | audio_mask = np.repeat(audio_mask, samples_per_window)
194 | return wav[audio_mask]
195 |
196 | @classmethod
197 | def from_config(cls, config: dict):
198 | return cls(**config)
199 |
200 |
201 | class Normalizer:
202 | def normalize(self, S):
203 | raise NotImplementedError
204 |
205 | def denormalize(self, S):
206 | raise NotImplementedError
207 |
208 |
209 | class MelGAN(Normalizer):
210 | def __init__(self):
211 | super().__init__()
212 | self.clip_min = 1.0e-5
213 |
214 | def normalize(self, S):
215 | S = np.clip(S, a_min=self.clip_min, a_max=None)
216 | return np.log(S)
217 |
218 | def denormalize(self, S):
219 | return np.exp(S)
220 |
221 |
222 | class WaveRNN(Normalizer):
223 | def __init__(self):
224 | super().__init__()
225 | self.min_level_db = - 100
226 | self.max_norm = 4
227 |
228 | def normalize(self, S):
229 | S = self.amp_to_db(S)
230 | S = np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
231 | return (S * 2 * self.max_norm) - self.max_norm
232 |
233 | def denormalize(self, S):
234 | S = (S + self.max_norm) / (2 * self.max_norm)
235 | S = (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
236 | return self.db_to_amp(S)
237 |
238 | def amp_to_db(self, x):
239 | return 20 * np.log10(np.maximum(1e-5, x))
240 |
241 | def db_to_amp(self, x):
242 | return np.power(10.0, x * 0.05)
243 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from random import Random
3 | from typing import List, Union
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from utils.training_config_manager import TrainingConfigManager
9 | from data.text.tokenizer import Tokenizer
10 | from data.metadata_readers import get_preprocessor_by_name
11 |
12 |
13 | def get_files(path: Union[Path, str], extension='.wav') -> List[Path]:
14 | """ Get all files from all subdirs with given extension. """
15 | path = Path(path).expanduser().resolve()
16 | return list(path.rglob(f'*{extension}'))
17 |
18 |
19 | class DataReader:
20 | """
21 | Reads dataset folder and constructs three useful objects:
22 | text_dict: {filename: text}
23 | wav_paths: {filename: path/to/filename.wav}
24 | filenames: [filename1, filename2, ...]
25 |
26 | IMPORTANT: Use only for information available from source dataset, not for
27 | training data.
28 | """
29 |
30 | def __init__(self, wav_directory: str, metadata_path: str, metadata_reading_function=None, scan_wavs=False,
31 | training=False, is_processed=False):
32 | self.metadata_reading_function = metadata_reading_function
33 | self.wav_directory = Path(wav_directory)
34 | self.metadata_path = Path(metadata_path)
35 | if not is_processed:
36 | self.text_dict = self.metadata_reading_function(self.metadata_path)
37 | self.filenames = list(self.text_dict.keys())
38 | else:
39 | self.text_dict, self.upsample = self.metadata_reading_function(self.metadata_path)
40 | self.filenames = list(self.text_dict.keys())
41 | if training:
42 | self.filenames += self.upsample
43 | if scan_wavs:
44 | all_wavs = get_files(self.wav_directory, extension='.wav')
45 | self.wav_paths = {w.with_suffix('').name: w for w in all_wavs}
46 |
47 | @classmethod
48 | def from_config(cls, config_manager: TrainingConfigManager, kind: str, scan_wavs=False):
49 | kinds = ['original', 'phonemized', 'train', 'valid']
50 | if kind not in kinds:
51 | raise ValueError(f'Invalid kind type. Expected one of: {kinds}')
52 | reader = get_preprocessor_by_name('post_processed_reader')
53 | training = False
54 | is_processed = True
55 | if kind == 'train':
56 | metadata = config_manager.train_metadata_path
57 | training = True
58 | elif kind == 'original':
59 | metadata = config_manager.metadata_path
60 | reader = get_preprocessor_by_name(config_manager.config['data_name'])
61 | is_processed = False
62 | elif kind == 'valid':
63 | metadata = config_manager.valid_metadata_path
64 | elif kind == 'phonemized':
65 | metadata = config_manager.phonemized_metadata_path
66 |
67 | return cls(wav_directory=config_manager.wav_directory,
68 | metadata_reading_function=reader,
69 | metadata_path=metadata,
70 | scan_wavs=scan_wavs,
71 | training=training,
72 | is_processed=is_processed)
73 |
74 |
75 | class AlignerPreprocessor:
76 |
77 | def __init__(self,
78 | mel_channels: int,
79 | mel_start_value: float,
80 | mel_end_value: float,
81 | tokenizer: Tokenizer):
82 | self.output_types = (tf.float32, tf.int32, tf.int32, tf.string)
83 | self.padded_shapes = ([None, mel_channels], [None], [None], [])
84 | self.start_vec = np.ones((1, mel_channels)) * mel_start_value
85 | self.end_vec = np.ones((1, mel_channels)) * mel_end_value
86 | self.tokenizer = tokenizer
87 |
88 | def __call__(self, mel, text, sample_name):
89 | encoded_phonemes = self.tokenizer(text)
90 | norm_mel = np.concatenate([self.start_vec, mel, self.end_vec], axis=0)
91 | stop_probs = np.ones((norm_mel.shape[0]))
92 | stop_probs[-1] = 2
93 | return norm_mel, encoded_phonemes, stop_probs, sample_name
94 |
95 | def get_sample_length(self, norm_mel, encoded_phonemes, stop_probs, sample_name):
96 | return tf.shape(norm_mel)[0]
97 |
98 | @classmethod
99 | def from_config(cls, config: TrainingConfigManager, tokenizer: Tokenizer):
100 | return cls(mel_channels=config.config['mel_channels'],
101 | mel_start_value=config.config['mel_start_value'],
102 | mel_end_value=config.config['mel_end_value'],
103 | tokenizer=tokenizer)
104 |
105 |
106 | class AlignerDataset:
107 | def __init__(self,
108 | data_reader: DataReader,
109 | preprocessor,
110 | mel_directory: str):
111 | self.metadata_reader = data_reader
112 | self.preprocessor = preprocessor
113 | self.mel_directory = Path(mel_directory)
114 |
115 | def _read_sample(self, sample_name):
116 | text = self.metadata_reader.text_dict[sample_name]
117 | mel = np.load((self.mel_directory / sample_name).with_suffix('.npy').as_posix())
118 | return mel, text
119 |
120 | def _process_sample(self, sample_name):
121 | mel, text = self._read_sample(sample_name)
122 | return self.preprocessor(mel=mel, text=text, sample_name=sample_name)
123 |
124 | def get_dataset(self, bucket_batch_sizes, bucket_boundaries, shuffle=True, drop_remainder=False):
125 | return Dataset(
126 | samples=self.metadata_reader.filenames,
127 | preprocessor=self._process_sample,
128 | output_types=self.preprocessor.output_types,
129 | padded_shapes=self.preprocessor.padded_shapes,
130 | shuffle=shuffle,
131 | drop_remainder=drop_remainder,
132 | len_function=self.preprocessor.get_sample_length,
133 | bucket_batch_sizes=bucket_batch_sizes,
134 | bucket_boundaries=bucket_boundaries)
135 |
136 | @classmethod
137 | def from_config(cls,
138 | config: TrainingConfigManager,
139 | preprocessor,
140 | kind: str,
141 | mel_directory: str = None, ):
142 | kinds = ['original', 'phonemized', 'train', 'valid']
143 | if kind not in kinds:
144 | raise ValueError(f'Invalid kind type. Expected one of: {kinds}')
145 | if mel_directory is None:
146 | mel_directory = config.mel_dir
147 | metadata_reader = DataReader.from_config(config, kind=kind)
148 | return cls(preprocessor=preprocessor,
149 | data_reader=metadata_reader,
150 | mel_directory=mel_directory)
151 |
152 |
153 | class TTSPreprocessor:
154 | def __init__(self, mel_channels, tokenizer: Tokenizer):
155 | self.output_types = (tf.float32, tf.int32, tf.int32, tf.float32, tf.string)
156 | self.padded_shapes = ([None, mel_channels], [None], [None], [None], [])
157 | self.tokenizer = tokenizer
158 |
159 | def __call__(self, text, mel, durations, pitch, sample_name):
160 | encoded_phonemes = self.tokenizer(text)
161 | return mel, encoded_phonemes, durations, pitch, sample_name
162 |
163 | def get_sample_length(self, mel, encoded_phonemes, durations, pitch, sample_name):
164 | return tf.shape(mel)[0]
165 |
166 | @classmethod
167 | def from_config(cls, config: TrainingConfigManager, tokenizer: Tokenizer):
168 | return cls(mel_channels=config.config['mel_channels'],
169 | tokenizer=tokenizer)
170 |
171 |
172 | class TTSDataset:
173 | def __init__(self,
174 | data_reader: DataReader,
175 | preprocessor: TTSPreprocessor,
176 | mel_directory: str,
177 | pitch_directory: str,
178 | duration_directory: str,
179 | pitch_per_char_directory: str):
180 | self.metadata_reader = data_reader
181 | self.preprocessor = preprocessor
182 | self.mel_directory = Path(mel_directory)
183 | self.duration_directory = Path(duration_directory)
184 | self.pitch_directory = Path(pitch_directory)
185 | self.pitch_per_char_directory = Path(pitch_per_char_directory)
186 |
187 | def _read_sample(self, sample_name: str):
188 | text = self.metadata_reader.text_dict[sample_name]
189 | mel = np.load((self.mel_directory / sample_name).with_suffix('.npy').as_posix())
190 | durations = np.load(
191 | (self.duration_directory / sample_name).with_suffix('.npy').as_posix())
192 | char_wise_pitch = np.load((self.pitch_per_char_directory / sample_name).with_suffix('.npy').as_posix())
193 | return mel, text, durations, char_wise_pitch
194 |
195 | def _process_sample(self, sample_name: str):
196 | mel, text, durations, pitch = self._read_sample(sample_name)
197 | return self.preprocessor(mel=mel, text=text, durations=durations, pitch=pitch, sample_name=sample_name)
198 |
199 | def get_dataset(self, bucket_batch_sizes, bucket_boundaries, shuffle=True, drop_remainder=False):
200 | return Dataset(
201 | samples=self.metadata_reader.filenames,
202 | preprocessor=self._process_sample,
203 | output_types=self.preprocessor.output_types,
204 | padded_shapes=self.preprocessor.padded_shapes,
205 | len_function=self.preprocessor.get_sample_length,
206 | shuffle=shuffle,
207 | drop_remainder=drop_remainder,
208 | bucket_batch_sizes=bucket_batch_sizes,
209 | bucket_boundaries=bucket_boundaries)
210 |
211 | @classmethod
212 | def from_config(cls,
213 | config: TrainingConfigManager,
214 | preprocessor,
215 | kind: str,
216 | mel_directory: str = None,
217 | duration_directory: str = None,
218 | pitch_directory: str = None):
219 | kinds = ['phonemized', 'train', 'valid']
220 | if kind not in kinds:
221 | raise ValueError(f'Invalid kind type. Expected one of: {kinds}')
222 | if mel_directory is None:
223 | mel_directory = config.mel_dir
224 | if duration_directory is None:
225 | duration_directory = config.duration_dir
226 | if pitch_directory is None:
227 | pitch_directory = config.pitch_dir
228 | metadata_reader = DataReader.from_config(config,
229 | kind=kind)
230 | return cls(preprocessor=preprocessor,
231 | data_reader=metadata_reader,
232 | mel_directory=mel_directory,
233 | duration_directory=duration_directory,
234 | pitch_directory=pitch_directory,
235 | pitch_per_char_directory=config.pitch_per_char)
236 |
237 |
238 | class Dataset:
239 | """ Model digestible dataset. """
240 |
241 | def __init__(self,
242 | samples: list,
243 | preprocessor,
244 | len_function,
245 | padded_shapes: tuple,
246 | output_types: tuple,
247 | bucket_boundaries: list,
248 | bucket_batch_sizes: list,
249 | padding_values: tuple = None,
250 | shuffle=True,
251 | drop_remainder=True,
252 | seed=42):
253 | self._random = Random(seed)
254 | self._samples = samples[:]
255 | self.preprocessor = preprocessor
256 | dataset = tf.data.Dataset.from_generator(lambda: self._datagen(shuffle),
257 | output_types=output_types)
258 | # TODO: pass bin args
259 | binned_data = dataset.apply(
260 | tf.data.experimental.bucket_by_sequence_length(
261 | len_function,
262 | bucket_boundaries=bucket_boundaries,
263 | bucket_batch_sizes=bucket_batch_sizes,
264 | padded_shapes=padded_shapes,
265 | drop_remainder=drop_remainder,
266 | padding_values=padding_values
267 | ))
268 | self.dataset = binned_data
269 | self.data_iter = iter(binned_data.repeat(-1))
270 |
271 | def next_batch(self):
272 | return next(self.data_iter)
273 |
274 | def all_batches(self):
275 | return iter(self.dataset)
276 |
277 | def _datagen(self, shuffle):
278 | """
279 | Shuffle once before generating to avoid buffering
280 | """
281 | samples = self._samples[:]
282 | if shuffle:
283 | self._random.shuffle(samples)
284 | return (self.preprocessor(s) for s in samples)
285 |
--------------------------------------------------------------------------------
/data/metadata_readers.py:
--------------------------------------------------------------------------------
1 | """
2 | methods for reading a dataset and return a dictionary of the form:
3 | {
4 | filename: text_line,
5 | ...
6 | }
7 | """
8 |
9 | import sys
10 | from typing import Dict, List, Tuple
11 |
12 |
13 | def get_preprocessor_by_name(name: str):
14 | """
15 | Returns the respective data function.
16 | Taken from https://github.com/mozilla/TTS/blob/master/TTS/tts/datasets/preprocess.py
17 | """
18 | thismodule = sys.modules[__name__]
19 | return getattr(thismodule, name.lower())
20 |
21 |
22 | def ljspeech(metadata_path: str, column_sep='|') -> dict:
23 | text_dict = {}
24 | with open(metadata_path, 'r', encoding='utf-8') as f:
25 | for l in f.readlines():
26 | l_split = l.split(column_sep)
27 | filename, text = l_split[0], l_split[-1]
28 | if filename.endswith('.wav'):
29 | filename = filename.split('.')[0]
30 | text = text.replace('\n', '')
31 | text_dict.update({filename: text})
32 | return text_dict
33 |
34 |
35 | def post_processed_reader(metadata_path: str, column_sep='|', upsample_indicators='?!', upsample_factor=10) -> Tuple[
36 | Dict, List]:
37 | """
38 | Used to read metadata files created within the repo.
39 | """
40 | text_dict = {}
41 | upsample = []
42 | with open(metadata_path, 'r', encoding='utf-8') as f:
43 | for l in f.readlines():
44 | l_split = l.split(column_sep)
45 | filename, text = l_split[0], l_split[1]
46 | text = text.replace('\n', '')
47 | if any(el in text for el in list(upsample_indicators)):
48 | upsample.extend([filename] * upsample_factor)
49 | text_dict.update({filename: text})
50 | return text_dict, upsample
51 |
52 |
53 | if __name__ == '__main__':
54 | metadata_path = '/Volumes/data/datasets/LJSpeech-1.1/metadata.csv'
55 | d = get_preprocessor_by_name('ljspeech')(metadata_path)
56 | key_list = list(d.keys())
57 | print('metadata head')
58 | for key in key_list[:5]:
59 | print(f'{key}: {d[key]}')
60 | print('metadata tail')
61 | for key in key_list[-5:]:
62 | print(f'{key}: {d[key]}')
63 |
--------------------------------------------------------------------------------
/data/text/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from data.text.symbols import all_phonemes
4 | from data.text.tokenizer import Phonemizer, Tokenizer
5 |
6 |
7 | class TextToTokens:
8 | def __init__(self, phonemizer: Phonemizer, tokenizer: Tokenizer):
9 | self.phonemizer = phonemizer
10 | self.tokenizer = tokenizer
11 |
12 | def __call__(self, input_text: Union[str, list]) -> list:
13 | phons = self.phonemizer(input_text)
14 | tokens = self.tokenizer(phons)
15 | return tokens
16 |
17 | @classmethod
18 | def default(cls, language: str, add_start_end: bool, with_stress: bool, model_breathing: bool, njobs=1):
19 | phonemizer = Phonemizer(language=language, njobs=njobs, with_stress=with_stress)
20 | tokenizer = Tokenizer(add_start_end=add_start_end, model_breathing=model_breathing)
21 | return cls(phonemizer=phonemizer, tokenizer=tokenizer)
22 |
--------------------------------------------------------------------------------
/data/text/symbols.py:
--------------------------------------------------------------------------------
1 | _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
2 | _non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
3 | _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
4 | _suprasegmentals = 'ˈˌːˑ'
5 | _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
6 | _diacrilics = 'ɚ˞ɫ'
7 | _phonemes = sorted(list(
8 | _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
9 | _punctuations = '!,-.:;? \'()'
10 | _alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzäüößÄÖÜ'
11 |
12 | all_phonemes = sorted(list(_phonemes) + list(_punctuations))
13 |
--------------------------------------------------------------------------------
/data/text/tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import re
3 |
4 | from phonemizer.phonemize import phonemize
5 |
6 | from data.text.symbols import all_phonemes, _punctuations
7 |
8 |
9 | class Tokenizer:
10 |
11 | def __init__(self, start_token='>', end_token='<', pad_token='/', add_start_end=True, alphabet=None,
12 | model_breathing=True):
13 | if not alphabet:
14 | self.alphabet = all_phonemes
15 | else:
16 | self.alphabet = sorted(list(set(alphabet))) # for testing
17 | self.idx_to_token = {i: s for i, s in enumerate(self.alphabet, start=1)}
18 | self.idx_to_token[0] = pad_token
19 | self.token_to_idx = {s: [i] for i, s in self.idx_to_token.items()}
20 | self.vocab_size = len(self.alphabet) + 1
21 | self.add_start_end = add_start_end
22 | if add_start_end:
23 | self.start_token_index = len(self.alphabet) + 1
24 | self.end_token_index = len(self.alphabet) + 2
25 | self.vocab_size += 2
26 | self.idx_to_token[self.start_token_index] = start_token
27 | self.idx_to_token[self.end_token_index] = end_token
28 | self.model_breathing = model_breathing
29 | if model_breathing:
30 | self.breathing_token_index = self.vocab_size
31 | self.token_to_idx[' '] = self.token_to_idx[' '] + [self.breathing_token_index]
32 | self.vocab_size += 1
33 | self.breathing_token = '@'
34 | self.idx_to_token[self.breathing_token_index] = self.breathing_token
35 | self.token_to_idx[self.breathing_token] = [self.breathing_token_index]
36 |
37 | def __call__(self, sentence: str) -> list:
38 | sequence = [self.token_to_idx[c] for c in sentence] # No filtering: text should only contain known chars.
39 | sequence = [item for items in sequence for item in items]
40 | if self.model_breathing:
41 | sequence = [self.breathing_token_index] + sequence
42 | if self.add_start_end:
43 | sequence = [self.start_token_index] + sequence + [self.end_token_index]
44 | return sequence
45 |
46 | def decode(self, sequence: list) -> str:
47 | return ''.join([self.idx_to_token[int(t)] for t in sequence])
48 |
49 |
50 | class Phonemizer:
51 | def __init__(self, language: str, with_stress: bool, njobs=4):
52 | self.language = language
53 | self.njobs = njobs
54 | self.with_stress = with_stress
55 | self.special_hyphen = '—'
56 | self.punctuation = ';:,.!?¡¿—…"«»“”'
57 | self._whitespace_re = re.compile(r'\s+')
58 | self._whitespace_punctuation_re = re.compile(f'\s*([{_punctuations}])\s*')
59 |
60 | def __call__(self, text: Union[str, list], with_stress=None, njobs=None, language=None) -> Union[str, list]:
61 | language = language or self.language
62 | njobs = njobs or self.njobs
63 | with_stress = with_stress or self.with_stress
64 | # phonemizer does not like hyphens.
65 | text = self._preprocess(text)
66 | phonemes = phonemize(text,
67 | language=language,
68 | backend='espeak',
69 | strip=True,
70 | preserve_punctuation=True,
71 | with_stress=with_stress,
72 | punctuation_marks=self.punctuation,
73 | njobs=njobs,
74 | language_switch='remove-flags')
75 | return self._postprocess(phonemes)
76 |
77 | def _preprocess_string(self, text: str):
78 | text = text.replace('-', self.special_hyphen)
79 | return text
80 |
81 | def _preprocess(self, text: Union[str, list]) -> Union[str, list]:
82 | if isinstance(text, list):
83 | return [self._preprocess_string(t) for t in text]
84 | elif isinstance(text, str):
85 | return self._preprocess_string(text)
86 | else:
87 | raise TypeError(f'{self} input must be list or str, not {type(text)}')
88 |
89 | def _collapse_whitespace(self, text: str) -> str:
90 | text = re.sub(self._whitespace_re, ' ', text)
91 | return re.sub(self._whitespace_punctuation_re, r'\1', text)
92 |
93 | def _postprocess_string(self, text: str) -> str:
94 | text = text.replace(self.special_hyphen, '-')
95 | text = ''.join([c for c in text if c in all_phonemes])
96 | text = self._collapse_whitespace(text)
97 | text = text.strip()
98 | return text
99 |
100 | def _postprocess(self, text: Union[str, list]) -> Union[str, list]:
101 | if isinstance(text, list):
102 | return [self._postprocess_string(t) for t in text]
103 | elif isinstance(text, str):
104 | return self._postprocess_string(text)
105 | else:
106 | raise TypeError(f'{self} input must be list or str, not {type(text)}')
107 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _site
2 | .sass-cache
3 | .jekyll-cache
4 | .jekyll-metadata
5 | vendor
6 | .bundle
--------------------------------------------------------------------------------
/docs/404.html:
--------------------------------------------------------------------------------
1 | ---
2 | permalink: /404.html
3 | layout: default
4 | ---
5 |
6 |
19 |
20 |
21 |
404
22 |
23 |
Page not found :(
24 |
The requested page could not be found.
25 |
26 |
--------------------------------------------------------------------------------
/docs/Gemfile:
--------------------------------------------------------------------------------
1 | source "https://rubygems.org"
2 | # Hello! This is where you manage which Jekyll version is used to run.
3 | # When you want to use a different version, change it below, save the
4 | # file and run `bundle install`. Run Jekyll with `bundle exec`, like so:
5 | #
6 | # bundle exec jekyll serve
7 | #
8 | # This will help ensure the proper Jekyll version is running.
9 | # Happy Jekylling!
10 | # gem "jekyll", "~> 4.0.0"
11 | # This is the default theme for new Jekyll sites. You may change this to anything you like.
12 | gem "minima", "~> 2.5"
13 | # If you want to use GitHub Pages, remove the "gem "jekyll"" above and
14 | # uncomment the line below. To upgrade, run `bundle update github-pages`.
15 | gem "github-pages", group: :jekyll_plugins
16 | # If you have any plugins, put them here!
17 | group :jekyll_plugins do
18 | gem "jekyll-feed", "~> 0.12"
19 | end
20 |
21 | # Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem
22 | # and associated library.
23 | install_if -> { RUBY_PLATFORM =~ %r!mingw|mswin|java! } do
24 | gem "tzinfo", "~> 1.2"
25 | gem "tzinfo-data"
26 | end
27 |
28 | # Performance-booster for watching directories on Windows
29 | gem "wdm", "~> 0.1.1", :install_if => Gem.win_platform?
30 |
31 | gem "jekyll", "~> 3.8"
32 |
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | activesupport (6.0.3.1)
5 | concurrent-ruby (~> 1.0, >= 1.0.2)
6 | i18n (>= 0.7, < 2)
7 | minitest (~> 5.1)
8 | tzinfo (~> 1.1)
9 | zeitwerk (~> 2.2, >= 2.2.2)
10 | addressable (2.7.0)
11 | public_suffix (>= 2.0.2, < 5.0)
12 | coffee-script (2.4.1)
13 | coffee-script-source
14 | execjs
15 | coffee-script-source (1.11.1)
16 | colorator (1.1.0)
17 | commonmarker (0.17.13)
18 | ruby-enum (~> 0.5)
19 | concurrent-ruby (1.1.6)
20 | dnsruby (1.61.3)
21 | addressable (~> 2.5)
22 | em-websocket (0.5.1)
23 | eventmachine (>= 0.12.9)
24 | http_parser.rb (~> 0.6.0)
25 | ethon (0.12.0)
26 | ffi (>= 1.3.0)
27 | eventmachine (1.2.7)
28 | execjs (2.7.0)
29 | faraday (1.0.1)
30 | multipart-post (>= 1.2, < 3)
31 | ffi (1.12.2)
32 | forwardable-extended (2.6.0)
33 | gemoji (3.0.1)
34 | github-pages (204)
35 | github-pages-health-check (= 1.16.1)
36 | jekyll (= 3.8.5)
37 | jekyll-avatar (= 0.7.0)
38 | jekyll-coffeescript (= 1.1.1)
39 | jekyll-commonmark-ghpages (= 0.1.6)
40 | jekyll-default-layout (= 0.1.4)
41 | jekyll-feed (= 0.13.0)
42 | jekyll-gist (= 1.5.0)
43 | jekyll-github-metadata (= 2.13.0)
44 | jekyll-mentions (= 1.5.1)
45 | jekyll-optional-front-matter (= 0.3.2)
46 | jekyll-paginate (= 1.1.0)
47 | jekyll-readme-index (= 0.3.0)
48 | jekyll-redirect-from (= 0.15.0)
49 | jekyll-relative-links (= 0.6.1)
50 | jekyll-remote-theme (= 0.4.1)
51 | jekyll-sass-converter (= 1.5.2)
52 | jekyll-seo-tag (= 2.6.1)
53 | jekyll-sitemap (= 1.4.0)
54 | jekyll-swiss (= 1.0.0)
55 | jekyll-theme-architect (= 0.1.1)
56 | jekyll-theme-cayman (= 0.1.1)
57 | jekyll-theme-dinky (= 0.1.1)
58 | jekyll-theme-hacker (= 0.1.1)
59 | jekyll-theme-leap-day (= 0.1.1)
60 | jekyll-theme-merlot (= 0.1.1)
61 | jekyll-theme-midnight (= 0.1.1)
62 | jekyll-theme-minimal (= 0.1.1)
63 | jekyll-theme-modernist (= 0.1.1)
64 | jekyll-theme-primer (= 0.5.4)
65 | jekyll-theme-slate (= 0.1.1)
66 | jekyll-theme-tactile (= 0.1.1)
67 | jekyll-theme-time-machine (= 0.1.1)
68 | jekyll-titles-from-headings (= 0.5.3)
69 | jemoji (= 0.11.1)
70 | kramdown (= 1.17.0)
71 | liquid (= 4.0.3)
72 | mercenary (~> 0.3)
73 | minima (= 2.5.1)
74 | nokogiri (>= 1.10.4, < 2.0)
75 | rouge (= 3.13.0)
76 | terminal-table (~> 1.4)
77 | github-pages-health-check (1.16.1)
78 | addressable (~> 2.3)
79 | dnsruby (~> 1.60)
80 | octokit (~> 4.0)
81 | public_suffix (~> 3.0)
82 | typhoeus (~> 1.3)
83 | html-pipeline (2.12.3)
84 | activesupport (>= 2)
85 | nokogiri (>= 1.4)
86 | http_parser.rb (0.6.0)
87 | i18n (0.9.5)
88 | concurrent-ruby (~> 1.0)
89 | jekyll (3.8.5)
90 | addressable (~> 2.4)
91 | colorator (~> 1.0)
92 | em-websocket (~> 0.5)
93 | i18n (~> 0.7)
94 | jekyll-sass-converter (~> 1.0)
95 | jekyll-watch (~> 2.0)
96 | kramdown (~> 1.14)
97 | liquid (~> 4.0)
98 | mercenary (~> 0.3.3)
99 | pathutil (~> 0.9)
100 | rouge (>= 1.7, < 4)
101 | safe_yaml (~> 1.0)
102 | jekyll-avatar (0.7.0)
103 | jekyll (>= 3.0, < 5.0)
104 | jekyll-coffeescript (1.1.1)
105 | coffee-script (~> 2.2)
106 | coffee-script-source (~> 1.11.1)
107 | jekyll-commonmark (1.3.1)
108 | commonmarker (~> 0.14)
109 | jekyll (>= 3.7, < 5.0)
110 | jekyll-commonmark-ghpages (0.1.6)
111 | commonmarker (~> 0.17.6)
112 | jekyll-commonmark (~> 1.2)
113 | rouge (>= 2.0, < 4.0)
114 | jekyll-default-layout (0.1.4)
115 | jekyll (~> 3.0)
116 | jekyll-feed (0.13.0)
117 | jekyll (>= 3.7, < 5.0)
118 | jekyll-gist (1.5.0)
119 | octokit (~> 4.2)
120 | jekyll-github-metadata (2.13.0)
121 | jekyll (>= 3.4, < 5.0)
122 | octokit (~> 4.0, != 4.4.0)
123 | jekyll-mentions (1.5.1)
124 | html-pipeline (~> 2.3)
125 | jekyll (>= 3.7, < 5.0)
126 | jekyll-optional-front-matter (0.3.2)
127 | jekyll (>= 3.0, < 5.0)
128 | jekyll-paginate (1.1.0)
129 | jekyll-readme-index (0.3.0)
130 | jekyll (>= 3.0, < 5.0)
131 | jekyll-redirect-from (0.15.0)
132 | jekyll (>= 3.3, < 5.0)
133 | jekyll-relative-links (0.6.1)
134 | jekyll (>= 3.3, < 5.0)
135 | jekyll-remote-theme (0.4.1)
136 | addressable (~> 2.0)
137 | jekyll (>= 3.5, < 5.0)
138 | rubyzip (>= 1.3.0)
139 | jekyll-sass-converter (1.5.2)
140 | sass (~> 3.4)
141 | jekyll-seo-tag (2.6.1)
142 | jekyll (>= 3.3, < 5.0)
143 | jekyll-sitemap (1.4.0)
144 | jekyll (>= 3.7, < 5.0)
145 | jekyll-swiss (1.0.0)
146 | jekyll-theme-architect (0.1.1)
147 | jekyll (~> 3.5)
148 | jekyll-seo-tag (~> 2.0)
149 | jekyll-theme-cayman (0.1.1)
150 | jekyll (~> 3.5)
151 | jekyll-seo-tag (~> 2.0)
152 | jekyll-theme-dinky (0.1.1)
153 | jekyll (~> 3.5)
154 | jekyll-seo-tag (~> 2.0)
155 | jekyll-theme-hacker (0.1.1)
156 | jekyll (~> 3.5)
157 | jekyll-seo-tag (~> 2.0)
158 | jekyll-theme-leap-day (0.1.1)
159 | jekyll (~> 3.5)
160 | jekyll-seo-tag (~> 2.0)
161 | jekyll-theme-merlot (0.1.1)
162 | jekyll (~> 3.5)
163 | jekyll-seo-tag (~> 2.0)
164 | jekyll-theme-midnight (0.1.1)
165 | jekyll (~> 3.5)
166 | jekyll-seo-tag (~> 2.0)
167 | jekyll-theme-minimal (0.1.1)
168 | jekyll (~> 3.5)
169 | jekyll-seo-tag (~> 2.0)
170 | jekyll-theme-modernist (0.1.1)
171 | jekyll (~> 3.5)
172 | jekyll-seo-tag (~> 2.0)
173 | jekyll-theme-primer (0.5.4)
174 | jekyll (> 3.5, < 5.0)
175 | jekyll-github-metadata (~> 2.9)
176 | jekyll-seo-tag (~> 2.0)
177 | jekyll-theme-slate (0.1.1)
178 | jekyll (~> 3.5)
179 | jekyll-seo-tag (~> 2.0)
180 | jekyll-theme-tactile (0.1.1)
181 | jekyll (~> 3.5)
182 | jekyll-seo-tag (~> 2.0)
183 | jekyll-theme-time-machine (0.1.1)
184 | jekyll (~> 3.5)
185 | jekyll-seo-tag (~> 2.0)
186 | jekyll-titles-from-headings (0.5.3)
187 | jekyll (>= 3.3, < 5.0)
188 | jekyll-watch (2.2.1)
189 | listen (~> 3.0)
190 | jemoji (0.11.1)
191 | gemoji (~> 3.0)
192 | html-pipeline (~> 2.2)
193 | jekyll (>= 3.0, < 5.0)
194 | kramdown (1.17.0)
195 | liquid (4.0.3)
196 | listen (3.2.1)
197 | rb-fsevent (~> 0.10, >= 0.10.3)
198 | rb-inotify (~> 0.9, >= 0.9.10)
199 | mercenary (0.3.6)
200 | mini_portile2 (2.5.1)
201 | minima (2.5.1)
202 | jekyll (>= 3.5, < 5.0)
203 | jekyll-feed (~> 0.9)
204 | jekyll-seo-tag (~> 2.1)
205 | minitest (5.14.1)
206 | multipart-post (2.1.1)
207 | nokogiri (1.11.5)
208 | mini_portile2 (~> 2.5.0)
209 | racc (~> 1.4)
210 | octokit (4.18.0)
211 | faraday (>= 0.9)
212 | sawyer (~> 0.8.0, >= 0.5.3)
213 | pathutil (0.16.2)
214 | forwardable-extended (~> 2.6)
215 | public_suffix (3.1.1)
216 | racc (1.5.2)
217 | rb-fsevent (0.10.3)
218 | rb-inotify (0.10.1)
219 | ffi (~> 1.0)
220 | rouge (3.13.0)
221 | ruby-enum (0.8.0)
222 | i18n
223 | rubyzip (2.3.0)
224 | safe_yaml (1.0.5)
225 | sass (3.7.4)
226 | sass-listen (~> 4.0.0)
227 | sass-listen (4.0.0)
228 | rb-fsevent (~> 0.9, >= 0.9.4)
229 | rb-inotify (~> 0.9, >= 0.9.7)
230 | sawyer (0.8.2)
231 | addressable (>= 2.3.5)
232 | faraday (> 0.8, < 2.0)
233 | terminal-table (1.8.0)
234 | unicode-display_width (~> 1.1, >= 1.1.1)
235 | thread_safe (0.3.6)
236 | typhoeus (1.3.1)
237 | ethon (>= 0.9.0)
238 | tzinfo (1.2.7)
239 | thread_safe (~> 0.1)
240 | tzinfo-data (1.2019.3)
241 | tzinfo (>= 1.0.0)
242 | unicode-display_width (1.7.0)
243 | wdm (0.1.1)
244 | zeitwerk (2.3.0)
245 |
246 | PLATFORMS
247 | ruby
248 |
249 | DEPENDENCIES
250 | github-pages
251 | jekyll (~> 3.8)
252 | jekyll-feed (~> 0.12)
253 | minima (~> 2.5)
254 | tzinfo (~> 1.2)
255 | tzinfo-data
256 | wdm (~> 0.1.1)
257 |
258 | BUNDLED WITH
259 | 2.1.4
260 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # TransformerTTS Project Page
2 |
3 | We use [Jekyll](https://jekyllrb.com/) to generate our project page.
4 |
5 | ### Usage
6 |
7 | To serve locally run:
8 | ```
9 | bundle exec jekyll serve
10 | ```
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | # Welcome to Jekyll!
2 | #
3 | # This config file is meant for settings that affect your whole blog, values
4 | # which you are expected to set up once and rarely edit after that. If you find
5 | # yourself editing this file very often, consider using Jekyll's data files
6 | # feature for the data you need to update frequently.
7 | #
8 | # For technical reasons, this file is *NOT* reloaded automatically when you use
9 | # 'bundle exec jekyll serve'. If you change this file, please restart the server process.
10 | #
11 | # If you need help with YAML syntax, here are some quick references for you:
12 | # https://learn-the-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml
13 | # https://learnxinyminutes.com/docs/yaml/
14 | #
15 | # Site settings
16 | # These are used to personalize your new site. If you look in the HTML files,
17 | # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on.
18 | # You can create any custom variable you would like, and they will be accessible
19 | # in the templates via {{ site.myvariable }}.
20 |
21 | title: TransformerTTS
22 | email: ai@axelspringer.com
23 | description: >- # this means to ignore newlines until "baseurl:"
24 | Implementation of a Transformer based neural network for text to speech.
25 | google_analytics: UA-137434942-7
26 | baseurl: "/TransformerTTS" # the subpath of your site, e.g. /blog
27 | url: "" # the base hostname & protocol for your site, e.g. http://example.com
28 | twitter_username: axelspringerai
29 | github_username: as-ideas
30 |
31 | # Build settings
32 | markdown: kramdown
33 | theme: jekyll-theme-primer
34 | plugins:
35 | - jekyll-feed
36 |
37 | # Exclude from processing.
38 | # The following items will not be processed, by default.
39 | # Any item listed under the `exclude:` key here will be automatically added to
40 | # the internal "default list".
41 | #
42 | # Excluded items can be processed by explicitly listing the directories or
43 | # their entries' file path in the `include:` list.
44 | #
45 | # exclude:
46 | # - .sass-cache/
47 | # - .jekyll-cache/
48 | # - gemfiles/
49 | # - Gemfile
50 | # - Gemfile.lock
51 | # - node_modules/
52 | # - vendor/bundle/
53 | # - vendor/cache/
54 | # - vendor/gems/
55 | # - vendor/ruby/
56 |
--------------------------------------------------------------------------------
/docs/_layouts/default.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | {% seo %}
11 |
12 |
13 |
14 | Fork me on GitHub
15 |
16 |
17 | {{ content }}
18 |
19 | {% if site.github.private != true and site.github.license %}
20 |
23 | {% endif %}
24 |
25 |
26 |
27 | {% if site.google_analytics %}
28 |
36 | {% endif %}
37 |
38 |
--------------------------------------------------------------------------------
/docs/assets/css/style.scss:
--------------------------------------------------------------------------------
1 | ---
2 | ---
3 |
4 | @import "{{ site.theme }}";
5 |
6 | .text {
7 | font-size: 20px;
8 | }
--------------------------------------------------------------------------------
/docs/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/docs/favicon.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | A Text-to-Speech Transformer in TensorFlow 2
9 |
10 |
11 | Samples are converted using the pre-trained HiFiGAN vocoder and with the standard Griffin-Lim algorithm for comparison.
12 |
13 |
14 | ## 🎧 Model samples
15 |
16 | Introductory speech ODSC Boston 2021
17 |
18 |
19 |
20 |
21 | Peter piper picked a peck of pickled peppers.
22 |
23 | | HiFiGAN | Griffin-Lim |
24 | |:---:|:---:|
25 | | | |
26 |
27 | President Trump met with other leaders at the Group of twenty conference.
28 |
29 | | HiFiGAN | Griffin-Lim |
30 | |:---:|:---:|
31 | | | |
32 |
33 | Scientists at the CERN laboratory say they have discovered a new particle.
34 |
35 | | HiFiGAN | Griffin-Lim |
36 | |:---:|:---:|
37 | | | |
38 |
39 | There’s a way to measure the acute emotional intelligence that has never gone out of style.
40 |
41 | | HiFiGAN | Griffin-Lim |
42 | |:---:|:---:|
43 | | | |
44 |
45 | The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled.
46 |
47 | | HiFiGAN | Griffin-Lim |
48 | |:---:|:---:|
49 | | | |
50 |
51 |
52 | If I were to talk to a human, I would definitely try to sound normal. Wouldn't I?
53 |
54 | | HiFiGAN | Griffin-Lim |
55 | |:---:|:---:|
56 | | | |
57 |
58 |
59 | ### Robustness
60 |
61 | To deliver interfaces that are significantly better suited to create and process RFC eight twenty one , RFC eight twenty two , RFC nine seventy seven , and MIME content.
62 |
63 | | HiFiGAN | Griffin-Lim |
64 | |:---:|:---:|
65 | | | |
66 |
67 |
68 | ### Comparison with [ForwardTacotron](https://github.com/as-ideas/ForwardTacotron)
69 | In a statement announcing his resignation, Mr Ross, said: "While the intentions may have been well meaning, the reaction to this news shows that Mr Cummings interpretation of the government advice was not shared by the vast majority of people who have done as the government asked."
70 |
71 | | TransformerTTS | ForwardTacotron |
72 | |:---:|:---:|
73 | | | |
74 |
--------------------------------------------------------------------------------
/docs/tboard_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/docs/tboard_demo.gif
--------------------------------------------------------------------------------
/docs/transformer_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/docs/transformer_logo.png
--------------------------------------------------------------------------------
/extract_durations.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | import tensorflow as tf
5 | import numpy as np
6 | from tqdm import tqdm
7 | from p_tqdm import p_umap
8 |
9 | from utils.training_config_manager import TrainingConfigManager
10 | from utils.logging_utils import SummaryManager
11 | from data.datasets import AlignerPreprocessor
12 | from utils.alignments import get_durations_from_alignment
13 | from utils.scripts_utils import dynamic_memory_allocation
14 | from data.datasets import AlignerDataset
15 | from data.datasets import DataReader
16 |
17 | np.random.seed(42)
18 | tf.random.set_seed(42)
19 | dynamic_memory_allocation()
20 |
21 | if __name__ == '__main__':
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--config', dest='config', type=str)
24 | parser.add_argument('--best', dest='best', action='store_true',
25 | help='Use best head instead of weighted average of heads.')
26 | parser.add_argument('--autoregressive_weights', type=str, default=None,
27 | help='Explicit path to autoregressive model weights.')
28 | parser.add_argument('--skip_char_pitch', dest='skip_char_pitch', action='store_true')
29 | parser.add_argument('--skip_durations', dest='skip_durations', action='store_true')
30 | args = parser.parse_args()
31 | weighted = not args.best
32 | tag_description = ''.join([
33 | f'{"_weighted" * weighted}{"_best" * (not weighted)}',
34 | ])
35 | writer_tag = f'DurationExtraction{tag_description}'
36 | print(writer_tag)
37 | config_manager = TrainingConfigManager(config_path=args.config, aligner=True)
38 | config = config_manager.config
39 | config_manager.print_config()
40 |
41 | if not args.skip_durations:
42 | model = config_manager.load_model(args.autoregressive_weights)
43 | if model.r != 1:
44 | print(f"ERROR: model's reduction factor is greater than 1, check config. (r={model.r}")
45 |
46 | data_prep = AlignerPreprocessor.from_config(config=config_manager,
47 | tokenizer=model.text_pipeline.tokenizer)
48 | data_handler = AlignerDataset.from_config(config_manager,
49 | preprocessor=data_prep,
50 | kind='phonemized')
51 | target_dir = config_manager.duration_dir
52 | config_manager.dump_config()
53 | dataset = data_handler.get_dataset(bucket_batch_sizes=config['bucket_batch_sizes'],
54 | bucket_boundaries=config['bucket_boundaries'],
55 | shuffle=False,
56 | drop_remainder=False)
57 |
58 | last_layer_key = 'Decoder_LastBlock_CrossAttention'
59 | print(f'Extracting attention from layer {last_layer_key}')
60 |
61 | summary_manager = SummaryManager(model=model, log_dir=config_manager.log_dir / 'Duration Extraction',
62 | config=config,
63 | default_writer='Duration Extraction')
64 | all_durations = np.array([])
65 | new_alignments = []
66 | iterator = tqdm(enumerate(dataset.all_batches()))
67 | step = 0
68 | for c, (mel_batch, text_batch, stop_batch, file_name_batch) in iterator:
69 | iterator.set_description(f'Processing dataset')
70 | outputs = model.val_step(inp=text_batch,
71 | tar=mel_batch,
72 | stop_prob=stop_batch)
73 | attention_values = outputs['decoder_attention'][last_layer_key].numpy()
74 | text = text_batch.numpy()
75 |
76 | mel = mel_batch.numpy()
77 |
78 | durations, final_align, jumpiness, peakiness, diag_measure = get_durations_from_alignment(
79 | batch_alignments=attention_values,
80 | mels=mel,
81 | phonemes=text,
82 | weighted=weighted)
83 | batch_avg_jumpiness = tf.reduce_mean(jumpiness, axis=0)
84 | batch_avg_peakiness = tf.reduce_mean(peakiness, axis=0)
85 | batch_avg_diag_measure = tf.reduce_mean(diag_measure, axis=0)
86 | for i in range(tf.shape(jumpiness)[1]):
87 | summary_manager.display_scalar(tag=f'DurationAttentionJumpiness/head{i}',
88 | scalar_value=tf.reduce_mean(batch_avg_jumpiness[i]), step=c)
89 | summary_manager.display_scalar(tag=f'DurationAttentionPeakiness/head{i}',
90 | scalar_value=tf.reduce_mean(batch_avg_peakiness[i]), step=c)
91 | summary_manager.display_scalar(tag=f'DurationAttentionDiagonality/head{i}',
92 | scalar_value=tf.reduce_mean(batch_avg_diag_measure[i]), step=c)
93 |
94 | for i, name in enumerate(file_name_batch):
95 | all_durations = np.append(all_durations, durations[i]) # for plotting only
96 | summary_manager.add_image(tag='ExtractedAlignments',
97 | image=tf.expand_dims(tf.expand_dims(final_align[i], 0), -1),
98 | step=step)
99 |
100 | step += 1
101 | np.save(str(target_dir / f"{name.numpy().decode('utf-8')}.npy"), durations[i])
102 |
103 | all_durations[all_durations >= 20] = 20 # for plotting only
104 | buckets = len(set(all_durations)) # for plotting only
105 | summary_manager.add_histogram(values=all_durations, tag='ExtractedDurations', buckets=buckets)
106 |
107 | if not args.skip_char_pitch:
108 | def _pitch_per_char(pitch, durations, mel_len):
109 | durs_cum = np.cumsum(np.pad(durations, (1, 0)))
110 | pitch_char = np.zeros((durations.shape[0],), dtype=np.float)
111 | for idx, a, b in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]):
112 | values = pitch[a:b][np.where(pitch[a:b] != 0.0)[0]]
113 | values = values[np.where((values * pitch_stats['pitch_std'] + pitch_stats['pitch_mean']) < 400)[0]]
114 | pitch_char[idx] = np.mean(values) if len(values) > 0 else 0.0
115 | return pitch_char
116 |
117 |
118 | def process_per_char_pitch(sample_name: str):
119 | pitch = np.load((config_manager.pitch_dir / sample_name).with_suffix('.npy').as_posix())
120 | durations = np.load((config_manager.duration_dir / sample_name).with_suffix('.npy').as_posix())
121 | mel = np.load((config_manager.mel_dir / sample_name).with_suffix('.npy').as_posix())
122 | char_wise_pitch = _pitch_per_char(pitch, durations, mel.shape[0])
123 | np.save((config_manager.pitch_per_char / sample_name).with_suffix('.npy').as_posix(), char_wise_pitch)
124 |
125 |
126 | metadatareader = DataReader.from_config(config_manager, kind='phonemized', scan_wavs=False)
127 | pitch_stats = pickle.load(open(config_manager.data_dir / 'pitch_stats.pkl', 'rb'))
128 | print(f'\nComputing phoneme-wise pitch')
129 | print(f'{len(metadatareader.filenames)} items found in {metadatareader.metadata_path}.')
130 | wav_iter = p_umap(process_per_char_pitch, metadatareader.filenames)
131 | print('Done.')
132 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/model/__init__.py
--------------------------------------------------------------------------------
/model/factory.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from pathlib import Path
3 |
4 | import tensorflow as tf
5 | import ruamel.yaml
6 |
7 | from model.models import ForwardTransformer, Aligner
8 |
9 |
10 | def tts_ljspeech(step='95000') -> Tuple[ForwardTransformer, dict]:
11 | model_name = f'bdf06b9_ljspeech_step_{step}.zip'
12 | remote_dir = 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/api_weights/bdf06b9_ljspeech/'
13 | model_path = tf.keras.utils.get_file(model_name,
14 | remote_dir + model_name,
15 | extract=True,
16 | archive_format='zip',
17 | cache_subdir='TransformerTTS_models')
18 | model_path = Path(model_path).with_suffix('') # remove extension
19 | return ForwardTransformer.load_model(model_path.as_posix())
20 |
21 |
22 | def tts_custom(config_path: str, weights_path: str) -> Tuple[ForwardTransformer, dict]:
23 | yaml = ruamel.yaml.YAML()
24 | with open(config_path, 'rb') as session_yaml:
25 | config = yaml.load(session_yaml)
26 | model = ForwardTransformer.from_config(config)
27 | model.build_model_weights()
28 | model.load_weights(weights_path)
29 | return model, config
30 |
31 |
32 | def aligner_custom(config_path: str, weights_path: str) -> Tuple[Aligner, dict]:
33 | yaml = ruamel.yaml.YAML()
34 | with open(config_path, 'rb') as session_yaml:
35 | config = yaml.load(session_yaml)
36 | model = Aligner.from_config(config)
37 | model.build_model_weights()
38 | model.load_weights(weights_path)
39 | return model, config
40 |
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from model.transformer_utils import positional_encoding
4 |
5 |
6 | class CNNResNorm(tf.keras.layers.Layer):
7 | """
8 | Module used in attention blocks, after MHA
9 | """
10 |
11 | def __init__(self,
12 | filters: list,
13 | kernel_size: int,
14 | inner_activation: str,
15 | padding: str,
16 | dout_rate: float):
17 | super(CNNResNorm, self).__init__()
18 | self.n_layers = len(filters)
19 | self.convolutions = [tf.keras.layers.Conv1D(filters=f,
20 | kernel_size=kernel_size,
21 | padding=padding)
22 | for f in filters[:-1]]
23 | self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(self.n_layers - 1)]
24 | self.last_conv = tf.keras.layers.Conv1D(filters=filters[-1],
25 | kernel_size=kernel_size,
26 | padding=padding)
27 | self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-6)
28 | self.dropout = tf.keras.layers.Dropout(rate=dout_rate)
29 |
30 | def call_convs(self, x):
31 | for i in range(0, self.n_layers - 1):
32 | x = self.convolutions[i](x)
33 | x = self.inner_activations[i](x)
34 | return x
35 |
36 | def call(self, inputs, training):
37 | x = self.call_convs(inputs)
38 | x = self.last_conv(x)
39 | x = self.dropout(x, training=training)
40 | return self.normalization(inputs + x)
41 |
42 |
43 | class TransposedCNNResNorm(tf.keras.layers.Layer):
44 | """
45 | Module used in attention blocks, after MHA
46 | """
47 |
48 | def __init__(self,
49 | filters: list,
50 | kernel_size: int,
51 | inner_activation: str,
52 | padding: str,
53 | dout_rate: float):
54 | super(TransposedCNNResNorm, self).__init__()
55 | self.n_layers = len(filters)
56 | self.convolutions = [tf.keras.layers.Conv1D(filters=f,
57 | kernel_size=kernel_size,
58 | padding=padding)
59 | for f in filters[:-1]]
60 | self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(self.n_layers - 1)]
61 | self.last_conv = tf.keras.layers.Conv1D(filters=filters[-1],
62 | kernel_size=kernel_size,
63 | padding=padding)
64 | self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-6)
65 | self.dropout = tf.keras.layers.Dropout(rate=dout_rate)
66 |
67 | def call_convs(self, x):
68 | for i in range(0, self.n_layers - 1):
69 | x = self.convolutions[i](x)
70 | x = self.inner_activations[i](x)
71 | return x
72 |
73 | def call(self, inputs, training):
74 | x = tf.transpose(inputs, (0, 1, 2))
75 | x = self.call_convs(x)
76 | x = self.last_conv(x)
77 | x = tf.transpose(x, (0, 1, 2))
78 | x = self.dropout(x, training=training)
79 | return self.normalization(inputs + x)
80 |
81 |
82 | class FFNResNorm(tf.keras.layers.Layer):
83 | """
84 | Module used in attention blocks, after MHA
85 | """
86 |
87 | def __init__(self,
88 | model_dim: int,
89 | dense_hidden_units: int,
90 | dropout_rate: float,
91 | **kwargs):
92 | super(FFNResNorm, self).__init__(**kwargs)
93 | self.d1 = tf.keras.layers.Dense(dense_hidden_units, 'relu')
94 | self.d2 = tf.keras.layers.Dense(model_dim)
95 | self.dropout = tf.keras.layers.Dropout(dropout_rate)
96 | self.last_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6)
97 |
98 | def call(self, x, training):
99 | ffn_out = self.d1(x)
100 | ffn_out = self.d2(ffn_out) # (batch_size, input_seq_len, model_dim)
101 | ffn_out = self.dropout(ffn_out, training=training)
102 | return self.last_ln(ffn_out + x)
103 |
104 |
105 | class MultiHeadAttention(tf.keras.layers.Layer):
106 |
107 | def __init__(self, model_dim: int, num_heads: int, dropout: float, **kwargs):
108 | super(MultiHeadAttention, self).__init__(**kwargs)
109 | self.num_heads = num_heads
110 | self.model_dim = model_dim
111 |
112 | assert model_dim % self.num_heads == 0
113 |
114 | self.depth = model_dim // self.num_heads
115 |
116 | self.wq = tf.keras.layers.Dense(model_dim)
117 | self.wk = tf.keras.layers.Dense(model_dim)
118 | self.wv = tf.keras.layers.Dense(model_dim)
119 | self.attention = ScaledDotProductAttention(dropout=dropout)
120 | self.dense = tf.keras.layers.Dense(model_dim)
121 | self.dropout = tf.keras.layers.Dropout(dropout)
122 |
123 | def split_heads(self, x, batch_size: int):
124 | """ Split the last dimension into (num_heads, depth).
125 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
126 | """
127 |
128 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
129 | return tf.transpose(x, perm=[0, 2, 1, 3])
130 |
131 | def call(self, v, k, q_in, mask, training):
132 | batch_size = tf.shape(q_in)[0]
133 |
134 | q = self.wq(q_in) # (batch_size, seq_len, model_dim)
135 | k = self.wk(k) # (batch_size, seq_len, model_dim)
136 | v = self.wv(v) # (batch_size, seq_len, model_dim)
137 |
138 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
139 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
140 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
141 |
142 | scaled_attention, attention_weights = self.attention([q, k, v, mask], training=training)
143 |
144 | scaled_attention = tf.transpose(scaled_attention,
145 | perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
146 | concat_attention = tf.reshape(scaled_attention,
147 | (batch_size, -1, self.model_dim)) # (batch_size, seq_len_q, model_dim)
148 | concat_query = tf.concat([q_in, concat_attention], axis=-1)
149 | output = self.dense(concat_query) # (batch_size, seq_len_q, model_dim)
150 | output = self.dropout(output, training=training)
151 | return output, attention_weights
152 |
153 |
154 | class ScaledDotProductAttention(tf.keras.layers.Layer):
155 | """ Calculate the attention weights.
156 | q, k, v must have matching leading dimensions.
157 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
158 | The mask has different shapes depending on its type(padding or look ahead)
159 | but it must be broadcastable for addition.
160 |
161 | Args:
162 | q: query shape == (..., seq_len_q, depth)
163 | k: key shape == (..., seq_len_k, depth)
164 | v: value shape == (..., seq_len_v, depth_v)
165 | mask: Float tensor with shape broadcastable
166 | to (..., seq_len_q, seq_len_k). Defaults to None.
167 |
168 | Returns:
169 | output, attention_weights
170 | """
171 |
172 | def __init__(self, dropout: float):
173 | super(ScaledDotProductAttention, self).__init__()
174 | self.dropout = tf.keras.layers.Dropout(rate=dropout)
175 |
176 | def call(self, inputs, training=False):
177 | q, k, v, mask = inputs
178 |
179 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
180 |
181 | # scale matmul_qk
182 | dk = tf.cast(tf.shape(k)[-1], tf.float32)
183 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
184 |
185 | # add the mask to the scaled tensor.
186 | if mask is not None:
187 | scaled_attention_logits += mask * -1e9 # TODO: add mask expansion here and remove from create padding mask
188 |
189 | # softmax is normalized on the last axis (seq_len_k) so that the scores
190 | # add up to 1.
191 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
192 | attention_weights = self.dropout(attention_weights, training=training)
193 | output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
194 |
195 | return output, attention_weights
196 |
197 |
198 | class SelfAttentionResNorm(tf.keras.layers.Layer):
199 |
200 | def __init__(self,
201 | model_dim: int,
202 | num_heads: int,
203 | dropout_rate: float,
204 | **kwargs):
205 | super(SelfAttentionResNorm, self).__init__(**kwargs)
206 | self.mha = MultiHeadAttention(model_dim, num_heads, dropout=dropout_rate)
207 | self.last_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6)
208 |
209 | def call(self, x, training, mask):
210 | attn_out, attn_weights = self.mha(x, x, x, mask, training=training) # (batch_size, input_seq_len, model_dim)
211 | return self.last_ln(attn_out + x), attn_weights
212 |
213 |
214 | class SelfAttentionDenseBlock(tf.keras.layers.Layer):
215 |
216 | def __init__(self,
217 | model_dim: int,
218 | num_heads: int,
219 | dense_hidden_units: int,
220 | dropout_rate: float,
221 | **kwargs):
222 | super(SelfAttentionDenseBlock, self).__init__(**kwargs)
223 | self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate)
224 | self.ffn = FFNResNorm(model_dim, dense_hidden_units, dropout_rate=dropout_rate)
225 |
226 | def call(self, x, training, mask):
227 | attn_out, attn_weights = self.sarn(x, mask=mask, training=training)
228 | dense_mask = 1. - tf.squeeze(mask, axis=(1, 2))[:, :, None]
229 | attn_out = attn_out * dense_mask
230 | return self.ffn(attn_out, training=training) * dense_mask, attn_weights
231 |
232 |
233 | class SelfAttentionConvBlock(tf.keras.layers.Layer):
234 |
235 | def __init__(self,
236 | model_dim: int,
237 | num_heads: int,
238 | dropout_rate: float,
239 | conv_filters: list,
240 | kernel_size: int,
241 | conv_activation: str,
242 | transposed_convs: bool,
243 | **kwargs):
244 | super(SelfAttentionConvBlock, self).__init__(**kwargs)
245 | self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate)
246 | if transposed_convs:
247 | self.conv = TransposedCNNResNorm(filters=conv_filters,
248 | kernel_size=kernel_size,
249 | inner_activation=conv_activation,
250 | dout_rate=dropout_rate,
251 | padding='same')
252 | else:
253 | self.conv = CNNResNorm(filters=conv_filters,
254 | kernel_size=kernel_size,
255 | inner_activation=conv_activation,
256 | dout_rate=dropout_rate,
257 | padding='same')
258 |
259 | def call(self, x, training, mask):
260 | attn_out, attn_weights = self.sarn(x, mask=mask, training=training)
261 | conv_mask = 1. - tf.squeeze(mask, axis=(1, 2))[:, :, None]
262 | attn_out = attn_out * conv_mask
263 | conv = self.conv(attn_out, training=training)
264 | return conv * conv_mask, attn_weights
265 |
266 |
267 | class SelfAttentionBlocks(tf.keras.layers.Layer):
268 | def __init__(self,
269 | model_dim: int,
270 | feed_forward_dimension: int,
271 | num_heads: list,
272 | maximum_position_encoding: int,
273 | conv_filters: list,
274 | dropout_rate: float,
275 | dense_blocks: int,
276 | kernel_size: int,
277 | conv_activation: str,
278 | transposed_convs: bool = None,
279 | **kwargs):
280 | super(SelfAttentionBlocks, self).__init__(**kwargs)
281 | self.model_dim = model_dim
282 | self.pos_encoding_scalar = tf.Variable(1.)
283 | self.pos_encoding = positional_encoding(maximum_position_encoding, model_dim)
284 | self.dropout = tf.keras.layers.Dropout(dropout_rate)
285 | self.encoder_SADB = [
286 | SelfAttentionDenseBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads,
287 | dense_hidden_units=feed_forward_dimension, name=f'{self.name}_SADB_{i}')
288 | for i, n_heads in enumerate(num_heads[:dense_blocks])]
289 | self.encoder_SACB = [
290 | SelfAttentionConvBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads,
291 | name=f'{self.name}_SACB_{i}', kernel_size=kernel_size,
292 | conv_activation=conv_activation, conv_filters=conv_filters,
293 | transposed_convs=transposed_convs)
294 | for i, n_heads in enumerate(num_heads[dense_blocks:])]
295 | self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
296 |
297 | def call(self, inputs, training, padding_mask, reduction_factor=1):
298 | seq_len = tf.shape(inputs)[1]
299 | x = self.layernorm(inputs)
300 | x += self.pos_encoding_scalar * self.pos_encoding[:, :seq_len * reduction_factor:reduction_factor, :]
301 | x = self.dropout(x, training=training)
302 | attention_weights = {}
303 | for i, block in enumerate(self.encoder_SADB):
304 | x, attn_weights = block(x, training=training, mask=padding_mask)
305 | attention_weights[f'{self.name}_DenseBlock{i + 1}_SelfAttention'] = attn_weights
306 | for i, block in enumerate(self.encoder_SACB):
307 | x, attn_weights = block(x, training=training, mask=padding_mask)
308 | attention_weights[f'{self.name}_ConvBlock{i + 1}_SelfAttention'] = attn_weights
309 |
310 | return x, attention_weights
311 |
312 |
313 | class CrossAttentionResnorm(tf.keras.layers.Layer):
314 |
315 | def __init__(self,
316 | model_dim: int,
317 | num_heads: int,
318 | dropout_rate: float,
319 | **kwargs):
320 | super(CrossAttentionResnorm, self).__init__(**kwargs)
321 | self.mha = MultiHeadAttention(model_dim, num_heads, dropout=dropout_rate)
322 | self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
323 |
324 | def call(self, q, k, v, training, mask):
325 | attn_values, attn_weights = self.mha(v, k=k, q_in=q, mask=mask, training=training)
326 | out = self.layernorm(attn_values + q)
327 | return out, attn_weights
328 |
329 |
330 | class CrossAttentionDenseBlock(tf.keras.layers.Layer):
331 |
332 | def __init__(self,
333 | model_dim: int,
334 | num_heads: int,
335 | dense_hidden_units: int,
336 | dropout_rate: float,
337 | **kwargs):
338 | super(CrossAttentionDenseBlock, self).__init__(**kwargs)
339 | self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate)
340 | self.carn = CrossAttentionResnorm(model_dim, num_heads, dropout_rate=dropout_rate)
341 | self.ffn = FFNResNorm(model_dim, dense_hidden_units, dropout_rate=dropout_rate)
342 |
343 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
344 | attn1, attn_weights_block1 = self.sarn(x, mask=look_ahead_mask, training=training)
345 |
346 | attn2, attn_weights_block2 = self.carn(attn1, v=enc_output, k=enc_output,
347 | mask=padding_mask, training=training)
348 | ffn_out = self.ffn(attn2, training=training)
349 | return ffn_out, attn_weights_block1, attn_weights_block2
350 |
351 | # This is never used.
352 | # class CrossAttentionConvBlock(tf.keras.layers.Layer):
353 | #
354 | # def __init__(self,
355 | # model_dim: int,
356 | # num_heads: int,
357 | # conv_filters: list,
358 | # dropout_rate: float,
359 | # kernel_size: int,
360 | # conv_padding: str,
361 | # conv_activation: str,
362 | # **kwargs):
363 | # super(CrossAttentionConvBlock, self).__init__(**kwargs)
364 | # self.sarn = SelfAttentionResNorm(model_dim, num_heads, dropout_rate=dropout_rate)
365 | # self.carn = CrossAttentionResnorm(model_dim, num_heads, dropout_rate=dropout_rate)
366 | # self.conv = CNNResNorm(filters=conv_filters,
367 | # kernel_size=kernel_size,
368 | # inner_activation=conv_activation,
369 | # padding=conv_padding,
370 | # dout_rate=dropout_rate)
371 | #
372 | # def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
373 | # attn1, attn_weights_block1 = self.sarn(x, mask=look_ahead_mask, training=training)
374 | #
375 | # attn2, attn_weights_block2 = self.carn(attn1, v=enc_output, k=enc_output,
376 | # mask=padding_mask, training=training)
377 | # ffn_out = self.conv(attn2, training=training)
378 | # return ffn_out, attn_weights_block1, attn_weights_block2
379 |
380 |
381 | class CrossAttentionBlocks(tf.keras.layers.Layer):
382 |
383 | def __init__(self,
384 | model_dim: int,
385 | feed_forward_dimension: int,
386 | num_heads: list,
387 | maximum_position_encoding: int,
388 | dropout_rate: float,
389 | **kwargs):
390 | super(CrossAttentionBlocks, self).__init__(**kwargs)
391 | self.model_dim = model_dim
392 | self.pos_encoding_scalar = tf.Variable(1.)
393 | self.pos_encoding = positional_encoding(maximum_position_encoding, model_dim)
394 | self.dropout = tf.keras.layers.Dropout(dropout_rate)
395 | self.CADB = [
396 | CrossAttentionDenseBlock(model_dim=model_dim, dropout_rate=dropout_rate, num_heads=n_heads,
397 | dense_hidden_units=feed_forward_dimension, name=f'{self.name}_CADB_{i}')
398 | for i, n_heads in enumerate(num_heads[:-1])]
399 | self.last_CADB = CrossAttentionDenseBlock(model_dim=model_dim, dropout_rate=dropout_rate,
400 | num_heads=num_heads[-1],
401 | dense_hidden_units=feed_forward_dimension,
402 | name=f'{self.name}_CADB_last')
403 | self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
404 |
405 | def call(self, inputs, enc_output, training, decoder_padding_mask, encoder_padding_mask,
406 | reduction_factor=1):
407 | seq_len = tf.shape(inputs)[1]
408 | x = self.layernorm(inputs)
409 | x += self.pos_encoding_scalar * self.pos_encoding[:, :seq_len * reduction_factor:reduction_factor, :]
410 | x = self.dropout(x, training=training)
411 | attention_weights = {}
412 | for i, block in enumerate(self.CADB):
413 | x, _, attn_weights = block(x, enc_output, training, decoder_padding_mask, encoder_padding_mask)
414 | attention_weights[f'{self.name}_DenseBlock{i + 1}_CrossAttention'] = attn_weights
415 | x, _, attn_weights = self.last_CADB(x, enc_output, training, decoder_padding_mask, encoder_padding_mask)
416 | attention_weights[f'{self.name}_LastBlock_CrossAttention'] = attn_weights
417 | return x, attention_weights
418 |
419 |
420 | class DecoderPrenet(tf.keras.layers.Layer):
421 |
422 | def __init__(self,
423 | model_dim: int,
424 | dense_hidden_units: int,
425 | dropout_rate: float,
426 | **kwargs):
427 | super(DecoderPrenet, self).__init__(**kwargs)
428 | self.d1 = tf.keras.layers.Dense(dense_hidden_units,
429 | activation='relu') # (batch_size, seq_len, dense_hidden_units)
430 | self.d2 = tf.keras.layers.Dense(model_dim, activation='relu') # (batch_size, seq_len, model_dim)
431 | self.rate = tf.Variable(dropout_rate, trainable=False)
432 | self.dropout_1 = tf.keras.layers.Dropout(self.rate)
433 | self.dropout_2 = tf.keras.layers.Dropout(self.rate)
434 |
435 | def call(self, x, training):
436 | self.dropout_1.rate = self.rate
437 | self.dropout_2.rate = self.rate
438 | x = self.d1(x)
439 | # use dropout also in inference for positional encoding relevance
440 | x = self.dropout_1(x, training=training)
441 | x = self.d2(x)
442 | x = self.dropout_2(x, training=training)
443 | return x
444 |
445 |
446 | class Postnet(tf.keras.layers.Layer):
447 |
448 | def __init__(self, mel_channels: int, **kwargs):
449 | super(Postnet, self).__init__(**kwargs)
450 | self.mel_channels = mel_channels
451 | self.stop_linear = tf.keras.layers.Dense(3)
452 | self.mel_out = tf.keras.layers.Dense(mel_channels)
453 |
454 | def call(self, x):
455 | stop = self.stop_linear(x)
456 | mel = self.mel_out(x)
457 | return {
458 | 'mel': mel,
459 | 'stop_prob': stop,
460 | }
461 |
462 |
463 | class StatPredictor(tf.keras.layers.Layer):
464 | def __init__(self,
465 | conv_filters: list,
466 | kernel_size: int,
467 | conv_padding: str,
468 | conv_activation: str,
469 | dense_activation: str,
470 | dropout_rate: float,
471 | **kwargs):
472 | super(StatPredictor, self).__init__(**kwargs)
473 | self.conv_blocks = CNNDropout(filters=conv_filters,
474 | kernel_size=kernel_size,
475 | padding=conv_padding,
476 | inner_activation=conv_activation,
477 | last_activation=conv_activation,
478 | dout_rate=dropout_rate)
479 | self.linear = tf.keras.layers.Dense(1, activation=dense_activation)
480 |
481 | def call(self, x, training, mask):
482 | x = x * mask
483 | x = self.conv_blocks(x, training=training)
484 | x = self.linear(x)
485 | return x * mask
486 |
487 |
488 | class CNNDropout(tf.keras.layers.Layer):
489 | def __init__(self,
490 | filters: list,
491 | kernel_size: int,
492 | inner_activation: str,
493 | last_activation: str,
494 | padding: str,
495 | dout_rate: float):
496 | super(CNNDropout, self).__init__()
497 | self.n_layers = len(filters)
498 | self.convolutions = [tf.keras.layers.Conv1D(filters=f,
499 | kernel_size=kernel_size,
500 | padding=padding)
501 | for f in filters[:-1]]
502 | self.inner_activations = [tf.keras.layers.Activation(inner_activation) for _ in range(self.n_layers - 1)]
503 | self.last_conv = tf.keras.layers.Conv1D(filters=filters[-1],
504 | kernel_size=kernel_size,
505 | padding=padding)
506 | self.last_activation = tf.keras.layers.Activation(last_activation)
507 | self.dropouts = [tf.keras.layers.Dropout(rate=dout_rate) for _ in range(self.n_layers)]
508 | self.normalization = [tf.keras.layers.LayerNormalization(epsilon=1e-6) for _ in range(self.n_layers)]
509 |
510 | def call_convs(self, x, training):
511 | for i in range(0, self.n_layers - 1):
512 | x = self.convolutions[i](x)
513 | x = self.inner_activations[i](x)
514 | x = self.normalization[i](x)
515 | x = self.dropouts[i](x, training=training)
516 | return x
517 |
518 | def call(self, inputs, training):
519 | x = self.call_convs(inputs, training=training)
520 | x = self.last_conv(x)
521 | x = self.last_activation(x)
522 | x = self.normalization[-1](x)
523 | x = self.dropouts[-1](x, training=training)
524 | return x
525 |
526 |
527 | class Expand(tf.keras.layers.Layer):
528 | """ Expands a 3D tensor on its second axis given a list of dimensions.
529 | Tensor should be:
530 | batch_size, seq_len, dimension
531 |
532 | E.g:
533 | input = tf.Tensor([[[0.54710746 0.8943467 ]
534 | [0.7140938 0.97968304]
535 | [0.5347662 0.15213418]]], shape=(1, 3, 2), dtype=float32)
536 | dimensions = tf.Tensor([1 3 2], shape=(3,), dtype=int32)
537 | output = tf.Tensor([[[0.54710746 0.8943467 ]
538 | [0.7140938 0.97968304]
539 | [0.7140938 0.97968304]
540 | [0.7140938 0.97968304]
541 | [0.5347662 0.15213418]
542 | [0.5347662 0.15213418]]], shape=(1, 6, 2), dtype=float32)
543 | """
544 |
545 | def __init__(self, model_dim, **kwargs):
546 | super(Expand, self).__init__(**kwargs)
547 | self.model_dimension = model_dim
548 |
549 | def call(self, x, dimensions):
550 | dimensions = tf.squeeze(dimensions, axis=-1)
551 | dimensions = tf.cast(tf.math.round(dimensions), tf.int32)
552 | seq_len = tf.shape(x)[1]
553 | batch_size = tf.shape(x)[0]
554 | # build masks from dimensions
555 | max_dim = tf.math.reduce_max(dimensions)
556 | tot_dim = tf.math.reduce_sum(dimensions)
557 | index_masks = tf.RaggedTensor.from_row_lengths(tf.ones(tot_dim), tf.reshape(dimensions, [-1])).to_tensor()
558 | index_masks = tf.cast(tf.reshape(index_masks, (batch_size, seq_len * max_dim)), tf.float32)
559 | non_zeros = seq_len * max_dim - tf.reduce_sum(max_dim - dimensions, axis=1)
560 | # stack and mask
561 | tiled = tf.tile(x, [1, 1, max_dim])
562 | reshaped = tf.reshape(tiled, (batch_size, seq_len * max_dim, self.model_dimension))
563 | mask_reshape = tf.multiply(reshaped, index_masks[:, :, tf.newaxis])
564 | ragged = tf.RaggedTensor.from_row_lengths(mask_reshape[index_masks > 0], non_zeros)
565 | return ragged.to_tensor()
566 |
--------------------------------------------------------------------------------
/model/transformer_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | def get_angles(pos, i, model_dim):
6 | angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(model_dim))
7 | return pos * angle_rates
8 |
9 |
10 | def positional_encoding(position, model_dim):
11 | angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(model_dim)[np.newaxis, :], model_dim)
12 |
13 | # apply sin to even indices in the array; 2i
14 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
15 |
16 | # apply cos to odd indices in the array; 2i+1
17 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
18 |
19 | pos_encoding = angle_rads[np.newaxis, ...]
20 |
21 | return tf.cast(pos_encoding, dtype=tf.float32)
22 |
23 |
24 | def create_encoder_padding_mask(seq):
25 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
26 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, y, x)
27 |
28 |
29 | def create_mel_padding_mask(seq):
30 | seq = tf.reduce_sum(tf.math.abs(seq), axis=-1)
31 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
32 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, y, x)
33 |
34 |
35 | def create_look_ahead_mask(size):
36 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
37 | return mask
38 |
--------------------------------------------------------------------------------
/notebooks/synthesize_forward_melgan.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "zdMgfG7GMF_R"
8 | },
9 | "source": [
10 | " \n",
11 | " Transformer TTS: A Text-to-Speech Transformer in TensorFlow 2 \n",
12 | " Audio synthesis with Forward Transformer TTS and MelGAN Vocoder \n",
13 | " \n",
14 | "\n",
15 | "## Forward Model"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {
22 | "colab": {
23 | "base_uri": "https://localhost:8080/",
24 | "height": 225
25 | },
26 | "colab_type": "code",
27 | "id": "JQ5YuFPAxXUy",
28 | "outputId": "23b61caa-9bc7-4bb8-b31b-0da6a029d6e6"
29 | },
30 | "outputs": [],
31 | "source": [
32 | "# Clone the Transformer TTS and MelGAN repos\n",
33 | "!git clone https://github.com/as-ideas/TransformerTTS.git\n",
34 | "!git clone https://github.com/seungwonpark/melgan.git"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "colab": {
42 | "base_uri": "https://localhost:8080/",
43 | "height": 1000
44 | },
45 | "colab_type": "code",
46 | "id": "9bIzkIGjMRwm",
47 | "outputId": "c078ee93-da4c-4c93-daf2-e092acd8b929"
48 | },
49 | "outputs": [],
50 | "source": [
51 | "# Install requirements\n",
52 | "!apt-get install -y espeak\n",
53 | "!pip install -r TransformerTTS/requirements.txt"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/",
62 | "height": 225
63 | },
64 | "colab_type": "code",
65 | "id": "cOxdx6L5Hjcf",
66 | "outputId": "e919b138-a9d4-431e-dcf6-873925e8f0a9"
67 | },
68 | "outputs": [],
69 | "source": [
70 | "!cd TransformerTTS/; git checkout c3405c53e435a06c809533aa4453923469081147"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {
77 | "colab": {},
78 | "colab_type": "code",
79 | "id": "W3tlwOlRbABh"
80 | },
81 | "outputs": [],
82 | "source": [
83 | "# Set up the paths\n",
84 | "from pathlib import Path\n",
85 | "MelGAN_path = 'melgan/'\n",
86 | "TTS_path = 'TransformerTTS/'\n",
87 | "\n",
88 | "import sys\n",
89 | "sys.path.append(TTS_path)"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {
96 | "colab": {
97 | "base_uri": "https://localhost:8080/",
98 | "height": 69
99 | },
100 | "colab_type": "code",
101 | "id": "LucwkAK1yEVq",
102 | "outputId": "c2000cef-7533-4e95-e1c0-93559a93aec4"
103 | },
104 | "outputs": [],
105 | "source": [
106 | "# Load pretrained model\n",
107 | "from model.factory import tts_ljspeech\n",
108 | "from data.audio import Audio\n",
109 | "\n",
110 | "model, config = tts_ljspeech()\n",
111 | "audio = Audio(config)"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {
118 | "colab": {},
119 | "colab_type": "code",
120 | "id": "_5RKHIDQyZvo"
121 | },
122 | "outputs": [],
123 | "source": [
124 | "# Synthesize text\n",
125 | "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n",
126 | "out_normal = model.predict(sentence)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {
133 | "colab": {
134 | "base_uri": "https://localhost:8080/",
135 | "height": 62
136 | },
137 | "colab_type": "code",
138 | "id": "GXxdDHOAyZ6f",
139 | "outputId": "a5d611c2-316a-4aec-834b-93562b0f487a"
140 | },
141 | "outputs": [],
142 | "source": [
143 | "# Convert spectrogram to wav (with griffin lim)\n",
144 | "wav = audio.reconstruct_waveform(out_normal['mel'].numpy().T)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 62
154 | },
155 | "colab_type": "code",
156 | "id": "GXxdDHOAyZ6f",
157 | "outputId": "a5d611c2-316a-4aec-834b-93562b0f487a"
158 | },
159 | "outputs": [],
160 | "source": [
161 | "import IPython.display as ipd\n",
162 | "\n",
163 | "ipd.display(ipd.Audio(wav, rate=config['sampling_rate']))"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "metadata": {
169 | "colab_type": "text",
170 | "id": "uyidCx84bAB1"
171 | },
172 | "source": [
173 | "You can also vary the speech speed"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {
180 | "colab": {
181 | "base_uri": "https://localhost:8080/",
182 | "height": 62
183 | },
184 | "colab_type": "code",
185 | "id": "_N3rM6qmbAB2",
186 | "outputId": "c68e3da8-e15d-431c-e377-c98457b83a52"
187 | },
188 | "outputs": [],
189 | "source": [
190 | "# 20% faster\n",
191 | "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n",
192 | "out = model.predict(sentence, speed_regulator=1.20)\n",
193 | "wav = audio.reconstruct_waveform(out['mel'].numpy().T)\n",
194 | "ipd.display(ipd.Audio(wav, rate=config['sampling_rate']))"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {
201 | "colab": {
202 | "base_uri": "https://localhost:8080/",
203 | "height": 62
204 | },
205 | "colab_type": "code",
206 | "id": "QxAIl9LkbAB6",
207 | "outputId": "a8e0531e-268f-4fd4-8a55-2c7914cc67d6"
208 | },
209 | "outputs": [],
210 | "source": [
211 | "# 10% slower\n",
212 | "sentence = 'Scientists at the CERN laboratory, say they have discovered a new particle.'\n",
213 | "out = model.predict(sentence, speed_regulator=.9)\n",
214 | "wav = audio.reconstruct_waveform(out['mel'].numpy().T)\n",
215 | "ipd.display(ipd.Audio(wav, rate=config['sampling_rate']))"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {
221 | "colab_type": "text",
222 | "id": "eZJo81viVus-"
223 | },
224 | "source": [
225 | "### MelGAN"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": null,
231 | "metadata": {
232 | "colab": {
233 | "base_uri": "https://localhost:8080/",
234 | "height": 34
235 | },
236 | "colab_type": "code",
237 | "id": "WjIuQALHTr-R",
238 | "outputId": "503b1a50-b658-4f38-acd3-ed8346571f24"
239 | },
240 | "outputs": [],
241 | "source": [
242 | "# Do some sys cleaning\n",
243 | "sys.path.remove(TTS_path)\n",
244 | "sys.modules.pop('model')"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {
251 | "colab": {
252 | "base_uri": "https://localhost:8080/",
253 | "height": 101,
254 | "referenced_widgets": [
255 | "d2d68febdf104697bbf2d4a7d132891a",
256 | "2820cb67ddb6435a9abc1b7252e7fe44",
257 | "a29f9e1975f248348c32254fe5b3d026",
258 | "b0b25d43cc6949d9b921d33984deea0f",
259 | "171439a16fba452a8fdab03a5711a6dc",
260 | "7b4b771dce304335933820d574f49e4f",
261 | "4c9a01915a2648f192071342e0cd3202",
262 | "90ad1e7188db4f98be3081554bb06cbc"
263 | ]
264 | },
265 | "colab_type": "code",
266 | "id": "L4gZZOgmbACF",
267 | "outputId": "d7d953a7-9fe5-44e3-c750-4327dcf3d541"
268 | },
269 | "outputs": [],
270 | "source": [
271 | "sys.path.append(MelGAN_path)\n",
272 | "import torch\n",
273 | "import numpy as np\n",
274 | "\n",
275 | "vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')\n",
276 | "vocoder.eval()\n",
277 | "\n",
278 | "mel = torch.tensor(out_normal['mel'].numpy().T[np.newaxis,:,:])"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": null,
284 | "metadata": {
285 | "colab": {},
286 | "colab_type": "code",
287 | "id": "iqXjJY_2bACJ"
288 | },
289 | "outputs": [],
290 | "source": [
291 | "if torch.cuda.is_available():\n",
292 | " vocoder = vocoder.cuda()\n",
293 | " mel = mel.cuda()\n",
294 | "\n",
295 | "with torch.no_grad():\n",
296 | " audio = vocoder.inference(mel)"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "metadata": {
303 | "colab": {
304 | "base_uri": "https://localhost:8080/",
305 | "height": 62
306 | },
307 | "colab_type": "code",
308 | "id": "vQYaZawLXTJI",
309 | "outputId": "0137a959-76c6-4c2b-cf0d-87d68c585651"
310 | },
311 | "outputs": [],
312 | "source": [
313 | "# Display audio\n",
314 | "ipd.display(ipd.Audio(audio.cpu().numpy(), rate=22050))"
315 | ]
316 | }
317 | ],
318 | "metadata": {
319 | "accelerator": "GPU",
320 | "colab": {
321 | "collapsed_sections": [],
322 | "name": "synthesize_forward_melgan",
323 | "provenance": []
324 | },
325 | "kernelspec": {
326 | "display_name": "ttsTF",
327 | "language": "python",
328 | "name": "ttstf"
329 | },
330 | "language_info": {
331 | "codemirror_mode": {
332 | "name": "ipython",
333 | "version": 3
334 | },
335 | "file_extension": ".py",
336 | "mimetype": "text/x-python",
337 | "name": "python",
338 | "nbconvert_exporter": "python",
339 | "pygments_lexer": "ipython3",
340 | "version": "3.6.9"
341 | },
342 | "widgets": {
343 | "application/vnd.jupyter.widget-state+json": {
344 | "171439a16fba452a8fdab03a5711a6dc": {
345 | "model_module": "@jupyter-widgets/controls",
346 | "model_name": "ProgressStyleModel",
347 | "state": {
348 | "_model_module": "@jupyter-widgets/controls",
349 | "_model_module_version": "1.5.0",
350 | "_model_name": "ProgressStyleModel",
351 | "_view_count": null,
352 | "_view_module": "@jupyter-widgets/base",
353 | "_view_module_version": "1.2.0",
354 | "_view_name": "StyleView",
355 | "bar_color": null,
356 | "description_width": "initial"
357 | }
358 | },
359 | "2820cb67ddb6435a9abc1b7252e7fe44": {
360 | "model_module": "@jupyter-widgets/base",
361 | "model_name": "LayoutModel",
362 | "state": {
363 | "_model_module": "@jupyter-widgets/base",
364 | "_model_module_version": "1.2.0",
365 | "_model_name": "LayoutModel",
366 | "_view_count": null,
367 | "_view_module": "@jupyter-widgets/base",
368 | "_view_module_version": "1.2.0",
369 | "_view_name": "LayoutView",
370 | "align_content": null,
371 | "align_items": null,
372 | "align_self": null,
373 | "border": null,
374 | "bottom": null,
375 | "display": null,
376 | "flex": null,
377 | "flex_flow": null,
378 | "grid_area": null,
379 | "grid_auto_columns": null,
380 | "grid_auto_flow": null,
381 | "grid_auto_rows": null,
382 | "grid_column": null,
383 | "grid_gap": null,
384 | "grid_row": null,
385 | "grid_template_areas": null,
386 | "grid_template_columns": null,
387 | "grid_template_rows": null,
388 | "height": null,
389 | "justify_content": null,
390 | "justify_items": null,
391 | "left": null,
392 | "margin": null,
393 | "max_height": null,
394 | "max_width": null,
395 | "min_height": null,
396 | "min_width": null,
397 | "object_fit": null,
398 | "object_position": null,
399 | "order": null,
400 | "overflow": null,
401 | "overflow_x": null,
402 | "overflow_y": null,
403 | "padding": null,
404 | "right": null,
405 | "top": null,
406 | "visibility": null,
407 | "width": null
408 | }
409 | },
410 | "4c9a01915a2648f192071342e0cd3202": {
411 | "model_module": "@jupyter-widgets/controls",
412 | "model_name": "DescriptionStyleModel",
413 | "state": {
414 | "_model_module": "@jupyter-widgets/controls",
415 | "_model_module_version": "1.5.0",
416 | "_model_name": "DescriptionStyleModel",
417 | "_view_count": null,
418 | "_view_module": "@jupyter-widgets/base",
419 | "_view_module_version": "1.2.0",
420 | "_view_name": "StyleView",
421 | "description_width": ""
422 | }
423 | },
424 | "7b4b771dce304335933820d574f49e4f": {
425 | "model_module": "@jupyter-widgets/base",
426 | "model_name": "LayoutModel",
427 | "state": {
428 | "_model_module": "@jupyter-widgets/base",
429 | "_model_module_version": "1.2.0",
430 | "_model_name": "LayoutModel",
431 | "_view_count": null,
432 | "_view_module": "@jupyter-widgets/base",
433 | "_view_module_version": "1.2.0",
434 | "_view_name": "LayoutView",
435 | "align_content": null,
436 | "align_items": null,
437 | "align_self": null,
438 | "border": null,
439 | "bottom": null,
440 | "display": null,
441 | "flex": null,
442 | "flex_flow": null,
443 | "grid_area": null,
444 | "grid_auto_columns": null,
445 | "grid_auto_flow": null,
446 | "grid_auto_rows": null,
447 | "grid_column": null,
448 | "grid_gap": null,
449 | "grid_row": null,
450 | "grid_template_areas": null,
451 | "grid_template_columns": null,
452 | "grid_template_rows": null,
453 | "height": null,
454 | "justify_content": null,
455 | "justify_items": null,
456 | "left": null,
457 | "margin": null,
458 | "max_height": null,
459 | "max_width": null,
460 | "min_height": null,
461 | "min_width": null,
462 | "object_fit": null,
463 | "object_position": null,
464 | "order": null,
465 | "overflow": null,
466 | "overflow_x": null,
467 | "overflow_y": null,
468 | "padding": null,
469 | "right": null,
470 | "top": null,
471 | "visibility": null,
472 | "width": null
473 | }
474 | },
475 | "90ad1e7188db4f98be3081554bb06cbc": {
476 | "model_module": "@jupyter-widgets/base",
477 | "model_name": "LayoutModel",
478 | "state": {
479 | "_model_module": "@jupyter-widgets/base",
480 | "_model_module_version": "1.2.0",
481 | "_model_name": "LayoutModel",
482 | "_view_count": null,
483 | "_view_module": "@jupyter-widgets/base",
484 | "_view_module_version": "1.2.0",
485 | "_view_name": "LayoutView",
486 | "align_content": null,
487 | "align_items": null,
488 | "align_self": null,
489 | "border": null,
490 | "bottom": null,
491 | "display": null,
492 | "flex": null,
493 | "flex_flow": null,
494 | "grid_area": null,
495 | "grid_auto_columns": null,
496 | "grid_auto_flow": null,
497 | "grid_auto_rows": null,
498 | "grid_column": null,
499 | "grid_gap": null,
500 | "grid_row": null,
501 | "grid_template_areas": null,
502 | "grid_template_columns": null,
503 | "grid_template_rows": null,
504 | "height": null,
505 | "justify_content": null,
506 | "justify_items": null,
507 | "left": null,
508 | "margin": null,
509 | "max_height": null,
510 | "max_width": null,
511 | "min_height": null,
512 | "min_width": null,
513 | "object_fit": null,
514 | "object_position": null,
515 | "order": null,
516 | "overflow": null,
517 | "overflow_x": null,
518 | "overflow_y": null,
519 | "padding": null,
520 | "right": null,
521 | "top": null,
522 | "visibility": null,
523 | "width": null
524 | }
525 | },
526 | "a29f9e1975f248348c32254fe5b3d026": {
527 | "model_module": "@jupyter-widgets/controls",
528 | "model_name": "FloatProgressModel",
529 | "state": {
530 | "_dom_classes": [],
531 | "_model_module": "@jupyter-widgets/controls",
532 | "_model_module_version": "1.5.0",
533 | "_model_name": "FloatProgressModel",
534 | "_view_count": null,
535 | "_view_module": "@jupyter-widgets/controls",
536 | "_view_module_version": "1.5.0",
537 | "_view_name": "ProgressView",
538 | "bar_style": "success",
539 | "description": "100%",
540 | "description_tooltip": null,
541 | "layout": "IPY_MODEL_7b4b771dce304335933820d574f49e4f",
542 | "max": 17090302,
543 | "min": 0,
544 | "orientation": "horizontal",
545 | "style": "IPY_MODEL_171439a16fba452a8fdab03a5711a6dc",
546 | "value": 17090302
547 | }
548 | },
549 | "b0b25d43cc6949d9b921d33984deea0f": {
550 | "model_module": "@jupyter-widgets/controls",
551 | "model_name": "HTMLModel",
552 | "state": {
553 | "_dom_classes": [],
554 | "_model_module": "@jupyter-widgets/controls",
555 | "_model_module_version": "1.5.0",
556 | "_model_name": "HTMLModel",
557 | "_view_count": null,
558 | "_view_module": "@jupyter-widgets/controls",
559 | "_view_module_version": "1.5.0",
560 | "_view_name": "HTMLView",
561 | "description": "",
562 | "description_tooltip": null,
563 | "layout": "IPY_MODEL_90ad1e7188db4f98be3081554bb06cbc",
564 | "placeholder": "",
565 | "style": "IPY_MODEL_4c9a01915a2648f192071342e0cd3202",
566 | "value": " 16.3M/16.3M [00:00<00:00, 91.6MB/s]"
567 | }
568 | },
569 | "d2d68febdf104697bbf2d4a7d132891a": {
570 | "model_module": "@jupyter-widgets/controls",
571 | "model_name": "HBoxModel",
572 | "state": {
573 | "_dom_classes": [],
574 | "_model_module": "@jupyter-widgets/controls",
575 | "_model_module_version": "1.5.0",
576 | "_model_name": "HBoxModel",
577 | "_view_count": null,
578 | "_view_module": "@jupyter-widgets/controls",
579 | "_view_module_version": "1.5.0",
580 | "_view_name": "HBoxView",
581 | "box_style": "",
582 | "children": [
583 | "IPY_MODEL_a29f9e1975f248348c32254fe5b3d026",
584 | "IPY_MODEL_b0b25d43cc6949d9b921d33984deea0f"
585 | ],
586 | "layout": "IPY_MODEL_2820cb67ddb6435a9abc1b7252e7fe44"
587 | }
588 | }
589 | }
590 | }
591 | },
592 | "nbformat": 4,
593 | "nbformat_minor": 1
594 | }
595 |
--------------------------------------------------------------------------------
/predict_tts.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from pathlib import Path
3 |
4 | import numpy as np
5 |
6 | from model.factory import tts_ljspeech
7 | from data.audio import Audio
8 | from model.models import ForwardTransformer
9 |
10 | if __name__ == '__main__':
11 | parser = ArgumentParser()
12 | parser.add_argument('--path', '-p', dest='path', default=None, type=str)
13 | parser.add_argument('--step', dest='step', default='90000', type=str)
14 | parser.add_argument('--text', '-t', dest='text', default=None, type=str)
15 | parser.add_argument('--file', '-f', dest='file', default=None, type=str)
16 | parser.add_argument('--outdir', '-o', dest='outdir', default=None, type=str)
17 | parser.add_argument('--store_mel', '-m', dest='store_mel', action='store_true')
18 | parser.add_argument('--verbose', '-v', dest='verbose', action='store_true')
19 | parser.add_argument('--single', '-s', dest='single', action='store_true')
20 | args = parser.parse_args()
21 |
22 | if args.file is not None:
23 | with open(args.file, 'r') as file:
24 | text = file.readlines()
25 | fname = Path(args.file).stem
26 | elif args.text is not None:
27 | text = [args.text]
28 | fname = 'custom_text'
29 | else:
30 | fname = None
31 | text = None
32 | print(f'Specify either an input text (-t "some text") or a text input file (-f /path/to/file.txt)')
33 | exit()
34 | # load the appropriate model
35 | outdir = Path(args.outdir) if args.outdir is not None else Path('.')
36 | if args.path is not None:
37 | print(f'Loading model from {args.path}')
38 | model = ForwardTransformer.load_model(args.path)
39 | else:
40 | model = tts_ljspeech(args.step)
41 | file_name = f"{fname}_{model.config['data_name']}_{model.config['git_hash']}_{model.config['step']}"
42 | outdir = outdir / 'outputs' / f'{fname}'
43 | outdir.mkdir(exist_ok=True, parents=True)
44 | output_path = (outdir / file_name).with_suffix('.wav')
45 | audio = Audio.from_config(model.config)
46 | print(f'Output wav under {output_path.parent}')
47 | wavs = []
48 | for i, text_line in enumerate(text):
49 | phons = model.text_pipeline.phonemizer(text_line)
50 | tokens = model.text_pipeline.tokenizer(phons)
51 | if args.verbose:
52 | print(f'Predicting {text_line}')
53 | print(f'Phonemes: "{phons}"')
54 | print(f'Tokens: "{tokens}"')
55 | out = model.predict(tokens, encode=False, phoneme_max_duration=None)
56 | mel = out['mel'].numpy().T
57 | wav = audio.reconstruct_waveform(mel)
58 | wavs.append(wav)
59 | if args.store_mel:
60 | np.save((outdir / (file_name + f'_{i}')).with_suffix('.mel'), out['mel'].numpy())
61 | if args.single:
62 | audio.save_wav(wav, (outdir / (file_name + f'_{i}')).with_suffix('.wav'))
63 | audio.save_wav(np.concatenate(wavs), output_path)
64 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.2.2
2 | librosa==0.7.1
3 | numba==0.48
4 | numpy>=1.17.4
5 | phonemizer~=2.2.1
6 | ruamel.yaml>=0.16.6
7 | tensorflow>=2.2.0
8 | tqdm==4.40.1
9 | p_tqdm
10 | soundfile
11 | webrtcvad
12 | scipy
13 | pyworld
14 |
--------------------------------------------------------------------------------
/test_sentences.txt:
--------------------------------------------------------------------------------
1 | This is a nice test.
2 | Is this a nice test?
3 | President Trump met with other leaders at the Group of twenty conference.
4 | Scientists at the CERN laboratory say they have discovered a new particle.
5 | There’s a way to measure the acute emotional intelligence that has never gone out of style.
6 | The Senate's bill to repeal and replace the Affordable Care-Act is now imperiled.
7 | Peter piper picked a peck of pickled peppers.
8 | If I were to talk to a human, I would definitely try to sound normal. Wouldn't I?
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_char_tokenizer.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | from data.text.tokenizer import Tokenizer
7 |
8 |
9 | class TestCharTokenizer(unittest.TestCase):
10 |
11 | def test_tokenizer(self):
12 | tokenizer = Tokenizer(alphabet=list('ab c'))
13 | self.assertEqual(5, tokenizer.start_token_index)
14 | self.assertEqual(6, tokenizer.end_token_index)
15 | self.assertEqual(7, tokenizer.vocab_size)
16 |
17 | seq = tokenizer('a b d')
18 | self.assertEqual([5, 1, 3, 2, 3, 6], seq)
19 |
20 | seq = np.array([5, 1, 3, 2, 8, 6])
21 | seq = tf.convert_to_tensor(seq)
22 | text = tokenizer.decode(seq)
23 | self.assertEqual('>a b<', text)
24 |
--------------------------------------------------------------------------------
/tests/test_config.yaml:
--------------------------------------------------------------------------------
1 | # ARCHITECTURE
2 | decoder_model_dimension: 128
3 | encoder_model_dimension: 128
4 | decoder_num_heads: [1]
5 | encoder_num_heads: [1]
6 | encoder_feed_forward_dimension: 128
7 | decoder_feed_forward_dimension: 128
8 | decoder_prenet_dimension: 128
9 | max_position_encoding: 10000
10 | postnet_conv_filters: 64
11 | postnet_conv_layers: 1
12 | postnet_kernel_size: 5
13 | dropout_rate: 0.1
14 | # DATA
15 | n_samples: 600
16 | mel_channels: 80
17 | sr: 22050
18 | mel_start_value: -3
19 | mel_end_value: 1
20 | # TRAINING
21 | use_decoder_prenet_dropout_schedule: True
22 | decoder_prenet_dropout_schedule_max: 0.9
23 | decoder_prenet_dropout_schedule_min: 0.6
24 | decoder_prenet_dropout_schedule_max_steps: 30_000
25 | fixed_decoder_prenet_dropout: 0.6
26 | epochs: 10
27 | batch_size: 2
28 | learning_rate_schedule:
29 | - [0, 1.0e-3]
30 | reduction_factor_schedule:
31 | - [0, 10]
32 | mask_prob: 0.3
33 | use_block_attention: False
34 | debug: True
35 | # LOGGING
36 | text_freq: 5000
37 | image_freq: 1000
38 | weights_save_freq: 10
39 | plot_attention_freq: 500
40 | keep_n_weights: 5
41 | warmup_steps: 1_000
42 | warmup_lr: 1.0e-6
43 | #TOKENIZER
44 | use_phonemes: True
45 | phoneme_language: 'en'
46 | tokenizer_alphabet: "!,.:;?'- abcdefghijklmnopqrstuvwxyzäüöß"
--------------------------------------------------------------------------------
/tests/test_loss.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from utils.losses import new_scaled_crossentropy, masked_crossentropy
6 |
7 |
8 | class TestCharTokenizer(unittest.TestCase):
9 |
10 | def test_crossentropy(self):
11 | scaled_crossent = new_scaled_crossentropy(index=2, scaling=5)
12 |
13 | targets = np.array([[0, 1, 2]])
14 | logits = np.array([[[.3, .2, .1], [.3, .2, .1], [.3, .2, .1]]])
15 |
16 | loss = scaled_crossent(targets, logits)
17 | self.assertAlmostEqual(2.3705523014068604, float(loss))
18 |
19 | scaled_crossent = new_scaled_crossentropy(index=2, scaling=1)
20 | loss = scaled_crossent(targets, logits)
21 | self.assertAlmostEqual(0.7679619193077087, float(loss))
22 |
23 | loss = masked_crossentropy(targets, logits)
24 | self.assertAlmostEqual(0.7679619193077087, float(loss))
25 |
--------------------------------------------------------------------------------
/train_aligner.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from tqdm import trange
4 |
5 | from utils.training_config_manager import TrainingConfigManager
6 | from data.datasets import AlignerDataset, AlignerPreprocessor
7 | from utils.decorators import ignore_exception, time_it
8 | from utils.scheduling import piecewise_linear_schedule, reduction_schedule
9 | from utils.logging_utils import SummaryManager
10 | from utils.scripts_utils import dynamic_memory_allocation, basic_train_parser
11 | from utils.metrics import attention_score
12 | from utils.spectrogram_ops import mel_lengths, phoneme_lengths
13 | from utils.alignments import get_durations_from_alignment
14 |
15 | np.random.seed(42)
16 | tf.random.set_seed(42)
17 |
18 | dynamic_memory_allocation()
19 | parser = basic_train_parser()
20 | args = parser.parse_args()
21 |
22 |
23 | def cut_with_durations(durations, mel, phonemes, snippet_len=10):
24 | phon_dur = np.pad(durations, (1, 0))
25 | starts = np.cumsum(phon_dur)[:-1]
26 | ends = np.cumsum(phon_dur)[1:]
27 | cut_mels = []
28 | cut_texts = []
29 | for end_idx in range(snippet_len, len(phon_dur), snippet_len):
30 | start_idx = end_idx - snippet_len
31 | cut_mels.append(mel[starts[start_idx]: ends[end_idx - 1], :])
32 | cut_texts.append(phonemes[start_idx: end_idx])
33 | return cut_mels, cut_texts
34 |
35 |
36 | @ignore_exception
37 | @time_it
38 | def validate(model,
39 | val_dataset,
40 | summary_manager,
41 | weighted_durations):
42 | val_loss = {'loss': 0.}
43 | norm = 0.
44 | current_r = model.r
45 | model.set_constants(reduction_factor=1)
46 | for val_mel, val_text, val_stop, fname in val_dataset.all_batches():
47 | model_out = model.val_step(inp=val_text,
48 | tar=val_mel,
49 | stop_prob=val_stop)
50 | norm += 1
51 | val_loss['loss'] += model_out['loss']
52 | val_loss['loss'] /= norm
53 | summary_manager.display_loss(model_out, tag='Validation', plot_all=True)
54 | summary_manager.display_last_attention(model_out, tag='ValidationAttentionHeads', fname=fname)
55 | attention_values = model_out['decoder_attention']['Decoder_LastBlock_CrossAttention'].numpy()
56 | text = val_text.numpy()
57 | mel = val_mel.numpy()
58 | model.set_constants(reduction_factor=current_r)
59 | modes = list({False, weighted_durations})
60 | for mode in modes:
61 | durations, final_align, jumpiness, peakiness, diag_measure = get_durations_from_alignment(
62 | batch_alignments=attention_values,
63 | mels=mel,
64 | phonemes=text,
65 | weighted=mode)
66 | for k in range(len(durations)):
67 | phon_dur = durations[k]
68 | imel = mel[k][1:] # remove start token (is padded so end token can't be removed/not an issue)
69 | itext = text[k][1:] # remove start token (is padded so end token can't be removed/not an issue)
70 | iphon = model.text_pipeline.tokenizer.decode(itext).replace('/', '')
71 | cut_mels, cut_texts = cut_with_durations(durations=phon_dur, mel=imel, phonemes=iphon)
72 | for cut_idx, cut_text in enumerate(cut_texts):
73 | weighted_label = 'weighted_' * mode
74 | summary_manager.display_audio(
75 | tag=f'CutAudio {weighted_label}{fname[k].numpy().decode("utf-8")}/{cut_idx}/{cut_text}',
76 | mel=cut_mels[cut_idx], description=iphon)
77 | return val_loss['loss']
78 |
79 |
80 | config_manager = TrainingConfigManager(config_path=args.config, aligner=True)
81 | config = config_manager.config
82 | config_manager.create_remove_dirs(clear_dir=args.clear_dir,
83 | clear_logs=args.clear_logs,
84 | clear_weights=args.clear_weights)
85 | config_manager.dump_config()
86 | config_manager.print_config()
87 |
88 | # get model, prepare data for model, create datasets
89 | model = config_manager.get_model()
90 | config_manager.compile_model(model)
91 | data_prep = AlignerPreprocessor.from_config(config_manager,
92 | tokenizer=model.text_pipeline.tokenizer) # TODO: tokenizer is now static
93 | train_data_handler = AlignerDataset.from_config(config_manager,
94 | preprocessor=data_prep,
95 | kind='train')
96 | valid_data_handler = AlignerDataset.from_config(config_manager,
97 | preprocessor=data_prep,
98 | kind='valid')
99 |
100 | train_dataset = train_data_handler.get_dataset(bucket_batch_sizes=config['bucket_batch_sizes'],
101 | bucket_boundaries=config['bucket_boundaries'],
102 | shuffle=True)
103 | valid_dataset = valid_data_handler.get_dataset(bucket_batch_sizes=config['val_bucket_batch_size'],
104 | bucket_boundaries=config['bucket_boundaries'],
105 | shuffle=False, drop_remainder=True)
106 |
107 | # create logger and checkpointer and restore latest model
108 |
109 | summary_manager = SummaryManager(model=model, log_dir=config_manager.log_dir, config=config)
110 | checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
111 | optimizer=model.optimizer,
112 | net=model)
113 | manager = tf.train.CheckpointManager(checkpoint, str(config_manager.weights_dir),
114 | max_to_keep=config['keep_n_weights'],
115 | keep_checkpoint_every_n_hours=config['keep_checkpoint_every_n_hours'])
116 | manager_training = tf.train.CheckpointManager(checkpoint, str(config_manager.weights_dir / 'latest'),
117 | max_to_keep=1, checkpoint_name='latest')
118 |
119 | checkpoint.restore(manager_training.latest_checkpoint)
120 | if manager_training.latest_checkpoint:
121 | print(f'\nresuming training from step {model.step} ({manager_training.latest_checkpoint})')
122 | else:
123 | print(f'\nstarting training from scratch')
124 |
125 | if config['debug'] is True:
126 | print('\nWARNING: DEBUG is set to True. Training in eager mode.')
127 | # main event
128 | print('\nTRAINING')
129 |
130 | texts = []
131 | for text_file in config['test_stencences']:
132 | with open(text_file, 'r') as file:
133 | text = file.readlines()
134 | texts.append(text)
135 |
136 | losses = []
137 | test_mel, test_phonemes, _, test_fname = valid_dataset.next_batch()
138 | val_test_sample, val_test_fname, val_test_mel = test_phonemes[0], test_fname[0], test_mel[0]
139 | val_test_sample = tf.boolean_mask(val_test_sample, val_test_sample!=0)
140 |
141 | _ = train_dataset.next_batch()
142 | t = trange(model.step, config['max_steps'], leave=True)
143 | for _ in t:
144 | t.set_description(f'step {model.step}')
145 | mel, phonemes, stop, sample_name = train_dataset.next_batch()
146 | learning_rate = piecewise_linear_schedule(model.step, config['learning_rate_schedule'])
147 | reduction_factor = reduction_schedule(model.step, config['reduction_factor_schedule'])
148 | t.display(f'reduction factor {reduction_factor}', pos=10)
149 | force_encoder_diagonal = model.step < config['force_encoder_diagonal_steps']
150 | force_decoder_diagonal = model.step < config['force_decoder_diagonal_steps']
151 | model.set_constants(learning_rate=learning_rate,
152 | reduction_factor=reduction_factor,
153 | force_encoder_diagonal=force_encoder_diagonal,
154 | force_decoder_diagonal=force_decoder_diagonal)
155 |
156 | output = model.train_step(inp=phonemes,
157 | tar=mel,
158 | stop_prob=stop)
159 | losses.append(float(output['loss']))
160 |
161 | t.display(f'step loss: {losses[-1]}', pos=1)
162 | for pos, n_steps in enumerate(config['n_steps_avg_losses']):
163 | if len(losses) > n_steps:
164 | t.display(f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}', pos=pos + 2)
165 |
166 | summary_manager.display_loss(output, tag='Train')
167 | summary_manager.display_scalar(tag='Meta/learning_rate', scalar_value=model.optimizer.lr)
168 | summary_manager.display_scalar(tag='Meta/reduction_factor', scalar_value=model.r)
169 | summary_manager.display_scalar(scalar_value=t.avg_time, tag='Meta/iter_time')
170 | summary_manager.display_scalar(scalar_value=tf.shape(sample_name)[0], tag='Meta/batch_size')
171 | if model.step % config['train_images_plotting_frequency'] == 0:
172 | summary_manager.display_attention_heads(output, tag='TrainAttentionHeads')
173 | summary_manager.display_mel(mel=output['mel'][0], tag=f'Train/predicted_mel')
174 | for layer, k in enumerate(output['decoder_attention'].keys()):
175 | mel_lens = mel_lengths(mel_batch=mel, padding_value=0) // model.r # [N]
176 | phon_len = phoneme_lengths(phonemes)
177 | loc_score, peak_score, diag_measure = attention_score(att=output['decoder_attention'][k],
178 | mel_len=mel_lens,
179 | phon_len=phon_len,
180 | r=model.r)
181 | loc_score = tf.reduce_mean(loc_score, axis=0)
182 | peak_score = tf.reduce_mean(peak_score, axis=0)
183 | diag_measure = tf.reduce_mean(diag_measure, axis=0)
184 | for i in range(tf.shape(loc_score)[0]):
185 | summary_manager.display_scalar(tag=f'TrainDecoderAttentionJumpiness/layer{layer}_head{i}',
186 | scalar_value=tf.reduce_mean(loc_score[i]))
187 | summary_manager.display_scalar(tag=f'TrainDecoderAttentionPeakiness/layer{layer}_head{i}',
188 | scalar_value=tf.reduce_mean(peak_score[i]))
189 | summary_manager.display_scalar(tag=f'TrainDecoderAttentionDiagonality/layer{layer}_head{i}',
190 | scalar_value=tf.reduce_mean(diag_measure[i]))
191 |
192 | if model.step % 1000 == 0:
193 | save_path = manager_training.save()
194 | if model.step % config['weights_save_frequency'] == 0:
195 | save_path = manager.save()
196 | t.display(f'checkpoint at step {model.step}: {save_path}', pos=len(config['n_steps_avg_losses']) + 2)
197 |
198 | if model.step % config['validation_frequency'] == 0 and (model.step >= config['prediction_start_step']):
199 | val_loss, time_taken = validate(model=model,
200 | val_dataset=valid_dataset,
201 | summary_manager=summary_manager,
202 | weighted_durations=config['extract_attention_weighted'])
203 | t.display(f'validation loss at step {model.step}: {val_loss} (took {time_taken}s)',
204 | pos=len(config['n_steps_avg_losses']) + 3)
205 |
206 | if model.step % config['prediction_frequency'] == 0 and (model.step >= config['prediction_start_step']):
207 | for j, text in enumerate(texts):
208 | for i, text_line in enumerate(text):
209 | out = model.predict(text_line, encode=True)
210 | wav = summary_manager.audio.reconstruct_waveform(out['mel'].numpy().T)
211 | wav = tf.expand_dims(wav, 0)
212 | wav = tf.expand_dims(wav, -1)
213 | summary_manager.add_audio(f'Predictions/{text_line}', wav.numpy(), sr=summary_manager.config['sampling_rate'],
214 | step=summary_manager.global_step)
215 |
216 | out = model.predict(val_test_sample, encode=False)#, max_length=tf.shape(val_test_mel)[-2])
217 | wav = summary_manager.audio.reconstruct_waveform(out['mel'].numpy().T)
218 | wav = tf.expand_dims(wav, 0)
219 | wav = tf.expand_dims(wav, -1)
220 | summary_manager.add_audio(f'Predictions/val_sample {val_test_fname.numpy().decode("utf-8")}', wav.numpy(), sr=summary_manager.config['sampling_rate'],
221 | step=summary_manager.global_step)
222 | print('Done.')
223 |
--------------------------------------------------------------------------------
/train_tts.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from tqdm import trange
4 |
5 | from utils.training_config_manager import TrainingConfigManager
6 | from data.datasets import TTSDataset, TTSPreprocessor
7 | from utils.decorators import ignore_exception, time_it
8 | from utils.scheduling import piecewise_linear_schedule
9 | from utils.logging_utils import SummaryManager
10 | from model.transformer_utils import create_mel_padding_mask
11 | from utils.scripts_utils import dynamic_memory_allocation, basic_train_parser
12 | from data.metadata_readers import post_processed_reader
13 |
14 | np.random.seed(42)
15 | tf.random.set_seed(42)
16 | dynamic_memory_allocation()
17 |
18 |
19 | def display_target_symbol_duration_distributions():
20 | phon_data, ups = post_processed_reader(config.phonemized_metadata_path)
21 | dur_dict = {}
22 | for key in phon_data.keys():
23 | dur_dict[key] = np.load((config.duration_dir / key).with_suffix('.npy'))
24 | symbol_durs = {}
25 | for key in dur_dict:
26 | for i, phoneme in enumerate(phon_data[key]):
27 | symbol_durs.setdefault(phoneme, []).append(dur_dict[key][i])
28 | for symbol in symbol_durs.keys():
29 | summary_manager.add_histogram(tag=f'"{symbol}"/Target durations', values=symbol_durs[symbol],
30 | buckets=len(set(symbol_durs[symbol])) + 1, step=0)
31 |
32 |
33 | def display_predicted_symbol_duration_distributions(all_durations):
34 | phon_data, ups = post_processed_reader(config.phonemized_metadata_path)
35 | symbol_durs = {}
36 | for key in all_durations.keys():
37 | clean_key = key.decode('utf-8')
38 | for i, phoneme in enumerate(phon_data[clean_key]):
39 | symbol_durs.setdefault(phoneme, []).append(all_durations[key][i])
40 | for symbol in symbol_durs.keys():
41 | summary_manager.add_histogram(tag=f'"{symbol}"/Predicted durations', values=symbol_durs[symbol])
42 |
43 |
44 | @ignore_exception
45 | @time_it
46 | def validate(model,
47 | val_dataset,
48 | summary_manager):
49 | val_loss = {'loss': 0.}
50 | norm = 0.
51 | for mel, phonemes, durations, pitch, fname in val_dataset.all_batches():
52 | model_out = model.val_step(input_sequence=phonemes,
53 | target_sequence=mel,
54 | target_durations=durations,
55 | target_pitch=pitch)
56 | norm += 1
57 | val_loss['loss'] += model_out['loss']
58 | val_loss['loss'] /= norm
59 | summary_manager.display_loss(model_out, tag='Validation', plot_all=True)
60 | summary_manager.display_attention_heads(model_out, tag='ValidationAttentionHeads')
61 | summary_manager.add_histogram(tag=f'Validation/Predicted durations', values=model_out['duration'])
62 | summary_manager.add_histogram(tag=f'Validation/Target durations', values=durations)
63 | summary_manager.display_plot1D(tag=f'Validation/{fname[0].numpy().decode("utf-8")} predicted pitch',
64 | y=model_out['pitch'][0])
65 | summary_manager.display_plot1D(tag=f'Validation/{fname[0].numpy().decode("utf-8")} target pitch', y=pitch[0])
66 | summary_manager.display_mel(mel=model_out['mel'][0],
67 | tag=f'Validation/{fname[0].numpy().decode("utf-8")} predicted_mel')
68 | summary_manager.display_mel(mel=mel[0], tag=f'Validation/{fname[0].numpy().decode("utf-8")} target_mel')
69 | summary_manager.display_audio(tag=f'Validation {fname[0].numpy().decode("utf-8")}/prediction',
70 | mel=model_out['mel'][0])
71 | summary_manager.display_audio(tag=f'Validation {fname[0].numpy().decode("utf-8")}/target', mel=mel[0])
72 | # predict withoyt enforcing durations and pitch
73 | model_out = model.predict(phonemes, encode=False)
74 | pred_lengths = tf.cast(tf.reduce_sum(1 - model_out['expanded_mask'], axis=-1), tf.int32)
75 | pred_lengths = tf.squeeze(pred_lengths)
76 | tar_lengths = tf.cast(tf.reduce_sum(1 - create_mel_padding_mask(mel), axis=-1), tf.int32)
77 | tar_lengths = tf.squeeze(tar_lengths)
78 | for j, pred_mel in enumerate(model_out['mel']):
79 | predval = pred_mel[:pred_lengths[j], :]
80 | tar_value = mel[j, :tar_lengths[j], :]
81 | summary_manager.display_mel(mel=predval, tag=f'Test/{fname[j].numpy().decode("utf-8")}/predicted')
82 | summary_manager.display_mel(mel=tar_value, tag=f'Test/{fname[j].numpy().decode("utf-8")}/target')
83 | summary_manager.display_audio(tag=f'Prediction {fname[j].numpy().decode("utf-8")}/target', mel=tar_value)
84 | summary_manager.display_audio(tag=f'Prediction {fname[j].numpy().decode("utf-8")}/prediction',
85 | mel=predval)
86 | return val_loss['loss']
87 |
88 |
89 | parser = basic_train_parser()
90 | args = parser.parse_args()
91 |
92 | config = TrainingConfigManager(config_path=args.config)
93 | config_dict = config.config
94 | config.create_remove_dirs(clear_dir=args.clear_dir,
95 | clear_logs=args.clear_logs,
96 | clear_weights=args.clear_weights)
97 | config.dump_config()
98 | config.print_config()
99 |
100 | model = config.get_model()
101 | config.compile_model(model)
102 |
103 | data_prep = TTSPreprocessor.from_config(config=config,
104 | tokenizer=model.text_pipeline.tokenizer)
105 | train_data_handler = TTSDataset.from_config(config,
106 | preprocessor=data_prep,
107 | kind='train')
108 | valid_data_handler = TTSDataset.from_config(config,
109 | preprocessor=data_prep,
110 | kind='valid')
111 | train_dataset = train_data_handler.get_dataset(bucket_batch_sizes=config_dict['bucket_batch_sizes'],
112 | bucket_boundaries=config_dict['bucket_boundaries'],
113 | shuffle=True)
114 | valid_dataset = valid_data_handler.get_dataset(bucket_batch_sizes=config_dict['val_bucket_batch_size'],
115 | bucket_boundaries=config_dict['bucket_boundaries'],
116 | shuffle=False,
117 | drop_remainder=True)
118 |
119 | # create logger and checkpointer and restore latest model
120 | summary_manager = SummaryManager(model=model, log_dir=config.log_dir, config=config_dict)
121 | checkpoint = tf.train.Checkpoint(step=tf.Variable(1),
122 | optimizer=model.optimizer,
123 | net=model)
124 | manager_training = tf.train.CheckpointManager(checkpoint, str(config.weights_dir / 'latest'),
125 | max_to_keep=1, checkpoint_name='latest')
126 |
127 | checkpoint.restore(manager_training.latest_checkpoint)
128 | if manager_training.latest_checkpoint:
129 | print(f'\nresuming training from step {model.step} ({manager_training.latest_checkpoint})')
130 | else:
131 | print(f'\nstarting training from scratch')
132 |
133 | if config_dict['debug'] is True:
134 | print('\nWARNING: DEBUG is set to True. Training in eager mode.')
135 |
136 | display_target_symbol_duration_distributions()
137 | # main event
138 | print('\nTRAINING')
139 | losses = []
140 | texts = []
141 | for text_file in config_dict['text_prediction']:
142 | with open(text_file, 'r') as file:
143 | text = file.readlines()
144 | texts.append(text)
145 |
146 | all_files = len(set(train_data_handler.metadata_reader.filenames)) # without duplicates
147 | all_durations = {}
148 | t = trange(model.step, config_dict['max_steps'], leave=True)
149 | for _ in t:
150 | t.set_description(f'step {model.step}')
151 | mel, phonemes, durations, pitch, fname = train_dataset.next_batch()
152 | learning_rate = piecewise_linear_schedule(model.step, config_dict['learning_rate_schedule'])
153 | model.set_constants(learning_rate=learning_rate)
154 | output = model.train_step(input_sequence=phonemes,
155 | target_sequence=mel,
156 | target_durations=durations,
157 | target_pitch=pitch)
158 | losses.append(float(output['loss']))
159 |
160 | predicted_durations = dict(zip(fname.numpy(), output['duration'].numpy()))
161 | all_durations.update(predicted_durations)
162 | if len(all_durations) >= all_files: # all the dataset has been processed
163 | display_predicted_symbol_duration_distributions(all_durations)
164 | all_durations = {}
165 |
166 | t.display(f'step loss: {losses[-1]}', pos=1)
167 | for pos, n_steps in enumerate(config_dict['n_steps_avg_losses']):
168 | if len(losses) > n_steps:
169 | t.display(f'{n_steps}-steps average loss: {sum(losses[-n_steps:]) / n_steps}', pos=pos + 2)
170 |
171 | summary_manager.display_loss(output, tag='Train')
172 | summary_manager.display_scalar(scalar_value=t.avg_time, tag='Meta/iter_time')
173 | summary_manager.display_scalar(scalar_value=tf.shape(fname)[0], tag='Meta/batch_size')
174 | summary_manager.display_scalar(tag='Meta/learning_rate', scalar_value=model.optimizer.lr)
175 | if model.step % config_dict['train_images_plotting_frequency'] == 0:
176 | summary_manager.display_attention_heads(output, tag='TrainAttentionHeads')
177 | summary_manager.display_mel(mel=output['mel'][0], tag=f'Train/predicted_mel')
178 | summary_manager.display_mel(mel=mel[0], tag=f'Train/target_mel')
179 | summary_manager.display_plot1D(tag=f'Train/Predicted pitch', y=output['pitch'][0])
180 | summary_manager.display_plot1D(tag=f'Train/Target pitch', y=pitch[0])
181 |
182 | if model.step % 1000 == 0:
183 | save_path = manager_training.save()
184 | if (model.step % config_dict['weights_save_frequency'] == 0) & (
185 | model.step >= config_dict['weights_save_starting_step']):
186 | model.save_model(config.weights_dir / f'step_{model.step}')
187 | t.display(f'checkpoint at step {model.step}: {config.weights_dir / f"step_{model.step}"}',
188 | pos=len(config_dict['n_steps_avg_losses']) + 2)
189 |
190 | if model.step % config_dict['validation_frequency'] == 0:
191 | t.display(f'Validating', pos=len(config_dict['n_steps_avg_losses']) + 3)
192 | val_loss, time_taken = validate(model=model,
193 | val_dataset=valid_dataset,
194 | summary_manager=summary_manager)
195 | t.display(f'validation loss at step {model.step}: {val_loss} (took {time_taken}s)',
196 | pos=len(config_dict['n_steps_avg_losses']) + 3)
197 |
198 | if model.step % config_dict['prediction_frequency'] == 0 and (model.step >= config_dict['prediction_start_step']):
199 | for i, text in enumerate(texts):
200 | wavs = []
201 | for i, text_line in enumerate(text):
202 | out = model.predict(text_line, encode=True)
203 | wav = summary_manager.audio.reconstruct_waveform(out['mel'].numpy().T)
204 | wavs.append(wav)
205 | wavs = np.concatenate(wavs)
206 | wavs = tf.expand_dims(wavs, 0)
207 | wavs = tf.expand_dims(wavs, -1)
208 | summary_manager.add_audio(f'Text file input', wavs.numpy(), sr=summary_manager.config['sampling_rate'],
209 | step=summary_manager.global_step)
210 |
211 | print('Done.')
212 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spring-media/TransformerTTS/363805548abdd93b33508da2c027ae514bfc1a07/utils/__init__.py
--------------------------------------------------------------------------------
/utils/alignments.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from utils.metrics import attention_score
4 | from utils.spectrogram_ops import mel_lengths, phoneme_lengths
5 |
6 | logger = tf.get_logger()
7 | logger.setLevel('ERROR')
8 | import numpy as np
9 | from scipy.sparse import coo_matrix
10 | from scipy.sparse.csgraph import dijkstra
11 |
12 |
13 | def to_node_index(i, j, cols):
14 | return cols * i + j
15 |
16 |
17 | def from_node_index(node_index, cols):
18 | return node_index // cols, node_index % cols
19 |
20 |
21 | def to_adj_matrix(mat):
22 | rows = mat.shape[0]
23 | cols = mat.shape[1]
24 |
25 | row_ind = []
26 | col_ind = []
27 | data = []
28 |
29 | for i in range(rows):
30 | for j in range(cols):
31 |
32 | node = to_node_index(i, j, cols)
33 |
34 | if j < cols - 1:
35 | right_node = to_node_index(i, j + 1, cols)
36 | weight_right = mat[i, j + 1]
37 | row_ind.append(node)
38 | col_ind.append(right_node)
39 | data.append(weight_right)
40 |
41 | if i < rows - 1 and j < cols:
42 | bottom_node = to_node_index(i + 1, j, cols)
43 | weight_bottom = mat[i + 1, j]
44 | row_ind.append(node)
45 | col_ind.append(bottom_node)
46 | data.append(weight_bottom)
47 |
48 | if i < rows - 1 and j < cols - 1:
49 | bottom_right_node = to_node_index(i + 1, j + 1, cols)
50 | weight_bottom_right = mat[i + 1, j + 1]
51 | row_ind.append(node)
52 | col_ind.append(bottom_right_node)
53 | data.append(weight_bottom_right)
54 |
55 | adj_mat = coo_matrix((data, (row_ind, col_ind)), shape=(rows * cols, rows * cols))
56 | return adj_mat.tocsr()
57 |
58 |
59 | def extract_durations_with_dijkstra(attention_map: np.array) -> np.array:
60 | """
61 | Extracts durations from the attention matrix by finding the shortest monotonic path from
62 | top left to bottom right.
63 | """
64 | attn_max = np.max(attention_map)
65 | path_probs = attn_max - attention_map
66 | adj_matrix = to_adj_matrix(path_probs)
67 | dist_matrix, predecessors = dijkstra(csgraph=adj_matrix, directed=True,
68 | indices=0, return_predecessors=True)
69 | path = []
70 | pr_index = predecessors[-1]
71 | while pr_index != 0:
72 | path.append(pr_index)
73 | pr_index = predecessors[pr_index]
74 | path.reverse()
75 |
76 | # append first and last node
77 | path = [0] + path + [dist_matrix.size - 1]
78 | cols = path_probs.shape[1]
79 | mel_text = {}
80 | durations = np.zeros(attention_map.shape[1], dtype=np.int32)
81 |
82 | # collect indices (mel, text) along the path
83 | for node_index in path:
84 | i, j = from_node_index(node_index, cols)
85 | mel_text[i] = j
86 |
87 | for j in mel_text.values():
88 | durations[j] += 1
89 |
90 | return durations
91 |
92 |
93 | def duration_to_alignment_matrix(durations):
94 | starts = np.cumsum(np.append([0], durations[:-1]))
95 | tot_duration = np.sum(durations)
96 | pads = tot_duration - starts - durations
97 | alignments = [np.concatenate([np.zeros(starts[i]), np.ones(durations[i]), np.zeros(pads[i])]) for i in
98 | range(len(durations))]
99 | return np.array(alignments)
100 |
101 |
102 | def get_durations_from_alignment(batch_alignments, mels, phonemes, weighted=False):
103 | """
104 |
105 | :param batch_alignments: attention weights from autoregressive model.
106 | :param mels: mel spectrograms.
107 | :param phonemes: phoneme sequence.
108 | :param weighted: if True use weighted average of durations of heads, best head if False.
109 | :param binary: if True take maximum attention peak, sum if False.
110 | :param fill_gaps: if True fills zeros durations with ones.
111 | :param fix_jumps: if True, tries to scan alingments for attention jumps and interpolate.
112 | :param fill_mode: used only if fill_gaps is True. Is either 'max' or 'next'. Defines where to take the duration
113 | needed to fill the gap. Next takes it from the next non-zeros duration value, max from the sequence maximum.
114 | :return:
115 | """
116 | # mel_len - 1 because we remove last timestep, which is end_vector. start vector is not predicted (or removed from GTA)
117 | mel_len = mel_lengths(mels, padding_value=0.) - 1 # [N]
118 | # phonemes contain start and end tokens (start will be removed later)
119 | phon_len = phoneme_lengths(phonemes) - 1
120 | jumpiness, peakiness, diag_measure = attention_score(att=batch_alignments, mel_len=mel_len, phon_len=phon_len, r=1)
121 | attn_scores = diag_measure + jumpiness + peakiness
122 | durations = []
123 | final_alignment = []
124 | for batch_num, al in enumerate(batch_alignments):
125 | unpad_mel_len = mel_len[batch_num]
126 | unpad_phon_len = phon_len[batch_num]
127 | unpad_alignments = al[:, 1:unpad_mel_len, 1:unpad_phon_len] # first dim is heads
128 | scored_attention = unpad_alignments * attn_scores[batch_num][:, None, None]
129 |
130 | if weighted:
131 | ref_attention_weights = np.sum(scored_attention, axis=0)
132 | else:
133 | best_head = np.argmax(attn_scores[batch_num])
134 | ref_attention_weights = unpad_alignments[best_head]
135 | integer_durations = extract_durations_with_dijkstra(ref_attention_weights)
136 |
137 | assert np.sum(integer_durations) == mel_len[batch_num]-1, f'{np.sum(integer_durations)} vs {mel_len[batch_num]-1}'
138 | new_alignment = duration_to_alignment_matrix(integer_durations.astype(int))
139 | best_head = np.argmax(attn_scores[batch_num])
140 | best_attention = unpad_alignments[best_head]
141 | final_alignment.append(best_attention.T + new_alignment)
142 | durations.append(integer_durations)
143 | return durations, final_alignment, jumpiness, peakiness, diag_measure
144 |
--------------------------------------------------------------------------------
/utils/decorators.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from time import time
3 |
4 |
5 | def ignore_exception(f):
6 | def apply_func(*args, **kwargs):
7 | try:
8 | result = f(*args, **kwargs)
9 | return result
10 | except Exception:
11 | print(f'Catched exception in {f}:')
12 | traceback.print_exc()
13 | return None
14 |
15 | return apply_func
16 |
17 |
18 | def time_it(f):
19 | def apply_func(*args, **kwargs):
20 | t_start = time()
21 | result = f(*args, **kwargs)
22 | t_end = time()
23 | dur = round(t_end - t_start, ndigits=2)
24 | return result, dur
25 |
26 | return apply_func
27 |
--------------------------------------------------------------------------------
/utils/display.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 |
6 |
7 | def buffer_image(figure):
8 | buf = io.BytesIO()
9 | figure.savefig(buf, format='png')
10 | buf.seek(0)
11 | plt.close('all')
12 | return buf
13 |
14 | def plot1D(y, figsize=None, title='', x=None):
15 | f = plt.figure(figsize=figsize)
16 | if x is None:
17 | x = np.arange(len(y))
18 | plt.plot(x,y)
19 | plt.title(title)
20 | buf = buffer_image(f)
21 | return buf
22 |
23 |
24 | def plot_image(image, with_bar, figsize=None, title=''):
25 | """Create a pyplot plot and save to buffer."""
26 | f = plt.figure(figsize=figsize)
27 | plt.imshow(image)
28 | plt.title(title)
29 | if with_bar:
30 | plt.colorbar()
31 | buf = buffer_image(f)
32 | return buf
33 |
34 |
35 | def tight_grid(images):
36 | images = np.array(images)
37 | images = np.pad(images, [[0, 0], [1, 1], [1, 1]], 'constant', constant_values=1) # add borders
38 | if len(images.shape) != 3:
39 | raise Exception
40 | else:
41 | n, y, x = images.shape
42 | ratio = y / x
43 | if ratio > 1:
44 | ny = max(int(np.sqrt(n / ratio)), 1)
45 | nx = int(n / ny)
46 | nx += n - (nx * ny)
47 | extra = nx * ny - n
48 | else:
49 | nx = max(int(np.sqrt(n * ratio)), 1)
50 | ny = int(n / nx)
51 | ny += n - (nx * ny)
52 | extra = nx * ny - n
53 | tot = np.append(images, np.zeros((extra, y, x)), axis=0)
54 | img = np.block([[*tot[i * nx:(i + 1) * nx]] for i in range(ny)])
55 | return img
56 |
--------------------------------------------------------------------------------
/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import tensorflow as tf
4 |
5 | from data.audio import Audio
6 | from utils.display import tight_grid, buffer_image, plot_image, plot1D
7 | from utils.vec_ops import norm_tensor
8 | from utils.decorators import ignore_exception
9 |
10 |
11 | def control_frequency(f):
12 | def apply_func(*args, **kwargs):
13 | # args[0] is self
14 | plot_all = ('plot_all' in kwargs) and kwargs['plot_all']
15 | if (args[0].global_step % args[0].plot_frequency == 0) or plot_all:
16 | result = f(*args, **kwargs)
17 | return result
18 | else:
19 | return None
20 |
21 | return apply_func
22 |
23 |
24 | class SummaryManager:
25 | """ Writes tensorboard logs during training.
26 |
27 | :arg model: model object that is trained
28 | :arg log_dir: base directory where logs of a config are created
29 | :arg config: configuration dictionary
30 | :arg max_plot_frequency: every how many steps to plot
31 | """
32 |
33 | def __init__(self,
34 | model: tf.keras.models.Model,
35 | log_dir: str,
36 | config: dict,
37 | max_plot_frequency=10,
38 | default_writer='log_dir'):
39 | self.model = model
40 | self.log_dir = Path(log_dir)
41 | self.config = config
42 | self.audio = Audio.from_config(config)
43 | self.plot_frequency = max_plot_frequency
44 | self.default_writer = default_writer
45 | self.writers = {}
46 | self.add_writer(tag=default_writer, path=self.log_dir, default=True)
47 |
48 | def add_writer(self, path, tag=None, default=False):
49 | """ Adds a writer to self.writers if the writer does not exist already.
50 | To avoid spamming writers on disk.
51 |
52 | :returns the writer on path with tag tag or path
53 | """
54 | if not tag:
55 | tag = path
56 | if tag not in self.writers.keys():
57 | self.writers[tag] = tf.summary.create_file_writer(str(path))
58 | if default:
59 | self.default_writer = tag
60 | return self.writers[tag]
61 |
62 | @property
63 | def global_step(self):
64 | if self.model is not None:
65 | return self.model.step
66 | else:
67 | return 0
68 |
69 | def add_scalars(self, tag, dictionary, step=None):
70 | if step is None:
71 | step = self.global_step
72 | for k in dictionary.keys():
73 | with self.add_writer(str(self.log_dir / k)).as_default():
74 | tf.summary.scalar(name=tag, data=dictionary[k], step=step)
75 |
76 | def add_scalar(self, tag, scalar_value, step=None):
77 | if step is None:
78 | step = self.global_step
79 | with self.writers[self.default_writer].as_default():
80 | tf.summary.scalar(name=tag, data=scalar_value, step=step)
81 |
82 | def add_image(self, tag, image, step=None):
83 | if step is None:
84 | step = self.global_step
85 | with self.writers[self.default_writer].as_default():
86 | tf.summary.image(name=tag, data=image, step=step, max_outputs=4)
87 |
88 | def add_histogram(self, tag, values, buckets=None, step=None):
89 | if step is None:
90 | step = self.global_step
91 | with self.writers[self.default_writer].as_default():
92 | tf.summary.histogram(name=tag, data=values, step=step, buckets=buckets)
93 |
94 | def add_audio(self, tag, wav, sr, step=None, description=None):
95 | if step is None:
96 | step = self.global_step
97 | with self.writers[self.default_writer].as_default():
98 | tf.summary.audio(name=tag,
99 | data=wav,
100 | sample_rate=sr,
101 | step=step,
102 | description=description)
103 |
104 | def add_text(self, tag, text, step=None):
105 | if step is None:
106 | step = self.global_step
107 | with self.writers[self.default_writer].as_default():
108 | tf.summary.text(name=tag,
109 | data=text,
110 | step=step)
111 |
112 | @ignore_exception
113 | def display_attention_heads(self, outputs: dict, tag='', step: int=None, fname: list=None):
114 | if step is None:
115 | step = self.global_step
116 | for layer in ['encoder_attention', 'decoder_attention']:
117 | for k in outputs[layer].keys():
118 | if fname is None:
119 | image = tight_grid(norm_tensor(outputs[layer][k][0])) # dim 0 of image_batch is now number of heads
120 | if k == 'Decoder_LastBlock_CrossAttention':
121 | batch_plot_path = f'{tag}_Decoder_Final_Attention'
122 | else:
123 | batch_plot_path = f'{tag}_{layer}/{k}'
124 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step)
125 | else:
126 | for j, file in enumerate(fname):
127 | image = tight_grid(
128 | norm_tensor(outputs[layer][k][j])) # dim 0 of image_batch is now number of heads
129 | if k == 'Decoder_LastBlock_CrossAttention':
130 | batch_plot_path = f'{tag}_Decoder_Final_Attention/{file.numpy().decode("utf-8")}'
131 | else:
132 | batch_plot_path = f'{tag}_{layer}/{k}/{file.numpy().decode("utf-8")}'
133 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step)
134 |
135 | @ignore_exception
136 | def display_last_attention(self, outputs, tag='', step=None, fname=None):
137 | if step is None:
138 | step = self.global_step
139 |
140 | if fname is None:
141 | image = tight_grid(norm_tensor(outputs['decoder_attention']['Decoder_LastBlock_CrossAttention'][0])) # dim 0 of image_batch is now number of heads
142 | batch_plot_path = f'{tag}_Decoder_Final_Attention'
143 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step)
144 | else:
145 | for j, file in enumerate(fname):
146 | image = tight_grid(
147 | norm_tensor(outputs['decoder_attention']['Decoder_LastBlock_CrossAttention'][j])) # dim 0 of image_batch is now number of heads
148 | batch_plot_path = f'{tag}_Decoder_Final_Attention/{file.numpy().decode("utf-8")}'
149 | self.add_image(str(batch_plot_path), tf.expand_dims(tf.expand_dims(image, 0), -1), step=step)
150 |
151 | @ignore_exception
152 | def display_mel(self, mel, tag='', step=None):
153 | if step is None:
154 | step = self.global_step
155 | img = tf.transpose(mel)
156 | figure = self.audio.display_mel(img, is_normal=True)
157 | buf = buffer_image(figure)
158 | img_tf = tf.image.decode_png(buf.getvalue(), channels=3)
159 | self.add_image(tag, tf.expand_dims(img_tf, 0), step=step)
160 |
161 | @ignore_exception
162 | def display_image(self, image, with_bar=False, figsize=None, tag='', step=None):
163 | if step is None:
164 | step = self.global_step
165 | buf = plot_image(image, with_bar=with_bar, figsize=figsize)
166 | image = tf.image.decode_png(buf.getvalue(), channels=4)
167 | image = tf.expand_dims(image, 0)
168 | self.add_image(tag=tag, image=image, step=step)
169 |
170 | @ignore_exception
171 | def display_plot1D(self, y, x=None, figsize=None, tag='', step=None):
172 | if step is None:
173 | step = self.global_step
174 | buf = plot1D(y, x=x, figsize=figsize)
175 | image = tf.image.decode_png(buf.getvalue(), channels=4)
176 | image = tf.expand_dims(image, 0)
177 | self.add_image(tag=tag, image=image, step=step)
178 |
179 | @control_frequency
180 | @ignore_exception
181 | def display_loss(self, output, tag='', plot_all=False, step=None):
182 | if step is None:
183 | step = self.global_step
184 | self.add_scalars(tag=f'{tag}/losses', dictionary=output['losses'], step=step)
185 | self.add_scalar(tag=f'{tag}/loss', scalar_value=output['loss'], step=step)
186 |
187 | @control_frequency
188 | @ignore_exception
189 | def display_scalar(self, tag, scalar_value, plot_all=False, step=None):
190 | if step is None:
191 | step = self.global_step
192 | self.add_scalar(tag=tag, scalar_value=scalar_value, step=step)
193 |
194 | @ignore_exception
195 | def display_audio(self, tag, mel, step=None, description=None):
196 | wav = tf.transpose(mel)
197 | wav = self.audio.reconstruct_waveform(wav)
198 | wav = tf.expand_dims(wav, 0)
199 | wav = tf.expand_dims(wav, -1)
200 | self.add_audio(tag, wav.numpy(), sr=self.config['sampling_rate'], step=step, description=description)
201 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def new_scaled_crossentropy(index=2, scaling=1.0):
5 | """
6 | Returns masked crossentropy with extra scaling:
7 | Scales the loss for given stop_index by stop_scaling
8 | """
9 |
10 | def masked_crossentropy(targets: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
11 | crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
12 | padding_mask = tf.math.equal(targets, 0)
13 | padding_mask = tf.math.logical_not(padding_mask)
14 | padding_mask = tf.cast(padding_mask, dtype=tf.float32)
15 | stop_mask = tf.math.equal(targets, index)
16 | stop_mask = tf.cast(stop_mask, dtype=tf.float32) * (scaling - 1.)
17 | combined_mask = padding_mask + stop_mask
18 | loss = crossentropy(targets, logits, sample_weight=combined_mask)
19 | return loss
20 |
21 | return masked_crossentropy
22 |
23 |
24 | def masked_crossentropy(targets: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
25 | crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
26 | mask = tf.math.logical_not(tf.math.equal(targets, 0))
27 | mask = tf.cast(mask, dtype=tf.int32)
28 | loss = crossentropy(targets, logits, sample_weight=mask)
29 | return loss
30 |
31 |
32 | def masked_mean_squared_error(targets: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
33 | mse = tf.keras.losses.MeanSquaredError()
34 | mask = tf.math.logical_not(tf.math.equal(targets, 0))
35 | mask = tf.cast(mask, dtype=tf.int32)
36 | mask = tf.reduce_max(mask, axis=-1)
37 | loss = mse(targets, logits, sample_weight=mask)
38 | return loss
39 |
40 |
41 | def masked_mean_absolute_error(targets: tf.Tensor, logits: tf.Tensor, mask_value=0,
42 | mask: tf.Tensor = None) -> tf.Tensor:
43 | mae = tf.keras.losses.MeanAbsoluteError()
44 | if mask is not None:
45 | mask = tf.math.logical_not(tf.math.equal(targets, mask_value))
46 | mask = tf.cast(mask, dtype=tf.int32)
47 | mask = tf.reduce_max(mask, axis=-1)
48 | loss = mae(targets, logits, sample_weight=mask)
49 | return loss
50 |
51 |
52 | def masked_binary_crossentropy(targets: tf.Tensor, logits: tf.Tensor, mask_value=-1) -> tf.Tensor:
53 | bc = tf.keras.losses.BinaryCrossentropy(reduction='none')
54 | mask = tf.math.logical_not(tf.math.equal(logits,
55 | mask_value)) # TODO: masking based on the logits requires a masking layer. But masking layer produces 0. as outputs.
56 | # Need explicit masking
57 | mask = tf.cast(mask, dtype=tf.int32)
58 | loss_ = bc(targets, logits)
59 | loss_ *= mask
60 | return tf.reduce_mean(loss_)
61 |
62 |
63 | def weighted_sum_losses(targets, pred, loss_functions, coeffs):
64 | total_loss = 0
65 | loss_vals = []
66 | for i in range(len(loss_functions)):
67 | loss = loss_functions[i](targets[i], pred[i])
68 | loss_vals.append(loss)
69 | total_loss += coeffs[i] * loss
70 | return total_loss, loss_vals
71 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def attention_score(att, mel_len, phon_len, r):
5 | """
6 | returns a tuple of scores (loc_score, sharp_score), where loc_score measures monotonicity and
7 | sharp_score measures the sharpness of attention peaks
8 | attn_weights : [N, n_heads, mel_dim, phoneme_dim]
9 | """
10 | assert len(tf.shape(att)) == 4
11 |
12 | mask = tf.range(tf.shape(att)[2])[None, :] < mel_len[:, None]
13 | mask = tf.cast(mask, tf.int32)[:, None, :] # [N, 1, mel_dim]
14 |
15 | # distance between max (jumpiness)
16 | loc_score = attention_jumps_score(att=att, mel_mask=mask, mel_len=mel_len, r=r)
17 |
18 | # variance
19 | peak_score = attention_peak_score(att, mask)
20 |
21 | # diagonality
22 | diag_score = diagonality_score(att, mel_len, phon_len)
23 |
24 | return loc_score, peak_score, 3. / diag_score
25 |
26 |
27 | def attention_jumps_score(att, mel_mask, mel_len, r):
28 | max_loc = tf.argmax(att, axis=3) # [N, n_heads, mel_max]
29 | max_loc_diff = tf.abs(max_loc[:, :, 1:] - max_loc[:, :, :-1]) # [N, h_heads, mel_max - 1]
30 | loc_score = tf.cast(max_loc_diff >= 0, tf.int32) * tf.cast(max_loc_diff <= r, tf.int32) # [N, h_heads, mel_max - 1]
31 | loc_score = tf.reduce_sum(loc_score * mel_mask[:, :, 1:], axis=-1)
32 | loc_score = loc_score / (mel_len - 1)[:, None]
33 | return tf.cast(loc_score, tf.float32)
34 |
35 |
36 | def attention_peak_score(att, mel_mask):
37 | max_loc = tf.reduce_max(att, axis=3) # [N, n_heads, mel_dim]
38 | peak_score = tf.reduce_mean(max_loc * tf.cast(mel_mask, tf.float32), axis=-1)
39 | return tf.cast(peak_score, tf.float32)
40 |
41 | def diagonality_score(att, mel_len, phon_len, diag_mask=None):
42 | if diag_mask is None:
43 | diag_mask = batch_diagonal_mask(att, mel_len, phon_len)
44 | diag_score = tf.reduce_sum(att * diag_mask, axis=(-2, -1))
45 | return diag_score
46 |
47 | def batch_diagonal_mask(att, mel_len, phon_len):
48 | batch_size = tf.shape(att)[0]
49 | mel_size = tf.shape(att)[2]
50 | phon_size = tf.shape(att)[3]
51 | diag_mask = tf.TensorArray(tf.float32, size=batch_size)
52 | for i in range(batch_size):
53 | d_mask = diagonal_mask(mel_len[i], phon_len[i], padded_shape=(mel_size, phon_size))
54 | diag_mask = diag_mask.write(i, d_mask)
55 | diag_mask = tf.cast(diag_mask.stack(), tf.float32)
56 | diag_mask = tf.expand_dims(diag_mask, 1)
57 | return diag_mask
58 |
59 |
60 | def diagonal_mask(mel_len, phon_len, padded_shape):
61 | """ exponential loss mask based on distance from euclidean diagonal"""
62 | max_m = tf.cast(mel_len, tf.int32)
63 | if max_m > padded_shape[0]: # this can happen due to rounding errors when calculating mel lengths with r>1
64 | max_m = padded_shape[0]
65 | max_n = tf.cast(phon_len, tf.int32)
66 | i = tf.tile(tf.range(max_n)[None, :], [max_m, 1]) / max_n
67 | j = tf.tile(tf.range(max_m)[:, None], [1, max_n]) / max_m
68 | diag_mask = tf.math.sqrt(tf.square(i - j))
69 | expanded_mask = tf.pad(diag_mask, [[0, padded_shape[0] - max_m], [0, padded_shape[1] - max_n]])
70 | return tf.cast(expanded_mask, tf.float32)
71 |
--------------------------------------------------------------------------------
/utils/scheduling.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def linear_function(x, x0, x1, y0, y1):
6 | m = (y1 - y0) / (x1 - x0)
7 | b = y0 - m * x0
8 | return m * x + b
9 |
10 |
11 | def piecewise_linear(step, X, Y):
12 | """
13 | Piecewise linear function.
14 |
15 | :param step: current step.
16 | :param X: list of breakpoints
17 | :param Y: list of values at breakpoints
18 | :return: value of piecewise linear function with values Y_i at step X_i
19 | """
20 | assert len(X) == len(Y)
21 | X = np.array(X)
22 | if step < X[0]:
23 | return Y[0]
24 | idx = np.where(step >= X)[0][-1]
25 | if idx == (len(Y) - 1):
26 | return Y[-1]
27 | else:
28 | return linear_function(step, X[idx], X[idx + 1], Y[idx], Y[idx + 1])
29 |
30 |
31 | def piecewise_linear_schedule(step, schedule):
32 | schedule = np.array(schedule)
33 | x_schedule = schedule[:, 0]
34 | y_schedule = schedule[:, 1]
35 | value = piecewise_linear(step, x_schedule, y_schedule)
36 | return tf.cast(value, tf.float32)
37 |
38 |
39 | def reduction_schedule(step, schedule):
40 | schedule = np.array(schedule)
41 | r = schedule[0, 0]
42 | for i in range(schedule.shape[0]):
43 | if schedule[i, 0] <= step:
44 | r = schedule[i, 1]
45 | else:
46 | break
47 | return int(r)
48 |
--------------------------------------------------------------------------------
/utils/scripts_utils.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | import argparse
3 |
4 | import tensorflow as tf
5 |
6 |
7 | def dynamic_memory_allocation():
8 | gpus = tf.config.experimental.list_physical_devices('GPU')
9 | if gpus:
10 | try:
11 | # Currently, memory growth needs to be the same across GPUs
12 | for gpu in gpus:
13 | tf.config.experimental.set_memory_growth(gpu, True)
14 | logical_gpus = tf.config.experimental.list_logical_devices('GPU')
15 | print(len(gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs')
16 | except Exception:
17 | traceback.print_exc()
18 |
19 |
20 | def basic_train_parser():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--config', dest='config', type=str)
23 | parser.add_argument('--reset_dir', dest='clear_dir', action='store_true',
24 | help="deletes everything under this config's folder.")
25 | parser.add_argument('--reset_logs', dest='clear_logs', action='store_true',
26 | help="deletes logs under this config's folder.")
27 | parser.add_argument('--reset_weights', dest='clear_weights', action='store_true',
28 | help="deletes weights under this config's folder.")
29 | return parser
30 |
--------------------------------------------------------------------------------
/utils/spectrogram_ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def mel_padding_mask(mel_batch, padding_value=0):
5 | return 1.0 - tf.cast(mel_batch == padding_value, tf.float32)
6 |
7 |
8 | def mel_lengths(mel_batch, padding_value=0):
9 | mask = mel_padding_mask(mel_batch, padding_value=padding_value)
10 | mel_channels = tf.shape(mel_batch)[-1]
11 | sum_tot = tf.cast(mel_channels, tf.float32) * padding_value
12 | idxs = tf.cast(tf.reduce_sum(mask, axis=-1) != sum_tot, tf.int32)
13 | return tf.reduce_sum(idxs, axis=-1)
14 |
15 |
16 | def phoneme_lengths(phonemes, phoneme_padding=0):
17 | return tf.reduce_sum(tf.cast(phonemes != phoneme_padding, tf.int32), axis=-1)
18 |
--------------------------------------------------------------------------------
/utils/training_config_manager.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import shutil
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | import ruamel.yaml
8 |
9 | from model.models import Aligner, ForwardTransformer
10 | from utils.scheduling import reduction_schedule
11 |
12 |
13 | class TrainingConfigManager:
14 | def __init__(self, config_path: str, aligner=False):
15 | if aligner:
16 | model_kind = 'aligner'
17 | else:
18 | model_kind = 'tts'
19 | self.config_path = Path(config_path)
20 | self.model_kind = model_kind
21 | self.yaml = ruamel.yaml.YAML()
22 | self.config = self._load_config()
23 | self.git_hash = self._get_git_hash()
24 | self.data_name = self.config['data_name'] # raw data
25 | # make session names
26 | self.session_names = {'data': f"{self.config['text_settings_name']}.{self.config['audio_settings_name']}"}
27 | self.session_names['aligner'] = f"{self.config['aligner_settings_name']}.{self.session_names['data']}"
28 | self.session_names['tts'] = f"{self.config['tts_settings_name']}.{self.config['aligner_settings_name']}"
29 | # create paths
30 | self.wav_directory = Path(self.config['wav_directory'])
31 | self.data_dir = Path(f"{self.config['train_data_directory']}.{self.data_name}")
32 | self.metadata_path = Path(self.config['metadata_path'])
33 | self.base_dir = Path(self.config['log_directory']) / self.data_name / self.session_names[model_kind]
34 | self.log_dir = self.base_dir / 'logs'
35 | self.weights_dir = self.base_dir / 'weights'
36 | self.train_metadata_path = self.data_dir / f"train_metadata.{self.config['text_settings_name']}.txt"
37 | self.valid_metadata_path = self.data_dir / f"valid_metadata.{self.config['text_settings_name']}.txt"
38 | self.phonemized_metadata_path = self.data_dir / f"phonemized_metadata.{self.config['text_settings_name']}.txt"
39 | self.mel_dir = self.data_dir / f"mels.{self.config['audio_settings_name']}"
40 | self.pitch_dir = self.data_dir / f"pitch.{self.config['audio_settings_name']}"
41 | self.duration_dir = self.data_dir / f"durations.{self.session_names['aligner']}"
42 | self.pitch_per_char = self.data_dir / f"char_pitch.{self.session_names['aligner']}"
43 | # training parameters
44 | self.learning_rate = np.array(self.config['learning_rate_schedule'])[0, 1].astype(np.float32)
45 | if model_kind == 'aligner':
46 | self.max_r = np.array(self.config['reduction_factor_schedule'])[0, 1].astype(np.int32)
47 | self.stop_scaling = self.config.get('stop_loss_scaling', 1.)
48 |
49 | def _load_config(self):
50 | all_config = {}
51 | with open(str(self.config_path), 'rb') as session_yaml:
52 | session_config = self.yaml.load(session_yaml)
53 | for key in ['paths', 'naming', 'training_data_settings','audio_settings',
54 | 'text_settings', f'{self.model_kind}_settings']:
55 | all_config.update(session_config[key])
56 | return all_config
57 |
58 | @staticmethod
59 | def _get_git_hash():
60 | try:
61 | return subprocess.check_output(['git', 'describe', '--always']).strip().decode()
62 | except Exception as e:
63 | print(f'WARNING: could not retrieve git hash. {e}')
64 |
65 | def _check_hash(self):
66 | try:
67 | git_hash = subprocess.check_output(['git', 'describe', '--always']).strip().decode()
68 | if self.config['git_hash'] != git_hash:
69 | print(f"WARNING: git hash mismatch. Current: {git_hash}. Training config hash: {self.config['git_hash']}")
70 | except Exception as e:
71 | print(f'WARNING: could not check git hash. {e}')
72 |
73 | @staticmethod
74 | def _print_dict_values(values, key_name, level=0, tab_size=2):
75 | tab = level * tab_size * ' '
76 | print(tab + '-', key_name, ':', values)
77 |
78 | def _print_dictionary(self, dictionary, recursion_level=0):
79 | for key in dictionary.keys():
80 | if isinstance(key, dict):
81 | recursion_level += 1
82 | self._print_dictionary(dictionary[key], recursion_level)
83 | else:
84 | self._print_dict_values(dictionary[key], key_name=key, level=recursion_level)
85 |
86 | def print_config(self):
87 | print('\nCONFIGURATION', self.session_names[self.model_kind])
88 | self._print_dictionary(self.config)
89 |
90 | def update_config(self):
91 | self.config['git_hash'] = self.git_hash
92 | self.config['automatic'] = True
93 |
94 | def get_model(self, ignore_hash=False):
95 | if not ignore_hash:
96 | self._check_hash()
97 | if self.model_kind == 'aligner':
98 | return Aligner.from_config(self.config, max_r=self.max_r)
99 | else:
100 | return ForwardTransformer.from_config(self.config)
101 |
102 | def compile_model(self, model, beta_1=0.9, beta_2=0.98):
103 | optimizer = tf.keras.optimizers.Adam(self.learning_rate,
104 | beta_1=beta_1,
105 | beta_2=beta_2,
106 | epsilon=1e-9)
107 | if self.model_kind == 'aligner':
108 | model._compile(stop_scaling=self.stop_scaling, optimizer=optimizer)
109 | else:
110 | model._compile(optimizer=optimizer)
111 |
112 | def dump_config(self):
113 | self.update_config()
114 | with open(self.base_dir / f"config.yaml", 'w') as model_yaml:
115 | self.yaml.dump(self.config, model_yaml)
116 |
117 | def create_remove_dirs(self, clear_dir=False, clear_logs=False, clear_weights=False):
118 | self.base_dir.mkdir(exist_ok=True, parents=True)
119 | self.data_dir.mkdir(exist_ok=True)
120 | self.pitch_dir.mkdir(exist_ok=True)
121 | self.pitch_per_char.mkdir(exist_ok=True)
122 | self.mel_dir.mkdir(exist_ok=True)
123 | self.duration_dir.mkdir(exist_ok=True)
124 | if clear_dir:
125 | delete = input(f'Delete {self.log_dir} AND {self.weights_dir}? (y/[n])')
126 | if delete == 'y':
127 | shutil.rmtree(self.log_dir, ignore_errors=True)
128 | shutil.rmtree(self.weights_dir, ignore_errors=True)
129 | if clear_logs:
130 | delete = input(f'Delete {self.log_dir}? (y/[n])')
131 | if delete == 'y':
132 | shutil.rmtree(self.log_dir, ignore_errors=True)
133 | if clear_weights:
134 | delete = input(f'Delete {self.weights_dir}? (y/[n])')
135 | if delete == 'y':
136 | shutil.rmtree(self.weights_dir, ignore_errors=True)
137 | self.log_dir.mkdir(exist_ok=True)
138 | self.weights_dir.mkdir(exist_ok=True)
139 |
140 | def load_model(self, checkpoint_path: str = None, verbose=True):
141 | model = self.get_model()
142 | self.compile_model(model)
143 | ckpt = tf.train.Checkpoint(net=model)
144 | manager = tf.train.CheckpointManager(ckpt, self.weights_dir,
145 | max_to_keep=None)
146 | if checkpoint_path:
147 | ckpt.restore(checkpoint_path)
148 | if verbose:
149 | print(f'restored weights from {checkpoint_path} at step {model.step}')
150 | else:
151 | if manager.latest_checkpoint is None:
152 | print(f'WARNING: could not find weights file. Trying to load from \n {self.weights_dir}.')
153 | print('Edit config to point at the right log directory.')
154 | ckpt.restore(manager.latest_checkpoint)
155 | if verbose:
156 | print(f'restored weights from {manager.latest_checkpoint} at step {model.step}')
157 | if self.model_kind == 'aligner':
158 | reduction_factor = reduction_schedule(model.step, self.config['reduction_factor_schedule'])
159 | model.set_constants(reduction_factor=reduction_factor)
160 | return model
161 |
--------------------------------------------------------------------------------
/utils/vec_ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def norm_tensor(tensor):
5 | return tf.math.divide(
6 | tf.math.subtract(
7 | tensor,
8 | tf.math.reduce_min(tensor)
9 | ),
10 | tf.math.subtract(
11 | tf.math.reduce_max(tensor),
12 | tf.math.reduce_min(tensor)
13 | )
14 | )
15 |
--------------------------------------------------------------------------------