├── .gitignore ├── README.md ├── app.py ├── constants └── hparams.py ├── model ├── __init__.py ├── feeder.py ├── helpers.py ├── modules.py ├── networks.py └── tacotron.py ├── preprocess.py ├── requirements.txt ├── signal_proc ├── __init__.py ├── audio.py └── synthesizer.py ├── test.py ├── text ├── __init__.py ├── character_set.py ├── mm_num2word.py └── tokenizer.py ├── train.py └── utils ├── __init__.py ├── logger.py └── plotter.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | data/* 127 | .DS_Store 128 | .vscode 129 | trained_* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Myanmar End-to-End Text-to-Speech 2 | 3 | This is the development of a Myanmar Text-to-Speech system with the famous End-to-End Speech Synthesis Model, Tacotron. It is a part of a thesis for B.E. Degree that I've been assigned at Yangon Technological University. My supervisor was **Dr. Yuzana Win** and she guided me throughout this development. 4 | 5 | ## License 6 | 7 | This work is licensed under the Creative Commons Attribution-NonCommercial-Share Alike 4.0 International (CC BY-NC-SA 4.0) License. View detailed info of the [license](http://creativecommons.org/licenses/by-nc-sa/4.0/). 8 | 9 | ဤ myanmar-tts ကို educational purpose များအတွက် လွတ်လပ်စွာ အသုံးပြုနိုင်သော်လည်း commercial use case များအတွက် အသုံးပြုခွင့် ပေးမထားပါ။ 10 | 11 | ## Corpus 12 | 13 | **[Base Technology, Expa.Ai (Myanmar)](https://expa.ai)** kindly provided Myanmar text corpus and their amazing tool for creating speech corpus. 14 | 15 | Speech corpus (mmSpeech as I call it) is created solely on my own with a recorder tool (as previously mentioned) and it currently contains over 5,000 recorded `` pairs. I intend to upload the created corpus on some channel in future. 16 | 17 | ## Instructions 18 | 19 | ### Installing dependencies 20 | 21 | 1. Install Python 3 22 | 2. Install [TensorFlow](https://www.tensorflow.org/install/) 23 | 3. Install a number of modules 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | 29 | ### Preparing Text and Audio Dataset 30 | 31 | 1. First of all, the corpus should reside in `~/mm-tts`, although it is not a **must** and can easily be changed by a command line argument. 32 | ``` 33 | mm-tts 34 | | mmSpeech 35 | | metadata.csv 36 | | wavs 37 | ``` 38 | 39 | 2. **Preprocess the data** 40 | ``` 41 | python3 preprocess.py 42 | ``` 43 | After it is done, you should see the outputs in `~/mm-tts/training/` 44 | 45 | 46 | ### Training 47 | 48 | ``` 49 | python3 train.py 50 | ``` 51 | 52 | If you want to restore the step from a checkpoint 53 | ``` 54 | python3 train.py --restore_step Number 55 | ``` 56 | 57 | 58 | ### Evaluation 59 | 60 | There are some sentences defined in test.py, you may test them out with the trained model to see how good the current model is. 61 | ``` 62 | python3 test.py --checkpoint /path/to/checkpoint 63 | ``` 64 | 65 | 66 | ### Testing with Custom Inputs 67 | 68 | There is a simple app implemented to try out the trained models for their performance. 69 | ``` 70 | python3 app.py --checkpoint /path/to/checkpoint 71 | ``` 72 | This will create a simple web app listening at port 4000 unless you specify. 73 | Open up your browser and go to `http://localhost:4000`, you should see a simple interface with a text input to get the text from the user. 74 | 75 | 76 | ### Pretrained Model 77 | 78 | * [mmSpeech 150K Steps](https://drive.google.com/open?id=1P3JQYjGNoPbNykOg4-45LPUsZGuWkAd7) 79 | 80 | 81 | ### Generated Audio Samples 82 | 83 | Generated Samples are available on SoundCloud 84 | 85 | * [mmSpeech 150K Steps](https://soundcloud.com/htoo-pyae-466960846/sets/mmspeech-outputs) 86 | 87 | 88 | ### Notes 89 | 90 | * Google Colab which gives excellent GPU access was used for training this model. 91 | * On average, each step tooks about 1.6 seconds and at peak, each step took about 1.2 and sometimes 1.1 seconds. 92 | * For my thesis, I have trained this model for 150,000 steps (took me about a week). 93 | 94 | 95 | ### Loss Curve 96 | 97 | Below is the produced loss curves from training mmSpeech for 150,000 Steps. 98 | 99 | ![Loss](https://user-images.githubusercontent.com/34838719/65132116-7353e080-da26-11e9-9299-a08883811f47.png) 100 | 101 | 102 | ### Alignment Plot 103 | 104 | ![Alignment Plot](https://user-images.githubusercontent.com/34838719/65257737-afbb3580-db27-11e9-9046-e9a9931c55d3.gif) 105 | 106 | 107 | ### References 108 | 109 | * [Tacotron: Towards End-to-End Speech Synthesis](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=2&cad=rja&uact=8&ved=2ahUKEwiBjuL828vkAhWh6nMBHYccCdYQFjABegQIABAB&url=https%3A%2F%2Farxiv.org%2Fabs%2F1703.10135&usg=AOvVaw0_KT-Hbe9h_egPMynMsJOM) 110 | * [keithito/tacotron](https://github.com/keithito/tacotron) 111 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import falcon 5 | 6 | from signal_proc.synthesizer import Synthesizer 7 | from constants.hparams import Hyperparams as hparams 8 | 9 | 10 | html_body = ''' 11 | 12 | 13 | 14 | 15 | 16 | MM-TTS 17 | 22 | 23 | 24 | 25 | 26 |
27 |
28 |

Myanmar TTS

29 |
30 |
31 |
32 |
33 |
34 |
35 |

End-to-End Speech Synthesis

36 |
37 |

You can type any Burmese sentences in Unicode and the model will try to 38 | synthesize the speech based on your inputs.
39 | Synthesizing process may take a few seconds. 40 |

41 |
42 |
43 | 44 | 45 |
46 |
47 |
48 |
49 |
50 |

51 |
52 |
53 |
54 | 55 |
56 |
57 |
58 |
59 | 60 |
61 |
62 |
63 |
64 |
65 |
66 | 98 | 99 | 100 | 101 | 102 | ''' 103 | 104 | 105 | class UIResource: 106 | def on_get(self, req, res): 107 | res.content_type = 'text/html' 108 | res.body = html_body 109 | 110 | 111 | class SynthesisResource: 112 | def on_get(self, req, res): 113 | inputTxt = req.params.get('text') 114 | if not inputTxt: 115 | raise falcon.HTTPBadRequest() 116 | res.data = synthesizer.synthesize(inputTxt) 117 | res.content_type = 'audio/wav' 118 | 119 | 120 | synthesizer = Synthesizer() 121 | api = falcon.API() 122 | api.add_route('/synthesize', SynthesisResource()) 123 | api.add_route('/', UIResource()) 124 | 125 | 126 | if __name__ == '__main__': 127 | from wsgiref import simple_server 128 | 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--checkpoint', required=True, help='Full path to model checkpoint') 131 | parser.add_argument('--port', type=int, default=4000) 132 | args = parser.parse_args() 133 | 134 | synthesizer.init(args.checkpoint) 135 | 136 | print('Serving on port %d' % args.port) 137 | simple_server.make_server('0.0.0.0', args.port, api).serve_forever() 138 | -------------------------------------------------------------------------------- /constants/hparams.py: -------------------------------------------------------------------------------- 1 | class Hyperparams: 2 | """ Hyper parameters """ 3 | # Signal 4 | num_mels = 80 5 | num_freq = 1025 6 | sample_rate = 20000 7 | frame_length = 0.05 8 | frame_shift = 0.0125 9 | preemphasis = 0.97 10 | min_db = -100 11 | ref_db = 20 12 | 13 | # parameters 14 | n_fft = (num_freq - 1) * 2 15 | hop_length = int(frame_shift * sample_rate) 16 | win_length = int(frame_length * sample_rate) 17 | 18 | max_iters = 200 19 | griffin_lim_iters = 60 20 | power = 1.5 21 | 22 | # for training 23 | batch_size = 32 24 | learning_rate_decay = True 25 | initial_lr = 0.002 26 | adam_beta_1 = 0.9 27 | adam_beta_2 = 0.999 28 | 29 | # Model 30 | outputs_per_step = 5 31 | embed_depth = 256 32 | prenet_depths = [256, 128] 33 | encoder_depth = 256 34 | postnet_depth = 256 35 | attention_depth = 256 36 | decoder_depth = 256 37 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpbyte/myanmar-tts/fa3adaa7291b459f8cae67b96098dc05d1a7fdd2/model/__init__.py -------------------------------------------------------------------------------- /model/feeder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import traceback 5 | import threading 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from constants.hparams import Hyperparams as hparams 11 | from text.tokenizer import text_to_sequence 12 | from utils.logger import log 13 | 14 | batches_per_group = 32 15 | pad_val = 0 16 | 17 | class DataFeeder(threading.Thread): 18 | """ Feeds batches of data on a queue in a background thread """ 19 | 20 | def __init__(self, coordinator, metadata_filename): 21 | super(DataFeeder, self).__init__() 22 | self._coordi = coordinator 23 | self._hparams = hparams 24 | self._offset = 0 25 | 26 | # load metadata 27 | self._data_dir = os.path.dirname(metadata_filename) 28 | with open(metadata_filename, encoding='utf-8') as f: 29 | self._metadata = [line.strip().split('|') for line in f] 30 | hours = sum((int(x[2]) for x in self._metadata)) * hparams.frame_shift / 3600 31 | log('Loaded metadata for %d examples (%.2f hours)' % (len(self._metadata), hours)) 32 | 33 | # create placeholders for inputs and targets 34 | # didn't specify batch size bcuz of the need to feed different sized batches at eval time 35 | self._placeholders = [ 36 | tf.compat.v1.placeholder(tf.int32, [None, None], 'inputs'), 37 | tf.compat.v1.placeholder(tf.int32, [None], 'input_lengths'), 38 | tf.compat.v1.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets'), 39 | tf.compat.v1.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets') 40 | ] 41 | 42 | # create a queue for buffering data 43 | queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32], name='input_queue') 44 | self._enqueue_op = queue.enqueue(self._placeholders) 45 | self.inputs, self.input_lengths, self.mel_targets, self.linear_targets = queue.dequeue() 46 | self.inputs.set_shape(self._placeholders[0].shape) 47 | self.input_lengths.set_shape(self._placeholders[1].shape) 48 | self.mel_targets.set_shape(self._placeholders[2].shape) 49 | self.linear_targets.set_shape(self._placeholders[3].shape) 50 | 51 | 52 | def start_in_session(self, session): 53 | self._session = session 54 | # starting the thread which in turn, invokes the run() method 55 | self.start() 56 | 57 | 58 | def run(self): 59 | # perform queueing operations until it should stop 60 | try: 61 | while not self._coordi.should_stop(): 62 | # if it shoudn't stop, enqueue the next batches 63 | self.enqueue_next_group() 64 | except Exception as e: 65 | # print the exception occurred 66 | traceback.print_exc() 67 | # tell the coordinator to stop 68 | self._coordi.request_stop(e) 69 | 70 | 71 | def enqueue_next_group(self): 72 | """ Enqueue next group of batches into the queue """ 73 | 74 | start = time.time() 75 | 76 | # read a group of examples 77 | nb_batches = self._hparams.batch_size 78 | r = self._hparams.outputs_per_step 79 | examples = [self.get_next_example() for i in range(nb_batches * batches_per_group)] 80 | 81 | # sort examples based on their length for efficiency 82 | examples.sort(key=lambda x: x[-1]) 83 | batches = [examples[i:i+nb_batches] for i in range(0, len(examples), nb_batches)] 84 | random.shuffle(batches) 85 | 86 | log('Generated %d batches of size %d in %0.3f sec' % (len(batches), nb_batches, time.time() - start)) 87 | 88 | for b in batches: 89 | # make a feeding dictionary of iterables with the placeholders mapping to the input data 90 | # { 91 | # (inputs => input_data_text), (input_lengths => input_data_lengths), 92 | # (mel_targets => input_mel_targets), (linear_targets => input_linear_targets) 93 | # } 94 | feed_dict = dict(zip(self._placeholders, _prepare_batch(b, r))) 95 | # run the session with the fed placeholders 96 | self._session.run(self._enqueue_op, feed_dict=feed_dict) 97 | 98 | 99 | def get_next_example(self): 100 | """ 101 | Get a single example (input, mel_target, linear_target, cost) from metadata. 102 | This read the metadata file by offsetting the position. 103 | """ 104 | 105 | if self._offset >= len(self._metadata): 106 | # if somehow the offset gets larger than metadata size, 107 | # set the offset back to 0 and shuffle the metadata 108 | self._offset = 0 109 | random.shuffle(self._metadata) 110 | 111 | meta = self._metadata[self._offset] 112 | self._offset += 1 113 | 114 | text = meta[3] 115 | # get the normalized sequence of text 116 | input_data = np.asarray(text_to_sequence(text), dtype=np.int32) 117 | # load the linear spectrogram.npy 118 | linear_target = np.load(os.path.join(self._data_dir, meta[0])) 119 | # laod the mel-spectrogram.npy 120 | mel_target = np.load(os.path.join(self._data_dir, meta[1])) 121 | 122 | return (input_data, mel_target, linear_target, len(linear_target)) 123 | 124 | 125 | # helper functions 126 | def _prepare_batch(batch, outputs_per_step): 127 | """ 128 | Having constant input length is essential for training, 129 | so we need to pad each of them if needed 130 | """ 131 | random.shuffle(batch) 132 | # since a single example looks like this (input, mel_target, linear_target, cost), 133 | # x[0] => inputs, x[1] => mel_target, x[2] => linear_target 134 | inputs = _get_padded_inputs([x[0] for x in batch]) 135 | input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32) 136 | 137 | mel_targets = _get_padded_targets([x[1] for x in batch], outputs_per_step) 138 | linear_targets = _get_padded_targets([x[2] for x in batch], outputs_per_step) 139 | 140 | return (inputs, input_lengths, mel_targets, linear_targets) 141 | 142 | 143 | def _get_padded_inputs(inputs): 144 | """ join a sequence of arrays of padded inputs """ 145 | max_len = max((len(x) for x in inputs)) 146 | return np.stack([_pad_input(x, max_len) for x in inputs]) 147 | 148 | 149 | def _get_padded_targets(targets, alignment): 150 | """ join a sequence of arrays of padded targets """ 151 | max_len = max((len(t) for t in targets)) + 1 152 | # make the max_len to be a multiple of outputs_per_step 153 | return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets]) 154 | 155 | 156 | def _pad_input(x, max_len): 157 | """ pad the input (whose length is lower than the max_len) with zeros """ 158 | return np.pad(x, (0, max_len - x.shape[0]), mode='constant', constant_values=pad_val) 159 | 160 | 161 | def _pad_target(t, max_len): 162 | """ pad the target (whose length is lower than the max_len) with zeros """ 163 | return np.pad(t, [(0, max_len - t.shape[0]), (0, 0)], mode='constant', constant_values=pad_val) 164 | 165 | 166 | def _round_up(x, multiple): 167 | """ get the rounded max_len for a target """ 168 | remainder = x % multiple 169 | return x if remainder == 0 else x + multiple - remainder 170 | -------------------------------------------------------------------------------- /model/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | # Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper 6 | class TacoTestHelper(tf.contrib.seq2seq.Helper): 7 | def __init__(self, batch_size, output_dim, r): 8 | with tf.name_scope('TacoTestHelper'): 9 | self._batch_size = batch_size 10 | self._output_dim = output_dim 11 | self._end_token = tf.tile([0.0], [output_dim * r]) 12 | 13 | @property 14 | def batch_size(self): 15 | return self._batch_size 16 | 17 | @property 18 | def sample_ids_shape(self): 19 | return tf.TensorShape([]) 20 | 21 | @property 22 | def sample_ids_dtype(self): 23 | return np.int32 24 | 25 | def initialize(self, name=None): 26 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) 27 | 28 | def sample(self, time, outputs, state, name=None): 29 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them 30 | 31 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 32 | '''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.''' 33 | with tf.name_scope('TacoTestHelper'): 34 | finished = tf.reduce_all(tf.equal(outputs, self._end_token), axis=1) 35 | # Feed last output frame as next input. outputs is [N, output_dim * r] 36 | next_inputs = outputs[:, -self._output_dim:] 37 | return (finished, next_inputs, state) 38 | 39 | 40 | class TacoTrainingHelper(tf.contrib.seq2seq.Helper): 41 | def __init__(self, inputs, targets, output_dim, r): 42 | # inputs is [N, T_in], targets is [N, T_out, D] 43 | with tf.name_scope('TacoTrainingHelper'): 44 | self._batch_size = tf.shape(inputs)[0] 45 | self._output_dim = output_dim 46 | 47 | # Feed every r-th target frame as input 48 | self._targets = targets[:, r-1::r, :] 49 | 50 | # Use full length for every target because we don't want to mask the padding frames 51 | num_steps = tf.shape(self._targets)[1] 52 | self._lengths = tf.tile([num_steps], [self._batch_size]) 53 | 54 | @property 55 | def batch_size(self): 56 | return self._batch_size 57 | 58 | @property 59 | def sample_ids_shape(self): 60 | return tf.TensorShape([]) 61 | 62 | @property 63 | def sample_ids_dtype(self): 64 | return np.int32 65 | 66 | def initialize(self, name=None): 67 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) 68 | 69 | def sample(self, time, outputs, state, name=None): 70 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them 71 | 72 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 73 | with tf.name_scope(name or 'TacoTrainingHelper'): 74 | finished = (time + 1 >= self._lengths) 75 | next_inputs = self._targets[:, time, :] 76 | return (finished, next_inputs, state) 77 | 78 | 79 | def _go_frames(batch_size, output_dim): 80 | '''Returns all-zero frames for a given batch size and output dimension''' 81 | return tf.tile([[0.0]], [batch_size, output_dim]) 82 | 83 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import GRUCell, RNNCell 3 | 4 | def prenet(inputs, is_training, layer_sizes, scope=None): 5 | """ Pre-Net 6 | 7 | FC-256-ReLU -> Dropout(0.5) -> FC-128-ReLU -> Dropout(0.5) 8 | """ 9 | x = inputs 10 | drop_rate = 0.5 if is_training else 0.0 11 | with tf.variable_scope(scope or 'prenet'): 12 | for i, size in enumerate(layer_sizes): 13 | dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1)) 14 | x = tf.layers.dropout(dense, rate=drop_rate, training=is_training, name='dropout_%d' % (i+1)) 15 | return x 16 | 17 | 18 | def encoder_cbhg(inputs, input_lengths, is_training, depth): 19 | input_channels = inputs.get_shape()[2] 20 | return cbhg( 21 | inputs, 22 | input_lengths, 23 | is_training, 24 | scope='encoder_cbhg', 25 | K=16, 26 | projections=[128, input_channels], 27 | depth=depth) 28 | 29 | 30 | def post_cbhg(inputs, input_dim, is_training, depth): 31 | return cbhg( 32 | inputs, 33 | None, 34 | is_training, 35 | scope='post_cbhg', 36 | K=8, 37 | projections=[256, input_dim], 38 | depth=depth) 39 | 40 | 41 | def cbhg(inputs, input_lengths, is_training, scope, K, projections, depth): 42 | with tf.variable_scope(scope): 43 | with tf.variable_scope('conv_bank'): 44 | # Convolution bank: concatenate on the last axis to stack channels from all convolutions 45 | conv_outputs = tf.concat( 46 | [conv1d(inputs, k, 128, tf.nn.relu, is_training, 'conv1d_%d' % k) for k in range(1, K+1)], 47 | axis=-1 48 | ) 49 | 50 | # Maxpooling: 51 | maxpool_output = tf.layers.max_pooling1d( 52 | conv_outputs, 53 | pool_size=2, 54 | strides=1, 55 | padding='same') 56 | 57 | # Two projection layers: 58 | proj1_output = conv1d(maxpool_output, 3, projections[0], tf.nn.relu, is_training, 'proj_1') 59 | proj2_output = conv1d(proj1_output, 3, projections[1], None, is_training, 'proj_2') 60 | 61 | # Residual connection: 62 | highway_input = proj2_output + inputs 63 | 64 | half_depth = depth // 2 65 | assert half_depth*2 == depth, 'encoder and postnet depths must be even.' 66 | 67 | # Handle dimensionality mismatch: 68 | if highway_input.shape[2] != half_depth: 69 | highway_input = tf.layers.dense(highway_input, half_depth) 70 | 71 | # 4-layer HighwayNet: 72 | for i in range(4): 73 | highway_input = highwaynet(highway_input, 'highway_%d' % (i+1), half_depth) 74 | rnn_input = highway_input 75 | 76 | # Bidirectional RNN 77 | outputs, states = tf.nn.bidirectional_dynamic_rnn( 78 | GRUCell(half_depth), 79 | GRUCell(half_depth), 80 | rnn_input, 81 | sequence_length=input_lengths, 82 | dtype=tf.float32) 83 | return tf.concat(outputs, axis=2) # Concat forward and backward 84 | 85 | 86 | def highwaynet(inputs, scope, depth): 87 | with tf.variable_scope(scope): 88 | H = tf.layers.dense( 89 | inputs, 90 | units=depth, 91 | activation=tf.nn.relu, 92 | name='H') 93 | T = tf.layers.dense( 94 | inputs, 95 | units=depth, 96 | activation=tf.nn.sigmoid, 97 | name='T', 98 | bias_initializer=tf.constant_initializer(-1.0)) 99 | return H * T + inputs * (1.0 - T) 100 | 101 | 102 | def conv1d(inputs, kernel_size, channels, activation, is_training, scope): 103 | with tf.variable_scope(scope): 104 | conv1d_output = tf.layers.conv1d( 105 | inputs, 106 | filters=channels, 107 | kernel_size=kernel_size, 108 | activation=activation, 109 | padding='same') 110 | return tf.layers.batch_normalization(conv1d_output, training=is_training) 111 | 112 | # Other Modules 113 | 114 | class DecoderPrenetWrapper(RNNCell): 115 | '''Runs RNN inputs through a prenet before sending them to the cell.''' 116 | def __init__(self, cell, is_training, layer_sizes): 117 | super(DecoderPrenetWrapper, self).__init__() 118 | self._cell = cell 119 | self._is_training = is_training 120 | self._layer_sizes = layer_sizes 121 | 122 | @property 123 | def state_size(self): 124 | return self._cell.state_size 125 | 126 | @property 127 | def output_size(self): 128 | return self._cell.output_size 129 | 130 | def call(self, inputs, state): 131 | prenet_out = prenet(inputs, self._is_training, self._layer_sizes, scope='decoder_prenet') 132 | return self._cell(prenet_out, state) 133 | 134 | def zero_state(self, batch_size, dtype): 135 | return self._cell.zero_state(batch_size, dtype) 136 | 137 | 138 | class ConcatOutputAndAttentionWrapper(RNNCell): 139 | """Concatenates RNN cell output with the attention context vector. 140 | 141 | This is expected to wrap a cell wrapped with an AttentionWrapper constructed with 142 | attention_layer_size=None and output_attention=False. Such a cell's state will include an 143 | "attention" field that is the context vector. 144 | """ 145 | def __init__(self, cell): 146 | super(ConcatOutputAndAttentionWrapper, self).__init__() 147 | self._cell = cell 148 | 149 | @property 150 | def state_size(self): 151 | return self._cell.state_size 152 | 153 | @property 154 | def output_size(self): 155 | return self._cell.output_size + self._cell.state_size.attention 156 | 157 | def call(self, inputs, state): 158 | output, res_state = self._cell(inputs, state) 159 | return tf.concat([output, res_state.attention], axis=-1), res_state 160 | 161 | def zero_state(self, batch_size, dtype): 162 | return self._cell.zero_state(batch_size, dtype) 163 | 164 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper 3 | from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper 4 | 5 | from model.modules import (prenet, encoder_cbhg, post_cbhg, 6 | DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper) 7 | from model.helpers import (TacoTrainingHelper, TacoTestHelper) 8 | from constants.hparams import Hyperparams as hparams 9 | from text.character_set import characters 10 | from utils.logger import log 11 | 12 | 13 | def encoder(inputs, input_lengths, is_training): 14 | """ Encoder 15 | 16 | Embeddings -> Prenet -> Encoder CBHG 17 | 18 | @param inputs int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of 19 | steps in the input time series, and values are character IDs 20 | @param input_lengths lengths of the inputs 21 | @param is_training flag for training or eval 22 | 23 | @returns outputs from the encoder 24 | """ 25 | 26 | # Character Embeddings 27 | embedding_table = tf.get_variable( 28 | 'embedding', 29 | [len(characters), hparams.embed_depth], 30 | dtype=tf.float32, 31 | initializer=tf.truncated_normal_initializer(stddev=0.5) 32 | ) 33 | 34 | embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, embed_depth=256] 35 | 36 | # Encoder 37 | prenet_outputs = prenet(embedded_inputs, is_training, hparams.prenet_depths) # [N, T_in, prenet_depths[-1]=128] 38 | 39 | encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training, hparams.encoder_depth) 40 | # [N, T_in, encoder_depth=256] 41 | 42 | log('Encoder Network ...') 43 | log(' embedding: %d' % embedded_inputs.shape[-1]) 44 | log(' prenet out: %d' % prenet_outputs.shape[-1]) 45 | log(' encoder out: %d' % encoder_outputs.shape[-1]) 46 | 47 | return encoder_outputs 48 | 49 | 50 | def decoder(inputs, encoder_outputs, is_training, batch_size, mel_targets): 51 | """ Decoder 52 | 53 | Prenet -> Attention RNN 54 | Postprocessing CBHG 55 | 56 | @param encoder_outputs outputs from the encoder wtih shape [N, T_in, prenet_depth=256] 57 | @param inputs int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of 58 | steps in the input time series, and values are character IDs 59 | @param is_training flag for training or eval 60 | @param batch_size number of samples per batch 61 | @param mel_targets float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number 62 | of steps in the output time series, M is num_mels, and values are entries in the mel 63 | @param output_cell attention cell 64 | @param decoder_init_state initial state of the decoder 65 | 66 | @return linear_outputs, mel_outputs and alignments 67 | """ 68 | 69 | if (is_training): 70 | helper = TacoTrainingHelper(inputs, mel_targets, hparams.num_mels, hparams.outputs_per_step) 71 | else: 72 | helper = TacoTestHelper(batch_size, hparams.num_mels, hparams.outputs_per_step) 73 | 74 | # Attention 75 | attention_cell = AttentionWrapper( 76 | GRUCell(hparams.attention_depth), 77 | BahdanauAttention(hparams.attention_depth, encoder_outputs), 78 | alignment_history=True, 79 | output_attention=False 80 | ) # [N, T_in, attention_depth=256] 81 | 82 | # Apply prenet before concatenation in AttentionWrapper. 83 | attention_cell = DecoderPrenetWrapper(attention_cell, is_training, hparams.prenet_depths) 84 | 85 | # Concatenate attention context vector and RNN cell output into a 2*attention_depth=512D vector. 86 | concat_cell = ConcatOutputAndAttentionWrapper(attention_cell) # [N, T_in, 2*attention_depth=512] 87 | 88 | # Decoder (layers specified bottom to top): 89 | decoder_cell = MultiRNNCell([ 90 | OutputProjectionWrapper(concat_cell, hparams.decoder_depth), 91 | ResidualWrapper(GRUCell(hparams.decoder_depth)), 92 | ResidualWrapper(GRUCell(hparams.decoder_depth)) 93 | ], state_is_tuple=True) # [N, T_in, decoder_depth=256] 94 | 95 | # Project onto r mel spectrograms (predict r outputs at each RNN step): 96 | output_cell = OutputProjectionWrapper(decoder_cell, hparams.num_mels * hparams.outputs_per_step) 97 | 98 | decoder_init_state = output_cell.zero_state(batch_size=batch_size, dtype=tf.float32) 99 | 100 | (decoder_outputs, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode( 101 | BasicDecoder(output_cell, helper, decoder_init_state), 102 | maximum_iterations=hparams.max_iters 103 | ) # [N, T_out/r, M*r] 104 | 105 | # Reshape outputs to be one output per entry 106 | mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hparams.num_mels]) # [N, T_out, M] 107 | 108 | # Add post-processing CBHG: 109 | post_outputs = post_cbhg(mel_outputs, hparams.num_mels, is_training, # [N, T_out, postnet_depth=256] 110 | hparams.postnet_depth) 111 | linear_outputs = tf.layers.dense(post_outputs, hparams.num_freq) # [N, T_out, F] 112 | 113 | # Grab alignments from the final decoder state: 114 | alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0]) 115 | 116 | log('Decoder Network ...') 117 | log(' attention out: %d' % attention_cell.output_size) 118 | log(' concat attn & out: %d' % concat_cell.output_size) 119 | log(' decoder cell out: %d' % decoder_cell.output_size) 120 | log(' decoder out (%d frames): %d' % (hparams.outputs_per_step, decoder_outputs.shape[-1])) 121 | log(' decoder out (1 frame): %d' % mel_outputs.shape[-1]) 122 | log(' postnet out: %d' % post_outputs.shape[-1]) 123 | log(' linear out: %d' % linear_outputs.shape[-1]) 124 | 125 | return linear_outputs, mel_outputs, alignments 126 | -------------------------------------------------------------------------------- /model/tacotron.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from model.networks import encoder, decoder 4 | from constants.hparams import Hyperparams as hparams 5 | from utils.logger import log 6 | 7 | 8 | class Tacotron(): 9 | """ A Complete Tacotron Model """ 10 | def __init__(self): 11 | pass 12 | 13 | 14 | def init(self, inputs, input_lengths, mel_targets=None, linear_targets=None): 15 | """ Initialize the model for inference 16 | Sets "mel_outputs", "linear_outputs", and "alignments" fields. 17 | 18 | @param inputs int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of 19 | steps in the input time series, and values are character IDs 20 | @param input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths 21 | of each sequence in inputs. 22 | @param mel_targets float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number 23 | of steps in the output time series, M is num_mels, and values are entries in the mel 24 | spectrogram. Only needed for training. 25 | @param linear_targets float32 Tensor with shape [N, T_out, F] where N is batch_size, T_out is number 26 | of steps in the output time series, F is num_freq, and values are entries in the linear 27 | spectrogram. Only needed for training. 28 | """ 29 | 30 | with tf.variable_scope('inference') as scope: 31 | is_training = linear_targets is not None 32 | batch_size = tf.shape(inputs)[0] 33 | 34 | log('----------------------------------------------------------------') 35 | log('Initialized Tacotron model with dimensions: ') 36 | 37 | # encoder 38 | encoder_outputs = encoder(inputs, input_lengths, is_training) 39 | 40 | # decoder 41 | linear_outputs, mel_outputs, alignments = decoder(inputs, encoder_outputs, is_training, batch_size, mel_targets) 42 | 43 | self.inputs = inputs 44 | self.input_lengths = input_lengths 45 | self.mel_outputs = mel_outputs 46 | self.linear_outputs = linear_outputs 47 | self.alignments = alignments 48 | self.mel_targets = mel_targets 49 | self.linear_targets = linear_targets 50 | 51 | log('----------------------------------------------------------------') 52 | 53 | 54 | def add_loss(self): 55 | """ Adding Loss to the model """ 56 | with tf.variable_scope('loss'): 57 | self.mel_loss = tf.reduce_mean(tf.abs(self.mel_targets - self.mel_outputs)) 58 | l1_loss = tf.abs(self.linear_targets - self.linear_outputs) 59 | 60 | # prioritize loss for freqeuncies under 3000 Hz 61 | n_priority_freq = int(3000 / (hparams.sample_rate * 0.5) * hparams.num_freq) 62 | 63 | self.linear_loss = 0.5 * tf.reduce_mean(l1_loss) + 0.5 * tf.reduce_mean(l1_loss[:,:,0:n_priority_freq]) 64 | self.loss = self.mel_loss + self.linear_loss 65 | 66 | 67 | def add_optimizer(self, global_step): 68 | """ Adding optimizer to the model """ 69 | with tf.variable_scope('optimizer'): 70 | if (hparams.learning_rate_decay): 71 | self.learning_rate = _learning_rate_decay(hparams.initial_lr, global_step) 72 | else: 73 | self.learning_rate = tf.convert_to_tensor(hparams.initial_lr) 74 | 75 | optimizer = tf.train.AdamOptimizer(self.learning_rate, hparams.adam_beta_1, hparams.adam_beta_2) 76 | gradients, variables = zip(*optimizer.compute_gradients(self.loss)) 77 | 78 | self.gradients = gradients 79 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) 80 | 81 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 82 | self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables), 83 | global_step=global_step) 84 | 85 | 86 | def _learning_rate_decay(initial_lr, global_step): 87 | """ Learning rate decay 88 | 89 | @param initial_lr initial learning rate 90 | @param global_step global step number 91 | """ 92 | 93 | warmup_step = 4000.0 94 | step = tf.cast(global_step + 1, dtype=tf.float32) 95 | 96 | return initial_lr * warmup_step**0.5 * tf.minimum(step * warmup_step**-1.5, step**-0.5) 97 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from multiprocessing import cpu_count 4 | from concurrent.futures import ProcessPoolExecutor 5 | from functools import partial 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from constants.hparams import Hyperparams as hparams 11 | from signal_proc import audio 12 | 13 | 14 | def prepare_audio_dataset(in_dir, out_dir, nb_workers=1, tqdm=lambda x: x): 15 | """ 16 | Preprocess the dataset of the audio files from a given input path into a given output path 17 | 18 | @type in_dir str 19 | @type out_dir str 20 | @type nb_workers int 21 | @type tqdm lambda 22 | 23 | @param in_dir directory which contains speech corpus 24 | @param out_dir directory in which the training data will be created 25 | @param nb_workers number of parallel processes 26 | @param tqdm for progress bar 27 | 28 | @rtype list 29 | @return a list of tuples describing the training examples 30 | """ 31 | 32 | executor = ProcessPoolExecutor(max_workers=nb_workers) 33 | futures = [] 34 | indx = 1 35 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f: 36 | for line in f: 37 | parts = line.strip().split(',') 38 | txt = parts[0] 39 | wav_path = os.path.join(in_dir, 'wavs', '%s' % parts[1]) 40 | futures.append(executor.submit(partial(process_utterance, out_dir, indx, wav_path, txt))) 41 | indx += 1 42 | 43 | return [future.result() for future in tqdm(futures)] 44 | 45 | 46 | def prepare_text_dataset(metadata, out_dir): 47 | """ 48 | Preprocess the dataset of the texts from a given input path into a given output path. 49 | This writes a file called train.txt as the input dataset 50 | 51 | @type metadata list 52 | @type out_dir str 53 | 54 | @param metadata text data 55 | @param out_dir output directory for the preprocessed texts 56 | """ 57 | with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: 58 | for m in metadata: 59 | f.write('|'.join([str(x) for x in m]) + '\n') 60 | 61 | 62 | def process_utterance(out_dir, index, wav_path, text): 63 | """ 64 | Preprocess a single pair and outputs both linear and mel spectrograms and a tuple about them. 65 | 66 | @type out_dir str 67 | @type index int 68 | @type wav_path str 69 | @type text str 70 | 71 | @param out_dir output directory for spectrograms 72 | @param index index for spectrogram filenames 73 | @param wav_path path to the audio file 74 | @param text the text spoken in the input audio 75 | 76 | @rtype tuple 77 | @return a (spectrogram_filename, mel_filename, n_frames, text) tuple for train.txt 78 | """ 79 | 80 | # load the audio to a numpy array 81 | wav = audio.load_audio(wav_path) 82 | 83 | # compute linear-scale spectrogram 84 | spectrogram = audio.wav_to_spectrogram(wav).astype(np.float32) 85 | n_frames = spectrogram.shape[1] 86 | 87 | # compute mel-scale spectrogram 88 | mel_spectrogram = audio.wav_to_melspectrogram(wav).astype(np.float32) 89 | 90 | # outputs the spectrograms 91 | spectrogram_filename = 'mmspeech-spec-%05d.npy' % index 92 | mel_spectrogram_filename = 'mmspeech-mel-%05d.npy' % index 93 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 94 | np.save(os.path.join(out_dir, mel_spectrogram_filename), mel_spectrogram.T, allow_pickle=False) 95 | 96 | # return the tuple 97 | return (spectrogram_filename, mel_spectrogram_filename, n_frames, text) 98 | 99 | 100 | def preprocess(args): 101 | in_dir = os.path.join(args.base_dir, args.input) 102 | out_dir = os.path.join(args.base_dir, args.output) 103 | os.makedirs(out_dir, exist_ok=True) 104 | metadata = prepare_audio_dataset(in_dir, out_dir, args.nb_workers, tqdm=tqdm) 105 | prepare_text_dataset(metadata, out_dir) 106 | # give feedback 107 | frames = sum([m[2] for m in metadata]) 108 | hours = frames * hparams.frame_shift / 3600 109 | print('Wrote %d utterances, %d frames (%.2f hours)' % (len(metadata), frames, hours)) 110 | print('Max input length: %d' % max(len(m[3]) for m in metadata)) 111 | print('Max output length %d' % max(m[2] for m in metadata)) 112 | 113 | 114 | def main(): 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--base_dir', default=os.path.expanduser('~/mm-tts')) 117 | parser.add_argument('--input', default='mmSpeech') 118 | parser.add_argument('--output', default='training') 119 | parser.add_argument('--nb_workers', type=int, default=cpu_count()) 120 | args = parser.parse_args() 121 | 122 | preprocess(args) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # TensorFlow should be installed beforehand 2 | numpy 3 | scikit-learn 4 | librosa 5 | falcon 6 | tqdm 7 | matplotlib 8 | -------------------------------------------------------------------------------- /signal_proc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpbyte/myanmar-tts/fa3adaa7291b459f8cae67b96098dc05d1a7fdd2/signal_proc/__init__.py -------------------------------------------------------------------------------- /signal_proc/audio.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | import scipy 5 | import numpy as np 6 | import librosa 7 | 8 | from constants.hparams import Hyperparams as hparams 9 | 10 | 11 | # Utils 12 | def load_audio(path): 13 | return librosa.core.load(path, sr=hparams.sample_rate)[0] 14 | 15 | 16 | def save_audio(wav, path): 17 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 18 | scipy.io.wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) 19 | 20 | 21 | # Unit Conversions 22 | def amp_to_db(x): 23 | return 20 * np.log10(np.maximum(1e-5, x)) 24 | 25 | 26 | def db_to_amp(x): 27 | return np.power(10.0, x * 0.05) 28 | 29 | 30 | def db_to_amp_tf(x): 31 | return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05) 32 | 33 | 34 | def normalize(S): 35 | return np.clip((S - hparams.min_db) / -hparams.min_db, 0, 1) 36 | 37 | 38 | def denormalize(S): 39 | return (np.clip(S, 0, 1) * -hparams.min_db) + hparams.min_db 40 | 41 | 42 | def denormalize_tf(S): 43 | return (tf.clip_by_value(S, 0, 1) * -hparams.min_db) + hparams.min_db 44 | 45 | 46 | # Signal Processing formulas 47 | def stft(y): 48 | """ Short-Time-Fourier-Transform """ 49 | return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=hparams.hop_length, win_length=hparams.win_length) 50 | 51 | 52 | def stft_tf(signals): 53 | """ Short-Time-Fourier-Transform in TensorFlow """ 54 | return tf.contrib.signal.stft(signals, hparams.win_length, hparams.hop_length, hparams.n_fft, pad_end=False) 55 | 56 | 57 | def inv_stft(y): 58 | """ Inverse-Short-Time-Fourier-Transform """ 59 | return librosa.istft(y, hop_length=hparams.hop_length, win_length=hparams.win_length) 60 | 61 | 62 | def inv_stft_tf(stfts): 63 | """ Inverse-Short-Time-Fourier-Transform in TensorFlow """ 64 | return tf.contrib.signal.inverse_stft(stfts, hparams.win_length, hparams.hop_length, hparams.n_fft) 65 | 66 | 67 | def preemphasis(x): 68 | return scipy.signal.lfilter([1, -hparams.preemphasis], [1], x) 69 | 70 | 71 | def inv_preemphasis(x): 72 | return scipy.signal.lfilter([1], [1, -hparams.preemphasis], x) 73 | 74 | 75 | # Linear-scale and Mel-scale Spectrograms 76 | def wav_to_spectrogram(y): 77 | """ waveform to spectrogram conversion """ 78 | spectro = np.abs(stft(preemphasis(y))) 79 | S = amp_to_db(spectro) - hparams.ref_db 80 | return normalize(S) 81 | 82 | 83 | def spectrogram_to_wav(S): 84 | """ spectrogram to waveform conversion """ 85 | spectro = db_to_amp(denormalize(S) + hparams.ref_db) 86 | return inv_preemphasis(Griffin_Lim(spectro ** hparams.power)) 87 | 88 | 89 | def spectrogram_to_wav_tf(S): 90 | """ spectrogram to waveform conversion in TensorFlow (without inv_preemphasis) """ 91 | spectro = db_to_amp_tf(denormalize_tf(S) + hparams.ref_db) 92 | return Griffin_Lim_tf(tf.pow(spectro, hparams.power)) 93 | 94 | 95 | def wav_to_melspectrogram(y): 96 | """ waveform to mel-scale spectrogram conversion """ 97 | mel_transform_matrix = librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels) 98 | spectro = np.abs(stft(preemphasis(y))) 99 | mel_spectro = np.dot(mel_transform_matrix, spectro) 100 | S = amp_to_db(mel_spectro) - hparams.ref_db 101 | return normalize(S) 102 | 103 | 104 | # Griffin-Lim Reconstruction Algorithm 105 | def Griffin_Lim(S): 106 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 107 | S_complex = np.abs(S).astype(np.complex) 108 | y = inv_stft(S_complex * angles) 109 | 110 | for i in range(hparams.griffin_lim_iters): 111 | angles = np.exp(1j * np.angle(stft(y))) 112 | y = inv_stft(S_complex * angles) 113 | 114 | return y 115 | 116 | 117 | def Griffin_Lim_tf(S): 118 | with tf.compat.v1.variable_scope('griffinlim'): 119 | S = tf.expand_dims(S, 0) 120 | S_complex = tf.identity(tf.cast(S, dtype=tf.complex64)) 121 | y = inv_stft_tf(S_complex) 122 | 123 | for i in range(hparams.griffin_lim_iters): 124 | est = stft_tf(y) 125 | angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64) 126 | y = inv_stft_tf(S_complex * angles) 127 | 128 | return tf.squeeze(y, 0) 129 | 130 | 131 | def find_endpoint(wav, threshold_db = -40, min_silence_sec = 0.8): 132 | window_length = int(hparams.sample_rate * min_silence_sec) 133 | hop_length = int(window_length / 4) 134 | threshold = db_to_amp(threshold_db) 135 | 136 | for x in range(hop_length, len(wav) - window_length, hop_length): 137 | if np.max(wav[x:(x + window_length)]) < threshold: 138 | return x + hop_length 139 | 140 | return len(wav) 141 | -------------------------------------------------------------------------------- /signal_proc/synthesizer.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from librosa import effects 6 | 7 | from model.tacotron import Tacotron 8 | from signal_proc import audio 9 | from text.tokenizer import text_to_sequence 10 | from constants.hparams import Hyperparams as hparams 11 | 12 | 13 | class Synthesizer(): 14 | """ Synthesizer """ 15 | 16 | def init(self, checkpoint_path): 17 | """ Initialize Synthesizer 18 | 19 | @type checkpoint_path str 20 | @param checkpoint_path path to checkpoint to be restored 21 | """ 22 | print('Constructing Tacotron Model ...') 23 | 24 | inputs = tf.compat.v1.placeholder(tf.int32, [1, None], 'inputs') 25 | input_lengths = tf.compat.v1.placeholder(tf.int32, [1], 'input_lengths') 26 | 27 | with tf.compat.v1.variable_scope('model'): 28 | self.model = Tacotron() 29 | self.model.init(inputs, input_lengths) 30 | self.wav_output = audio.spectrogram_to_wav_tf(self.model.linear_outputs[0]) 31 | 32 | print('Loading checkpoint: %s' % checkpoint_path) 33 | self.session = tf.compat.v1.Session() 34 | self.session.run(tf.compat.v1.global_variables_initializer()) 35 | saver = tf.compat.v1.train.Saver() 36 | saver.restore(self.session, checkpoint_path) 37 | 38 | 39 | def synthesize(self, text): 40 | """ Convert the text into synthesized speech 41 | 42 | @type text str 43 | @param text text to be synthesized 44 | 45 | @rtype object 46 | @return synthesized speech 47 | """ 48 | 49 | seq = text_to_sequence(text) 50 | 51 | feed_dict = { 52 | self.model.inputs: [np.asarray(seq, dtype=np.int32)], 53 | self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32) 54 | } 55 | 56 | wav = self.session.run(self.wav_output, feed_dict=feed_dict) 57 | wav = audio.inv_preemphasis(wav) 58 | wav = wav[:audio.find_endpoint(wav)] 59 | out = io.BytesIO() 60 | audio.save_audio(wav, out) 61 | 62 | return out.getvalue() 63 | 64 | 65 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | 5 | from signal_proc.synthesizer import Synthesizer 6 | from constants.hparams import Hyperparams as hparams 7 | 8 | 9 | sentences = [ 10 | 'အခုပဲ ဝယ်လာလို့ စမ်းဖွင့်ကြည့်တယ်။ နာရီဝက်ပဲ ဖွင့်ရသေးတယ်၊', 11 | 'သူစစ်ပေးထားတဲ့ ရေက မနည်းဘူး ရနေပြီ။ လေထဲမှာ ရေငွေ့ပါဝင်မှု တော်တော် များနေတာပဲ။', 12 | 'ကိုယ်တွေလည်း အသက်ရှုရင်း ရေနစ်နေကြတာ', 13 | 'မန်ယူ - ချယ်ဆီး ပွဲ ကြည့်ဖြစ်တယ်။', 14 | '"ဘောလုံးဆိုတာအလုံးကြီး" ဆိုတဲ့ စကားလိုပဲ၊ ဘာမဆို ဖြစ်သွားနိုင်တဲ့ ပွဲတွေဆိုပေမယ့်', 15 | 'ဒိုမိန်းတွေကို မြန်မာလိုပေးပြီး စမ်းထားကြတာတစ်ချို့ တွေ့ဖူးတယ်။', 16 | 'ဒါပေမယ့် သိပ် စိတ်မဝင်စားမိဘူး။' 17 | ] 18 | 19 | 20 | def get_output_base_path(checkpoint_path, out_dir): 21 | base_dir = os.path.abspath(out_dir) 22 | m = re.compile(r'.*?\.ckpt\-([0-9]+)').match(checkpoint_path) 23 | name = 'eval-%d' % int(m.group(1)) if m else 'eval' 24 | return os.path.join(base_dir, name) 25 | 26 | 27 | def test(args): 28 | synthesizer = Synthesizer() 29 | synthesizer.init(args.checkpoint) 30 | base_path = get_output_base_path(args.checkpoint, args.out_dir) 31 | 32 | for i, text in enumerate(sentences): 33 | path = '%s-%d.wav' % (base_path, i) 34 | print('Synthesizing: %s' % path) 35 | with open(path, 'wb') as f: 36 | f.write(synthesizer.synthesize(text)) 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint') 42 | parser.add_argument('--out_dir', default=os.path.expanduser('~/mm-tts')) 43 | args = parser.parse_args() 44 | 45 | test(args) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpbyte/myanmar-tts/fa3adaa7291b459f8cae67b96098dc05d1a7fdd2/text/__init__.py -------------------------------------------------------------------------------- /text/character_set.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the definition of the whole characters used in the input of the model. 3 | I excluded all the digits since they will be converted into spoken words accordingly. 4 | """ 5 | 6 | _pad = '_' 7 | _eos = '~' 8 | # from \u1000 to \u1021 9 | _consonants = 'ကခဂဃငစဆဇဈဉညဋဌဍဎဏတထဒဓနပဖဗဘမယရလဝသဿဟဠအ' 10 | # from \u1023 to \u102A (except \u1028) 11 | _vowels = 'ဣဤဥဦဧဩဪ' 12 | # from \u102B to \u103E (except \u1033,34,35,39) 13 | _signs = '\u102B\u102C\u102D\u102E\u102F\u1030\u1031\u1032\u1036\u1037\u1038\u103A\u103B\u103C\u103D\u103E' 14 | # from \u104C to \u104F 15 | _other_signs = '၌၍၎၏' 16 | # from \u104A to \u104B 17 | _punctuation = '၊။' 18 | # special characters 19 | _other_chars = '!\'(),-.:;? ' 20 | 21 | # export all of them 22 | characters = [_pad, _eos] + list(_consonants + _vowels + _signs + _other_signs + _punctuation + _other_chars) 23 | -------------------------------------------------------------------------------- /text/mm_num2word.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module has a separate git repository https://github.com/hpbyte/Myanmar_Number_to_Words.git 3 | """ 4 | 5 | import re 6 | 7 | mm_digit = { 8 | '၀': 'သုည', 9 | '၁': 'တစ်', 10 | '၂': 'နှစ်', 11 | '၃': 'သုံ:', 12 | '၄': 'လေ:', 13 | '၅': 'ငါ:', 14 | '၆': 'ခြောက်', 15 | '၇': 'ခုနှစ်', 16 | '၈': 'ရှစ်', 17 | '၉': 'ကို:' 18 | } 19 | 20 | # regular expressions 21 | rgxPh = '^(၀၁|၀၉)' 22 | rgxDate = '[၀-၉]{1,2}-[၀-၉]{1,2}-[၀-၉]{4}|[၀-၉]{1,2}\/[၀-၉]{1,2}\/[၀-၉]{4}' 23 | rgxTime = '[၀-၉]{1,2}:[၀-၉]{1,2}' 24 | rgxDec = '[၀-၉]*\.[၀-၉]*' 25 | rgxAmt = '[,၀-၉]+' 26 | 27 | 28 | def convert_digit(num): 29 | """ 30 | @type num str 31 | @param num Myanmar number 32 | @rtype str 33 | @return converted Myanmar spoken words 34 | """ 35 | 36 | converted = '' 37 | nb_digits = len(num) 38 | 39 | def check_if_zero(pos): 40 | return not num[-pos] == '၀' 41 | 42 | def hundred_thousandth_val(): 43 | n = num[:-5] 44 | return ('သိန်: ' + mm_num2word(n)) if (n[-2:] == '၀၀') else (mm_num2word(n) + 'သိန်: ') 45 | 46 | def thousandth_val(): 47 | return mm_digit[num[-4]] + ('ထောင် ' if (num[-3:] == '၀၀၀') else 'ထောင့် ') 48 | 49 | def hundredth_val(): 50 | return mm_digit[num[-3]] + ('ရာ့ ' if ( 51 | (num[-2] == '၀' and re.match(r'[၁-၉]', num[-1])) or (re.match(r'[၁-၉]', num[-2]) and num[-1] == '၀') 52 | ) else 'ရာ ') 53 | 54 | def tenth_val(): 55 | return ('' if (num[-2] == '၁') else mm_digit[num[-2]]) + ('ဆယ် ' if (num[-1] == '၀') else 'ဆယ့် ') 56 | 57 | if ((nb_digits > 5)): 58 | converted += hundred_thousandth_val() 59 | if ((nb_digits > 4) and check_if_zero(5)): 60 | converted += mm_digit[num[-5]] + 'သောင်: ' 61 | if ((nb_digits > 3) and check_if_zero(4)): 62 | converted += thousandth_val() 63 | if ((nb_digits > 2) and check_if_zero(3)): 64 | converted += hundredth_val() 65 | if ((nb_digits > 1) and check_if_zero(2)): 66 | converted += tenth_val() 67 | if ((nb_digits > 0) and check_if_zero(1)): 68 | converted += mm_digit[num[-1]] 69 | 70 | return converted 71 | 72 | 73 | def mm_num2word(num): 74 | """ 75 | Detect type of number and convert accordingly 76 | 77 | @type num str 78 | @param num Myanmar number 79 | @rtype str 80 | @return converted Myanmar spoken words 81 | """ 82 | 83 | word = '' 84 | 85 | # phone number 86 | if (re.match(r'' + rgxPh, num[:2])): 87 | word = ' '.join([(mm_digit[d] if not d == '၇' else 'ခွန်') for d in num]) 88 | # date 89 | elif (re.match(r'' + rgxDate, num)): 90 | n = re.split(r'-|/', num) 91 | word = convert_digit(n[-1]) + ' ခုနှစ် ' + convert_digit(n[1]) + ' လပိုင်: ' + convert_digit(n[0]) + ' ရက်' 92 | # time 93 | elif (re.match(r'' + rgxTime, num)): 94 | n = re.split(r':', num) 95 | word = (convert_digit(n[0]) + ' နာရီ ') + ('ခွဲ' if (n[1] == '၃၀') else (convert_digit(n[1]) + ' မိနစ်')) 96 | # decimal 97 | elif (re.match(r'' + rgxDec, num)): 98 | n = re.split(r'\.', num) 99 | word = convert_digit(n[0]) + ' ဒဿမ ' + ' '.join([mm_digit[d] for d in n[1]]) 100 | # amount 101 | elif (re.match(r'' + rgxAmt, num)): 102 | word = convert_digit(num.replace(',', '')) 103 | # default 104 | else: 105 | raise Exception('Cannot convert the provided number format!') 106 | 107 | return word 108 | 109 | 110 | def extract_num(S): 111 | """ 112 | Extract numbers from the input string 113 | 114 | @type S str 115 | @param S Myanmar sentence 116 | @rtype list 117 | @return a list of Myanmar numbers 118 | """ 119 | matchedNums = re.compile('%s|%s|%s|%s' % (rgxDate, rgxTime, rgxDec, rgxAmt)).findall(S) 120 | 121 | return matchedNums 122 | -------------------------------------------------------------------------------- /text/tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from text.character_set import characters 4 | from text.mm_num2word import extract_num, mm_num2word 5 | 6 | # mappings of each character and its id 7 | char_to_id = {s : i for i, s in enumerate(characters)} 8 | id_to_char = {i : s for i, s in enumerate(characters)} 9 | 10 | 11 | def _should_keep_char(c): 12 | """ 13 | Determines whether the input character is defined in the character set 14 | 15 | @type c str 16 | @param c char 17 | 18 | @rtype bool 19 | @return result of the check 20 | """ 21 | return c in char_to_id and c is not '_' and c is not '~' 22 | 23 | 24 | def collapse_whitespace(text): 25 | """ 26 | Combine a series of whitespaces into a single whitespace 27 | 28 | @type text str 29 | @param text input string of text 30 | 31 | @rtype str 32 | @return collapsed text string 33 | """ 34 | rgx_whitespace = re.compile(r'\s+') 35 | return re.sub(rgx_whitespace, ' ', text) 36 | 37 | 38 | def numbers_to_words(text): 39 | """ 40 | Convert numbers into corresponding spoken words 41 | 42 | @type text str 43 | @param text input string of text 44 | 45 | @rtype str 46 | @return converted spoken words 47 | """ 48 | nums = extract_num(text) 49 | for n in nums: 50 | text = text.replace(n, mm_num2word(n)) 51 | 52 | return text 53 | 54 | 55 | def normalize(text): 56 | """ 57 | Normalize text string for numbers and whitespaces 58 | 59 | @type text str 60 | @param text input string of text 61 | 62 | @rtype str 63 | @return normalized string 64 | """ 65 | text = collapse_whitespace(text) 66 | text = numbers_to_words(text) 67 | 68 | return text 69 | 70 | 71 | def text_to_sequence(text): 72 | """ 73 | Convert an input text into a sequence of ids 74 | 75 | @type text str 76 | @param text input string of text 77 | 78 | @rtype list 79 | @return list of IDs corresponding to the characters 80 | """ 81 | text = normalize(text) 82 | seq = [char_to_id[c] for c in text if _should_keep_char(c)] 83 | 84 | seq.append(char_to_id['~']) 85 | return seq 86 | 87 | 88 | def sequence_to_text(seq): 89 | """ 90 | Convert a sequence of ids into the corresponding characters 91 | 92 | @type seq list 93 | @param seq list of ids 94 | 95 | @rtype str 96 | @return a string of text 97 | """ 98 | text = '' 99 | for char_id in seq: 100 | if char_id in id_to_char: 101 | text += id_to_char[char_id] 102 | 103 | return text 104 | 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import subprocess 4 | import time 5 | from datetime import datetime 6 | import traceback 7 | import argparse 8 | 9 | import tensorflow as tf 10 | 11 | from model.feeder import DataFeeder 12 | from model.tacotron import Tacotron 13 | from signal_proc import audio 14 | from text.tokenizer import sequence_to_text 15 | from constants.hparams import Hyperparams as hparams 16 | from utils import logger, plotter, ValueWindow 17 | 18 | 19 | def add_stats(model): 20 | with tf.compat.v1.variable_scope('stats'): 21 | tf.compat.v1.summary.histogram('linear_outputs', model.linear_outputs) 22 | tf.compat.v1.summary.histogram('linear_targets', model.linear_targets) 23 | tf.compat.v1.summary.histogram('mel_outputs', model.mel_outputs) 24 | tf.compat.v1.summary.histogram('mel_targets', model.mel_targets) 25 | tf.compat.v1.summary.scalar('loss_mel', model.mel_loss) 26 | tf.compat.v1.summary.scalar('loss_linear', model.linear_loss) 27 | tf.compat.v1.summary.scalar('learning_rate', model.learning_rate) 28 | tf.compat.v1.summary.scalar('loss', model.loss) 29 | gradient_norms = [tf.norm(grad) for grad in model.gradients] 30 | tf.compat.v1.summary.histogram('gradient_norm', gradient_norms) 31 | tf.compat.v1.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)) 32 | return tf.compat.v1.summary.merge_all() 33 | 34 | 35 | def time_string(): 36 | return datetime.now().strftime('%Y-%m-%d %H:%M') 37 | 38 | 39 | def train(log_dir, args): 40 | checkpoint_path = os.path.join(log_dir, 'model.ckpt') 41 | input_path = os.path.join(args.base_dir, 'training/train.txt') 42 | 43 | logger.log('Checkpoint path: %s' % checkpoint_path) 44 | logger.log('Loading training data from: %s' % input_path) 45 | 46 | # set up DataFeeder 47 | coordi = tf.train.Coordinator() 48 | with tf.compat.v1.variable_scope('data_feeder'): 49 | feeder = DataFeeder(coordi, input_path) 50 | 51 | # set up Model 52 | global_step = tf.Variable(0, name='global_step', trainable=False) 53 | with tf.compat.v1.variable_scope('model'): 54 | model = Tacotron() 55 | model.init(feeder.inputs, feeder.input_lengths, mel_targets=feeder.mel_targets, linear_targets=feeder.linear_targets) 56 | model.add_loss() 57 | model.add_optimizer(global_step) 58 | stats = add_stats(model) 59 | 60 | # book keeping 61 | step = 0 62 | loss_window = ValueWindow(100) 63 | time_window = ValueWindow(100) 64 | saver = tf.compat.v1.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) 65 | 66 | # start training already! 67 | with tf.compat.v1.Session() as sess: 68 | try: 69 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 70 | 71 | # initialize parameters 72 | sess.run(tf.compat.v1.global_variables_initializer()) 73 | 74 | # if requested, restore from step 75 | if (args.restore_step): 76 | restore_path = '%s-%d' % (checkpoint_path, args.restore_step) 77 | saver.restore(sess, restore_path) 78 | logger.log('Resuming from checkpoint: %s' % restore_path) 79 | else: 80 | logger.log('Starting a new training!') 81 | 82 | feeder.start_in_session(sess) 83 | 84 | while not coordi.should_stop(): 85 | start_time = time.time() 86 | 87 | step, loss, opt = sess.run([global_step, model.loss, model.optimize]) 88 | 89 | time_window.append(time.time() - start_time) 90 | loss_window.append(loss) 91 | 92 | msg = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (step, time_window.average, loss, loss_window.average) 93 | 94 | logger.log(msg) 95 | 96 | if loss > 100 or math.isnan(loss): 97 | # bad situation 98 | logger.log('Loss exploded to %.05f at step %d!' % (loss, step)) 99 | raise Exception('Loss Exploded') 100 | 101 | if step % args.summary_interval == 0: 102 | # it's time to write summary 103 | logger.log('Writing summary at step: %d' % step) 104 | summary_writer.add_summary(sess.run(stats), step) 105 | 106 | if step % args.checkpoint_interval == 0: 107 | # it's time to save a checkpoint 108 | logger.log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) 109 | saver.save(sess, checkpoint_path, global_step=step) 110 | logger.log('Saving audio and alignment...') 111 | 112 | input_seq, spectrogram, alignment = sess.run([ 113 | model.inputs[0], model.linear_outputs[0], model.alignments[0] 114 | ]) 115 | 116 | # convert spectrogram to waveform 117 | waveform = audio.spectrogram_to_wav(spectrogram.T) 118 | # save it 119 | audio.save_audio(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) 120 | 121 | plotter.plot_alignment( 122 | alignment, 123 | os.path.join(log_dir, 'step-%d-align.png' % step), 124 | info='%s, %s, step=%d, loss=%.5f' % ('tacotron', time_string(), step, loss) 125 | ) 126 | 127 | logger.log('Input: %s' % sequence_to_text(input_seq)) 128 | 129 | except Exception as e: 130 | logger.log('Exiting due to exception %s' % e) 131 | traceback.print_exc() 132 | coordi.request_stop(e) 133 | 134 | 135 | def main(): 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument('--base_dir', default=os.path.expanduser('~/mm-tts')) 138 | parser.add_argument('--input', default='training/train.txt') 139 | parser.add_argument('--log_dir', default=os.path.expanduser('~/mm-tts')) 140 | parser.add_argument('--restore_step', type=int, help='Global step to restore from checkpoint.') 141 | parser.add_argument('--summary_interval', type=int, default=100, help='Steps between running summary ops.') 142 | parser.add_argument('--checkpoint_interval', type=int, default=1000, help='Steps between writing checkpoints.') 143 | args = parser.parse_args() 144 | 145 | log_dir = os.path.join(args.log_dir, 'logs-mmspeech') 146 | os.makedirs(log_dir, exist_ok=True) 147 | logger.init(os.path.join(log_dir, 'train.log')) 148 | 149 | train(log_dir, args) 150 | 151 | 152 | if __name__ == "__main__": 153 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | class ValueWindow(): 2 | def __init__(self, window_size=100): 3 | self._window_size = window_size 4 | self._values = [] 5 | 6 | def append(self, x): 7 | self._values = self._values[-(self._window_size - 1):] + [x] 8 | 9 | @property 10 | def sum(self): 11 | return sum(self._values) 12 | 13 | @property 14 | def count(self): 15 | return len(self._values) 16 | 17 | @property 18 | def average(self): 19 | return self.sum / max(1, self.count) 20 | 21 | def reset(self): 22 | self._values = [] 23 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | _log_file = None 4 | _date_time_format = '%Y-%m-%d %H:%M:%S.%f' 5 | 6 | 7 | def init(filename): 8 | global _log_file 9 | close_logging() 10 | _log_file = open(filename, 'a', encoding='utf-8') 11 | _log_file.write('\n------------------------------------------') 12 | _log_file.write('\nStarting new training---------------------') 13 | _log_file.write('\n------------------------------------------') 14 | 15 | 16 | def log(msg): 17 | print(msg) 18 | if _log_file is not None: 19 | _log_file.write('[%s] %s\n' % (datetime.now().strftime(_date_time_format)[:-3], msg)) 20 | 21 | 22 | def close_logging(): 23 | global _log_file 24 | if _log_file is not None: 25 | _log_file.close() 26 | _log_file = None 27 | -------------------------------------------------------------------------------- /utils/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_alignment(alignment, path, info=None): 7 | fig, ax = plt.subplots() 8 | im = ax.imshow( 9 | alignment, 10 | aspect='auto', 11 | origin='lower', 12 | interpolation='none') 13 | fig.colorbar(im, ax=ax) 14 | xlabel = 'Decoder timestep' 15 | 16 | if info is not None: 17 | xlabel += '\n\n' + info 18 | 19 | plt.xlabel(xlabel) 20 | plt.ylabel('Encoder timestep') 21 | plt.tight_layout() 22 | plt.savefig(path, format='png') 23 | --------------------------------------------------------------------------------