├── .gitignore
├── LICENSE
├── README.md
├── data_load.py
├── eval.py
├── fig
├── attention.gif
├── mean_loss.png
└── training_curve.png
├── harvard_sentences.txt
├── history.md
├── hyperparams.py
├── modules.py
├── networks.py
├── prepro.py
├── synthesize.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | etc/
2 | asset/
3 | data/
4 | corpora/
5 | logdir/
6 | samples/
7 | *.pyc
8 | _*
9 | *~
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A (Heavily Documented) TensorFlow Implementation of Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model
2 |
3 | ## Requirements
4 |
5 | * NumPy >= 1.11.1
6 | * TensorFlow >= 1.3
7 | * librosa
8 | * tqdm
9 | * matplotlib
10 | * scipy
11 |
12 | ## Data
13 |
14 |
15 |
16 |
17 |
18 | We train the model on three different speech datasets.
19 | 1. [LJ Speech Dataset](https://keithito.com/LJ-Speech-Dataset/)
20 | 2. [Nick Offerman's Audiobooks](https://www.audible.com.au/search?searchNarrator=Nick+Offerman)
21 | 3. [The World English Bible](https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset)
22 |
23 | LJ Speech Dataset is recently widely used as a benchmark dataset in the TTS task because it is publicly available. It has 24 hours of reasonable quality samples.
24 | Nick's audiobooks are additionally used to see if the model can learn even with less data, variable speech samples. They are 18 hours long.
25 | [The World English Bible](https://en.wikipedia.org/wiki/World_English_Bible) is a public domain update of the American Standard Version of 1901 into modern English. Its original audios are freely available [here](http://www.audiotreasure.com/webindex.htm). Kyubyong split each chapter by verse manually and aligned the segmented audio clips to the text. They are 72 hours in total. You can download them at [Kaggle Datasets](https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset).
26 |
27 | ## Training
28 | * STEP 0. Download [LJ Speech Dataset](https://keithito.com/LJ-Speech-Dataset/) or prepare your own data.
29 | * STEP 1. Adjust hyper parameters in `hyperparams.py`. (If you want to do preprocessing, set `prepro` True`.
30 | * STEP 2. Run `python train.py`. (If you set `prepro` True, run `python prepro.py` first)
31 | * STEP 3. Run `python eval.py` regularly during training.
32 |
33 | ## Sample Synthesis
34 |
35 | We generate speech samples based on [Harvard Sentences](http://www.cs.columbia.edu/~hgs/audio/harvard.html) as the original paper does. It is already included in the repo.
36 |
37 | * Run `python synthesize.py` and check the files in `samples`.
38 |
39 | ## Training Curve
40 |
41 |
42 |
43 |
44 | ## Attention Plot
45 |
46 |
47 |
48 | ## Generated Samples
49 |
50 | * [LJ at 200k steps](https://soundcloud.com/kyubyong-park/sets/tacotron_lj_200k)
51 | * [Nick at 215k steps](https://soundcloud.com/kyubyong-park/sets/tacotron_nick_215k)
52 | * [WEB at 183k steps](https://soundcloud.com/kyubyong-park/sets/tacotron_web_183k)
53 |
54 | ## Pretrained Files
55 | * Keep in mind 200k steps may not be enough for the best performance.
56 | * [LJ 200k](https://www.dropbox.com/s/8kxa3xh2vfna3s9/LJ_logdir.zip?dl=0)
57 | * [WEB 200k](https://www.dropbox.com/s/g7m6xhd350ozkz7/WEB_logdir.zip?dl=0)
58 |
59 | ## Notes
60 |
61 | * It's important to monitor the attention plots during training. If the attention plots look good (alignment looks linear), and then they look bad (the plots will look similar to what they looked like in the begining of training), then training has gone awry and most likely will need to be restarted from a checkpoint where the attention looked good, because we've learned that it's unlikely that the loss will ever recover. This deterioration of attention will correspond with a spike in the loss.
62 |
63 | * In the original paper, the authors said, "An important trick we discovered was predicting multiple, non-overlapping output frames at each decoder step" where the number of of multiple frame is the reduction factor, `r`. We originally interpretted this as predicting non-sequential frames during each decoding step `t`. Thus were using the following scheme (with `r=5`) during decoding.
64 |
65 |
66 | t frame numbers
67 | -----------------------
68 | 0 [ 0 1 2 3 4]
69 | 1 [ 5 6 7 8 9]
70 | 2 [10 11 12 13 14]
71 | ...
72 |
73 | After much experimentation, we were unable to have our model learning anything useful. We then switched to predicting `r` sequential frames during each decoding step.
74 |
75 |
76 | t frame numbers
77 | -----------------------
78 | 0 [ 0 1 2 3 4]
79 | 1 [ 5 6 7 8 9]
80 | 2 [10 11 12 13 14]
81 | ...
82 |
83 | With this setup we noticed improvements in the attention and have since kept it.
84 |
85 | * **Perhaps the most important hyperparemeter is the learning rate.** With an intitial learning rate of 0.002 we were never able to learn a clean attention, the loss would frequently explode. With an initial learning rate of 0.001 we were able to learn a clean attention and train for much longer get decernable words during synthesis.
86 | * Check other TTS models such as [DCTTS](https://github.com/kyubyong/dc_tts) or [deep voice 3](https://github.com/kyubyong/deepvoice3).
87 |
88 | ### Differences from the original paper
89 |
90 | * We use Noam style warmup and decay.
91 | * We implement gradient clipping.
92 | * Our training batches are bucketed.
93 | * After the last convolutional layer of the post-processing net, we apply an affine transformation to bring the dimensionality up to 128 from 80, because the required dimensionality of highway net is 128. In the original highway networks paper, the authors mention that the dimensionality of the input can also be increased with zero-padding, but they used the affine transformation in all their experiments. We do not know what the Tacotron authors chose.
94 |
95 |
96 | ## Papers that referenced this repo
97 |
98 | * [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention](https://arxiv.org/abs/1710.08969)
99 | * [Storytime - End to end neural networks for audiobooks](http://web.stanford.edu/class/cs224s/reports/Pierce_Freeman.pdf)
100 |
101 | Jan. 2018,
102 | Kyubyong Park & [Tommy Mulc](tmulc18@gmail.com)
103 |
--------------------------------------------------------------------------------
/data_load.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | from hyperparams import Hyperparams as hp
11 | import numpy as np
12 | import tensorflow as tf
13 | from utils import *
14 | import codecs
15 | import re
16 | import os
17 | import unicodedata
18 |
19 | def load_vocab():
20 | char2idx = {char: idx for idx, char in enumerate(hp.vocab)}
21 | idx2char = {idx: char for idx, char in enumerate(hp.vocab)}
22 | return char2idx, idx2char
23 |
24 | def text_normalize(text):
25 | text = ''.join(char for char in unicodedata.normalize('NFD', text)
26 | if unicodedata.category(char) != 'Mn') # Strip accents
27 |
28 | text = text.lower()
29 | text = re.sub("[^{}]".format(hp.vocab), " ", text)
30 | text = re.sub("[ ]+", " ", text)
31 | return text
32 |
33 | def load_data(mode="train"):
34 | # Load vocabulary
35 | char2idx, idx2char = load_vocab()
36 |
37 | if mode in ("train", "eval"):
38 | # Parse
39 | fpaths, text_lengths, texts = [], [], []
40 | transcript = os.path.join(hp.data, 'transcript.csv')
41 | lines = codecs.open(transcript, 'r', 'utf-8').readlines()
42 | total_hours = 0
43 | if mode=="train":
44 | lines = lines[1:]
45 | else: # We attack only one sample!
46 | lines = lines[:1]
47 |
48 | for line in lines:
49 | fname, _, text = line.strip().split("|")
50 |
51 | fpath = os.path.join(hp.data, "wavs", fname + ".wav")
52 | fpaths.append(fpath)
53 |
54 | text = text_normalize(text) + "E" # E: EOS
55 | text = [char2idx[char] for char in text]
56 | text_lengths.append(len(text))
57 | texts.append(np.array(text, np.int32).tostring())
58 |
59 | return fpaths, text_lengths, texts
60 | else:
61 | # Parse
62 | lines = codecs.open(hp.test_data, 'r', 'utf-8').readlines()[1:]
63 | sents = [text_normalize(line.split(" ", 1)[-1]).strip() + "E" for line in lines] # text normalization, E: EOS
64 | lengths = [len(sent) for sent in sents]
65 | maxlen = sorted(lengths, reverse=True)[0]
66 | texts = np.zeros((len(sents), maxlen), np.int32)
67 | for i, sent in enumerate(sents):
68 | texts[i, :len(sent)] = [char2idx[char] for char in sent]
69 | return texts
70 |
71 | def get_batch():
72 | """Loads training data and put them in queues"""
73 | with tf.device('/cpu:0'):
74 | # Load data
75 | fpaths, text_lengths, texts = load_data() # list
76 | maxlen, minlen = max(text_lengths), min(text_lengths)
77 |
78 | # Calc total batch count
79 | num_batch = len(fpaths) // hp.batch_size
80 |
81 | fpaths = tf.convert_to_tensor(fpaths)
82 | text_lengths = tf.convert_to_tensor(text_lengths)
83 | texts = tf.convert_to_tensor(texts)
84 |
85 | # Create Queues
86 | fpath, text_length, text = tf.train.slice_input_producer([fpaths, text_lengths, texts], shuffle=True)
87 |
88 | # Parse
89 | text = tf.decode_raw(text, tf.int32) # (None,)
90 |
91 | if hp.prepro:
92 | def _load_spectrograms(fpath):
93 | fname = os.path.basename(fpath)
94 | mel = "mels/{}".format(fname.replace("wav", "npy"))
95 | mag = "mags/{}".format(fname.replace("wav", "npy"))
96 | return fname, np.load(mel), np.load(mag)
97 |
98 | fname, mel, mag = tf.py_func(_load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32])
99 | else:
100 | fname, mel, mag = tf.py_func(load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32]) # (None, n_mels)
101 |
102 | # Add shape information
103 | fname.set_shape(())
104 | text.set_shape((None,))
105 | mel.set_shape((None, hp.n_mels*hp.r))
106 | mag.set_shape((None, hp.n_fft//2+1))
107 |
108 | # Batching
109 | _, (texts, mels, mags, fnames) = tf.contrib.training.bucket_by_sequence_length(
110 | input_length=text_length,
111 | tensors=[text, mel, mag, fname],
112 | batch_size=hp.batch_size,
113 | bucket_boundaries=[i for i in range(minlen + 1, maxlen - 1, 20)],
114 | num_threads=16,
115 | capacity=hp.batch_size * 4,
116 | dynamic_pad=True)
117 |
118 | return texts, mels, mags, fnames, num_batch
119 |
120 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | from hyperparams import Hyperparams as hp
11 | import numpy as np
12 | from data_load import load_data
13 | import tensorflow as tf
14 | from train import Graph
15 | from utils import load_spectrograms
16 |
17 |
18 | def eval():
19 | # Load graph
20 | g = Graph(mode="eval"); print("Evaluation Graph loaded")
21 |
22 | # Load data
23 | fpaths, text_lengths, texts = load_data(mode="eval")
24 |
25 | # Parse
26 | text = np.fromstring(texts[0], np.int32) # (None,)
27 | fname, mel, mag = load_spectrograms(fpaths[0])
28 |
29 | x = np.expand_dims(text, 0) # (1, None)
30 | y = np.expand_dims(mel, 0) # (1, None, n_mels*r)
31 | z = np.expand_dims(mag, 0) # (1, None, n_mfccs)
32 |
33 | saver = tf.train.Saver()
34 | with tf.Session() as sess:
35 | saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")
36 |
37 | writer = tf.summary.FileWriter(hp.logdir, sess.graph)
38 |
39 | # Feed Forward
40 | ## mel
41 | y_hat = np.zeros((1, y.shape[1], y.shape[2]), np.float32) # hp.n_mels*hp.r
42 | for j in range(y.shape[1]):
43 | _y_hat = sess.run(g.y_hat, {g.x: x, g.y: y_hat})
44 | y_hat[:, j, :] = _y_hat[:, j, :]
45 |
46 | ## mag
47 | merged, gs = sess.run([g.merged, g.global_step], {g.x:x, g.y:y, g.y_hat: y_hat, g.z: z})
48 | writer.add_summary(merged, global_step=gs)
49 | writer.close()
50 |
51 | if __name__ == '__main__':
52 | eval()
53 | print("Done")
54 |
55 |
56 |
--------------------------------------------------------------------------------
/fig/attention.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kyubyong/tacotron/379bb7f54c3359ffe97d1f09a773bc6da49eba6f/fig/attention.gif
--------------------------------------------------------------------------------
/fig/mean_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kyubyong/tacotron/379bb7f54c3359ffe97d1f09a773bc6da49eba6f/fig/mean_loss.png
--------------------------------------------------------------------------------
/fig/training_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kyubyong/tacotron/379bb7f54c3359ffe97d1f09a773bc6da49eba6f/fig/training_curve.png
--------------------------------------------------------------------------------
/harvard_sentences.txt:
--------------------------------------------------------------------------------
1 | http://www.cs.columbia.edu/~hgs/audio/harvard.html
2 | 1. The birch canoe slid on the smooth planks.
3 | 2. Glue the sheet to the dark blue background.
4 | 3. It's easy to tell the depth of a well.
5 | 4. These days a chicken leg is a rare dish.
6 | 5. Rice is often served in round bowls.
7 | 6. The juice of lemons makes fine punch.
8 | 7. The box was thrown beside the parked truck.
9 | 8. The hogs were fed chopped corn and garbage.
10 | 9. Four hours of steady work faced us.
11 | 10. Large size in stockings is hard to sell.
12 | 11. The boy was there when the sun rose.
13 | 12. A rod is used to catch pink salmon.
14 | 13. The source of the huge river is the clear spring.
15 | 14. Kick the ball straight and follow through.
16 | 15. Help the woman get back to her feet.
17 | 16. A pot of tea helps to pass the evening.
18 | 17. Smoky fires lack flame and heat.
19 | 18. The soft cushion broke the man's fall.
20 | 19. The salt breeze came across from the sea.
21 | 20. The girl at the booth sold fifty bonds.
22 |
--------------------------------------------------------------------------------
/history.md:
--------------------------------------------------------------------------------
1 | # A (Heavily Documented) TensorFlow Implementation of Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model
2 |
3 | ## **Major History**
4 | * June 21, 2017. Fourth draft.
5 | * **I've updated the code for TF 1.1 to TF 1.2.** Turns out that TF 1.2 has a new api for attention wrapper and more detailed options.
6 | * I've added a sanity check option to the `hyperparams.py` file. Basically, it's purpose is to find out if our model is able to learn a handful of training data wihtout caring about generalization. For that, the training was done on a single mini-batch (32 samples) over and over again, and sample generation was based on the same text. I observed a quite smooth training curve for as below, and after around 18K global steps it started to generate recognizable sounds. The sample results after 36K steps are available in the `logdir_s` folder. It took around seven hours on a single gtx 1080. The pretrained files can be downloaded from [here](https://u42868014.dl.dropboxusercontent.com/u/42868014/tacotron/logdir_s.zip). The training curve looks like this.
7 |
8 |
9 |
10 | * June 4, 2017. Third draft.
11 | * Some people reported they gained promising results, based on my code. Among them are, [@ggsonic](https://www.github.com/ggsonic), [@chief7](https://www.github.com/chief7). To check relevant discussions, see this [discussion](https://www.github.com/Kyubyong/tacotron/issues/30), or their repo.
12 | * According @ggsonic, instance normalization worked better than batch normalization.
13 | * @chief7 trained on pavoque data, a German corpus spoken by a single male actor. He said that instance normalization and zero-masking are good choices.
14 | * Yuxuan, the first author of the paer, advised me to do sanity-check first with small data, and to adjust hyperparemters since our dataset is different from his. I really appreciate his tips, and hope this would help you.
15 | * [Alex's repo](https://github.com/barronalex/Tacotron), which is another implementation of Tacotron, seems to be successful in getting promising results with some small dataset. He's working on a big one.
16 | * June 2, 2017.
17 | * Added `train_multiple_gpus.py` for multiple GPUs.
18 | * June 1, 2017. Second draft.
19 | * I corrected some mistakes with the help of several contributors (THANKS!), and re-factored source codes so that they are more readable and modular. So far, I couldn't get any promising results.
20 | * May 17, 2017. First draft.
21 | * You can run it following the steps below, but good results are not guaranteed. I'll be working on debugging this weekend. (**Code reviews and/or contributions are more than welcome!**)
22 |
23 | ## Requirements
24 | * NumPy >= 1.11.1
25 | * TensorFlow == 1.2
26 | * librosa
27 | * tqdm
28 |
29 | ## Data
30 | Since the [original paper](https://arxiv.org/abs/1703.10135) was based on their internal data, I use a freely available one, instead.
31 |
32 | [The World English Bible](https://en.wikipedia.org/wiki/World_English_Bible) is a public domain update of the American Standard Version of 1901 into modern English. Its text and audio recordings are freely available [here](http://www.audiotreasure.com/webindex.htm). Unfortunately, however, each of the audio files matches a chapter, not a verse, so is too long for many machine learning tasks. I had someone slice them by verse manually. You can download [the audio data](https://dl.dropboxusercontent.com/u/42868014/WEB.zip) and its [text](https://dl.dropboxusercontent.com/u/42868014/text.csv) from my dropbox.
33 |
34 |
35 |
36 | ## File description
37 | * `hyperparams.py` includes all hyper parameters that are needed.
38 | * `prepare_pavoque.py` creates sliced sound files from raw sound data, and constructs necessary information.
39 | * `prepro.py` loads vocabulary, training/evaluation data.
40 | * `data_load.py` loads data and put them in queues so multiple mini-bach data are generated in parallel.
41 | * `utils.py` has several custom operational functions.
42 | * `modules.py` contains building blocks for encoding/decoding networks.
43 | * `networks.py` has three core networks, that is, encoding, decoding, and postprocessing network.
44 | * `train.py` is for training.
45 | * `eval.py` is for sample synthesis.
46 |
47 |
48 | ## Training
49 | * STEP 1. Adjust hyper parameters in `hyperparams.py` if necessary.
50 | * STEP 2. Download and extract [the audio data](https://dl.dropboxusercontent.com/u/42868014/WEB.zip) and its [text](https://dl.dropboxusercontent.com/u/42868014/text.csv).
51 | * STEP 3. Run `train.py`. or `train_multi_gpus.py` if you have more than one gpu.
52 |
53 | ## Sample Synthesis
54 | * Run `eval.py` to get samples.
55 |
56 | ### Acknowledgements
57 | I would like to show my respect to Dave, the host of www.audiotreasure.com and the reader of the audio files.
58 |
--------------------------------------------------------------------------------
/hyperparams.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 | class Hyperparams:
8 | '''Hyper parameters'''
9 |
10 | # pipeline
11 | prepro = False # if True, run `python prepro.py` first before running `python train.py`.
12 |
13 | vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding E: End of Sentence
14 |
15 | # data
16 | data = "/data/private/voice/LJSpeech-1.0"
17 | # data = "/data/private/voice/nick"
18 | test_data = 'harvard_sentences.txt'
19 | max_duration = 10.0
20 |
21 | # signal processing
22 | sr = 22050 # Sample rate.
23 | n_fft = 2048 # fft points (samples)
24 | frame_shift = 0.0125 # seconds
25 | frame_length = 0.05 # seconds
26 | hop_length = int(sr*frame_shift) # samples.
27 | win_length = int(sr*frame_length) # samples.
28 | n_mels = 80 # Number of Mel banks to generate
29 | power = 1.2 # Exponent for amplifying the predicted magnitude
30 | n_iter = 50 # Number of inversion iterations
31 | preemphasis = .97 # or None
32 | max_db = 100
33 | ref_db = 20
34 |
35 | # model
36 | embed_size = 256 # alias = E
37 | encoder_num_banks = 16
38 | decoder_num_banks = 8
39 | num_highwaynet_blocks = 4
40 | r = 5 # Reduction factor. Paper => 2, 3, 5
41 | dropout_rate = .5
42 |
43 | # training scheme
44 | lr = 0.001 # Initial learning rate.
45 | logdir = "logdir/01"
46 | sampledir = 'samples'
47 | batch_size = 32
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | from hyperparams import Hyperparams as hp
11 | import tensorflow as tf
12 |
13 |
14 | def embed(inputs, vocab_size, num_units, zero_pad=True, scope="embedding", reuse=None):
15 | '''Embeds a given tensor.
16 |
17 | Args:
18 | inputs: A `Tensor` with type `int32` or `int64` containing the ids
19 | to be looked up in `lookup table`.
20 | vocab_size: An int. Vocabulary size.
21 | num_units: An int. Number of embedding hidden units.
22 | zero_pad: A boolean. If True, all the values of the fist row (id 0)
23 | should be constant zeros.
24 | scope: Optional scope for `variable_scope`.
25 | reuse: Boolean, whether to reuse the weights of a previous layer
26 | by the same name.
27 |
28 | Returns:
29 | A `Tensor` with one more rank than inputs's. The last dimesionality
30 | should be `num_units`.
31 | '''
32 | with tf.variable_scope(scope, reuse=reuse):
33 | lookup_table = tf.get_variable('lookup_table',
34 | dtype=tf.float32,
35 | shape=[vocab_size, num_units],
36 | initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01))
37 | if zero_pad:
38 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
39 | lookup_table[1:, :]), 0)
40 | return tf.nn.embedding_lookup(lookup_table, inputs)
41 |
42 |
43 | def bn(inputs,
44 | is_training=True,
45 | activation_fn=None,
46 | scope="bn",
47 | reuse=None):
48 | '''Applies batch normalization.
49 |
50 | Args:
51 | inputs: A tensor with 2 or more dimensions, where the first dimension has
52 | `batch_size`. If type is `bn`, the normalization is over all but
53 | the last dimension. Or if type is `ln`, the normalization is over
54 | the last dimension. Note that this is different from the native
55 | `tf.contrib.layers.batch_norm`. For this I recommend you change
56 | a line in ``tensorflow/contrib/layers/python/layers/layer.py`
57 | as follows.
58 | Before: mean, variance = nn.moments(inputs, axis, keep_dims=True)
59 | After: mean, variance = nn.moments(inputs, [-1], keep_dims=True)
60 | is_training: Whether or not the layer is in training mode.
61 | activation_fn: Activation function.
62 | scope: Optional scope for `variable_scope`.
63 | reuse: Boolean, whether to reuse the weights of a previous layer
64 | by the same name.
65 |
66 | Returns:
67 | A tensor with the same shape and data dtype as `inputs`.
68 | '''
69 | inputs_shape = inputs.get_shape()
70 | inputs_rank = inputs_shape.ndims
71 |
72 | # use fused batch norm if inputs_rank in [2, 3, 4] as it is much faster.
73 | # pay attention to the fact that fused_batch_norm requires shape to be rank 4 of NHWC.
74 | if inputs_rank in [2, 3, 4]:
75 | if inputs_rank == 2:
76 | inputs = tf.expand_dims(inputs, axis=1)
77 | inputs = tf.expand_dims(inputs, axis=2)
78 | elif inputs_rank == 3:
79 | inputs = tf.expand_dims(inputs, axis=1)
80 |
81 | outputs = tf.contrib.layers.batch_norm(inputs=inputs,
82 | center=True,
83 | scale=True,
84 | updates_collections=None,
85 | is_training=is_training,
86 | scope=scope,
87 | fused=True,
88 | reuse=reuse)
89 | # restore original shape
90 | if inputs_rank == 2:
91 | outputs = tf.squeeze(outputs, axis=[1, 2])
92 | elif inputs_rank == 3:
93 | outputs = tf.squeeze(outputs, axis=1)
94 | else: # fallback to naive batch norm
95 | outputs = tf.contrib.layers.batch_norm(inputs=inputs,
96 | center=True,
97 | scale=True,
98 | updates_collections=None,
99 | is_training=is_training,
100 | scope=scope,
101 | reuse=reuse,
102 | fused=False)
103 | if activation_fn is not None:
104 | outputs = activation_fn(outputs)
105 |
106 | return outputs
107 |
108 | def conv1d(inputs,
109 | filters=None,
110 | size=1,
111 | rate=1,
112 | padding="SAME",
113 | use_bias=False,
114 | activation_fn=None,
115 | scope="conv1d",
116 | reuse=None):
117 | '''
118 | Args:
119 | inputs: A 3-D tensor with shape of [batch, time, depth].
120 | filters: An int. Number of outputs (=activation maps)
121 | size: An int. Filter size.
122 | rate: An int. Dilation rate.
123 | padding: Either `same` or `valid` or `causal` (case-insensitive).
124 | use_bias: A boolean.
125 | scope: Optional scope for `variable_scope`.
126 | reuse: Boolean, whether to reuse the weights of a previous layer
127 | by the same name.
128 | '''
129 | with tf.variable_scope(scope):
130 | if padding.lower()=="causal":
131 | # pre-padding for causality
132 | pad_len = (size - 1) * rate # padding size
133 | inputs = tf.pad(inputs, [[0, 0], [pad_len, 0], [0, 0]])
134 | padding = "valid"
135 |
136 | if filters is None:
137 | filters = inputs.get_shape().as_list[-1]
138 |
139 | params = {"inputs":inputs, "filters":filters, "kernel_size":size,
140 | "dilation_rate":rate, "padding":padding, "activation":activation_fn,
141 | "use_bias":use_bias, "reuse":reuse}
142 |
143 | outputs = tf.layers.conv1d(**params)
144 | return outputs
145 |
146 | def conv1d_banks(inputs, K=16, is_training=True, scope="conv1d_banks", reuse=None):
147 | '''Applies a series of conv1d separately.
148 |
149 | Args:
150 | inputs: A 3d tensor with shape of [N, T, C]
151 | K: An int. The size of conv1d banks. That is,
152 | The `inputs` are convolved with K filters: 1, 2, ..., K.
153 | is_training: A boolean. This is passed to an argument of `bn`.
154 | scope: Optional scope for `variable_scope`.
155 | reuse: Boolean, whether to reuse the weights of a previous layer
156 | by the same name.
157 |
158 | Returns:
159 | A 3d tensor with shape of [N, T, K*Hp.embed_size//2].
160 | '''
161 | with tf.variable_scope(scope, reuse=reuse):
162 | outputs = conv1d(inputs, hp.embed_size//2, 1) # k=1
163 | for k in range(2, K+1): # k = 2...K
164 | with tf.variable_scope("num_{}".format(k)):
165 | output = conv1d(inputs, hp.embed_size // 2, k)
166 | outputs = tf.concat((outputs, output), -1)
167 | outputs = bn(outputs, is_training=is_training, activation_fn=tf.nn.relu)
168 | return outputs # (N, T, Hp.embed_size//2*K)
169 |
170 | def gru(inputs, num_units=None, bidirection=False, scope="gru", reuse=None):
171 | '''Applies a GRU.
172 |
173 | Args:
174 | inputs: A 3d tensor with shape of [N, T, C].
175 | num_units: An int. The number of hidden units.
176 | bidirection: A boolean. If True, bidirectional results
177 | are concatenated.
178 | scope: Optional scope for `variable_scope`.
179 | reuse: Boolean, whether to reuse the weights of a previous layer
180 | by the same name.
181 |
182 | Returns:
183 | If bidirection is True, a 3d tensor with shape of [N, T, 2*num_units],
184 | otherwise [N, T, num_units].
185 | '''
186 | with tf.variable_scope(scope, reuse=reuse):
187 | if num_units is None:
188 | num_units = inputs.get_shape().as_list[-1]
189 |
190 | cell = tf.contrib.rnn.GRUCell(num_units)
191 | if bidirection:
192 | cell_bw = tf.contrib.rnn.GRUCell(num_units)
193 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell, cell_bw, inputs, dtype=tf.float32)
194 | return tf.concat(outputs, 2)
195 | else:
196 | outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)
197 | return outputs
198 |
199 | def attention_decoder(inputs, memory, num_units=None, scope="attention_decoder", reuse=None):
200 | '''Applies a GRU to `inputs`, while attending `memory`.
201 | Args:
202 | inputs: A 3d tensor with shape of [N, T', C']. Decoder inputs.
203 | memory: A 3d tensor with shape of [N, T, C]. Outputs of encoder network.
204 | num_units: An int. Attention size.
205 | scope: Optional scope for `variable_scope`.
206 | reuse: Boolean, whether to reuse the weights of a previous layer
207 | by the same name.
208 |
209 | Returns:
210 | A 3d tensor with shape of [N, T, num_units].
211 | '''
212 | with tf.variable_scope(scope, reuse=reuse):
213 | if num_units is None:
214 | num_units = inputs.get_shape().as_list[-1]
215 |
216 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units,
217 | memory)
218 | decoder_cell = tf.contrib.rnn.GRUCell(num_units)
219 | cell_with_attention = tf.contrib.seq2seq.AttentionWrapper(decoder_cell,
220 | attention_mechanism,
221 | num_units,
222 | alignment_history=True)
223 | outputs, state = tf.nn.dynamic_rnn(cell_with_attention, inputs, dtype=tf.float32) #( N, T', 16)
224 |
225 | return outputs, state
226 |
227 | def prenet(inputs, num_units=None, is_training=True, scope="prenet", reuse=None):
228 | '''Prenet for Encoder and Decoder1.
229 | Args:
230 | inputs: A 2D or 3D tensor.
231 | num_units: A list of two integers. or None.
232 | is_training: A python boolean.
233 | scope: Optional scope for `variable_scope`.
234 | reuse: Boolean, whether to reuse the weights of a previous layer
235 | by the same name.
236 |
237 | Returns:
238 | A 3D tensor of shape [N, T, num_units/2].
239 | '''
240 | if num_units is None:
241 | num_units = [hp.embed_size, hp.embed_size//2]
242 |
243 | with tf.variable_scope(scope, reuse=reuse):
244 | outputs = tf.layers.dense(inputs, units=num_units[0], activation=tf.nn.relu, name="dense1")
245 | outputs = tf.layers.dropout(outputs, rate=hp.dropout_rate, training=is_training, name="dropout1")
246 | outputs = tf.layers.dense(outputs, units=num_units[1], activation=tf.nn.relu, name="dense2")
247 | outputs = tf.layers.dropout(outputs, rate=hp.dropout_rate, training=is_training, name="dropout2")
248 | return outputs # (N, ..., num_units[1])
249 |
250 | def highwaynet(inputs, num_units=None, scope="highwaynet", reuse=None):
251 | '''Highway networks, see https://arxiv.org/abs/1505.00387
252 |
253 | Args:
254 | inputs: A 3D tensor of shape [N, T, W].
255 | num_units: An int or `None`. Specifies the number of units in the highway layer
256 | or uses the input size if `None`.
257 | scope: Optional scope for `variable_scope`.
258 | reuse: Boolean, whether to reuse the weights of a previous layer
259 | by the same name.
260 |
261 | Returns:
262 | A 3D tensor of shape [N, T, W].
263 | '''
264 | if not num_units:
265 | num_units = inputs.get_shape()[-1]
266 |
267 | with tf.variable_scope(scope, reuse=reuse):
268 | H = tf.layers.dense(inputs, units=num_units, activation=tf.nn.relu, name="dense1")
269 | T = tf.layers.dense(inputs, units=num_units, activation=tf.nn.sigmoid,
270 | bias_initializer=tf.constant_initializer(-1.0), name="dense2")
271 | outputs = H*T + inputs*(1.-T)
272 | return outputs
273 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | from hyperparams import Hyperparams as hp
11 | from modules import *
12 | import tensorflow as tf
13 |
14 |
15 | def encoder(inputs, is_training=True, scope="encoder", reuse=None):
16 | '''
17 | Args:
18 | inputs: A 2d tensor with shape of [N, T_x, E], with dtype of int32. Encoder inputs.
19 | is_training: Whether or not the layer is in training mode.
20 | scope: Optional scope for `variable_scope`
21 | reuse: Boolean, whether to reuse the weights of a previous layer
22 | by the same name.
23 |
24 | Returns:
25 | A collection of Hidden vectors. So-called memory. Has the shape of (N, T_x, E).
26 | '''
27 | with tf.variable_scope(scope, reuse=reuse):
28 | # Encoder pre-net
29 | prenet_out = prenet(inputs, is_training=is_training) # (N, T_x, E/2)
30 |
31 | # Encoder CBHG
32 | ## Conv1D banks
33 | enc = conv1d_banks(prenet_out, K=hp.encoder_num_banks, is_training=is_training) # (N, T_x, K*E/2)
34 |
35 | ## Max pooling
36 | enc = tf.layers.max_pooling1d(enc, pool_size=2, strides=1, padding="same") # (N, T_x, K*E/2)
37 |
38 | ## Conv1D projections
39 | enc = conv1d(enc, filters=hp.embed_size//2, size=3, scope="conv1d_1") # (N, T_x, E/2)
40 | enc = bn(enc, is_training=is_training, activation_fn=tf.nn.relu, scope="conv1d_1")
41 |
42 | enc = conv1d(enc, filters=hp.embed_size // 2, size=3, scope="conv1d_2") # (N, T_x, E/2)
43 | enc = bn(enc, is_training=is_training, scope="conv1d_2")
44 |
45 | enc += prenet_out # (N, T_x, E/2) # residual connections
46 |
47 | ## Highway Nets
48 | for i in range(hp.num_highwaynet_blocks):
49 | enc = highwaynet(enc, num_units=hp.embed_size//2,
50 | scope='highwaynet_{}'.format(i)) # (N, T_x, E/2)
51 |
52 | ## Bidirectional GRU
53 | memory = gru(enc, num_units=hp.embed_size//2, bidirection=True) # (N, T_x, E)
54 |
55 | return memory
56 |
57 | def decoder1(inputs, memory, is_training=True, scope="decoder1", reuse=None):
58 | '''
59 | Args:
60 | inputs: A 3d tensor with shape of [N, T_y/r, n_mels(*r)]. Shifted log melspectrogram of sound files.
61 | memory: A 3d tensor with shape of [N, T_x, E].
62 | is_training: Whether or not the layer is in training mode.
63 | scope: Optional scope for `variable_scope`
64 | reuse: Boolean, whether to reuse the weights of a previous layer
65 | by the same name.
66 |
67 | Returns
68 | Predicted log melspectrogram tensor with shape of [N, T_y/r, n_mels*r].
69 | '''
70 | with tf.variable_scope(scope, reuse=reuse):
71 | # Decoder pre-net
72 | inputs = prenet(inputs, is_training=is_training) # (N, T_y/r, E/2)
73 |
74 | # Attention RNN
75 | dec, state = attention_decoder(inputs, memory, num_units=hp.embed_size) # (N, T_y/r, E)
76 |
77 | ## for attention monitoring
78 | alignments = tf.transpose(state.alignment_history.stack(),[1,2,0])
79 |
80 | # Decoder RNNs
81 | dec += gru(dec, hp.embed_size, bidirection=False, scope="decoder_gru1") # (N, T_y/r, E)
82 | dec += gru(dec, hp.embed_size, bidirection=False, scope="decoder_gru2") # (N, T_y/r, E)
83 |
84 | # Outputs => (N, T_y/r, n_mels*r)
85 | mel_hats = tf.layers.dense(dec, hp.n_mels*hp.r)
86 |
87 | return mel_hats, alignments
88 |
89 | def decoder2(inputs, is_training=True, scope="decoder2", reuse=None):
90 | '''Decoder Post-processing net = CBHG
91 | Args:
92 | inputs: A 3d tensor with shape of [N, T_y/r, n_mels*r]. Log magnitude spectrogram of sound files.
93 | It is recovered to its original shape.
94 | is_training: Whether or not the layer is in training mode.
95 | scope: Optional scope for `variable_scope`
96 | reuse: Boolean, whether to reuse the weights of a previous layer
97 | by the same name.
98 |
99 | Returns
100 | Predicted linear spectrogram tensor with shape of [N, T_y, 1+n_fft//2].
101 | '''
102 | with tf.variable_scope(scope, reuse=reuse):
103 | # Restore shape -> (N, Ty, n_mels)
104 | inputs = tf.reshape(inputs, [tf.shape(inputs)[0], -1, hp.n_mels])
105 |
106 | # Conv1D bank
107 | dec = conv1d_banks(inputs, K=hp.decoder_num_banks, is_training=is_training) # (N, T_y, E*K/2)
108 |
109 | # Max pooling
110 | dec = tf.layers.max_pooling1d(dec, pool_size=2, strides=1, padding="same") # (N, T_y, E*K/2)
111 |
112 | ## Conv1D projections
113 | dec = conv1d(dec, filters=hp.embed_size // 2, size=3, scope="conv1d_1") # (N, T_x, E/2)
114 | dec = bn(dec, is_training=is_training, activation_fn=tf.nn.relu, scope="conv1d_1")
115 |
116 | dec = conv1d(dec, filters=hp.n_mels, size=3, scope="conv1d_2") # (N, T_x, E/2)
117 | dec = bn(dec, is_training=is_training, scope="conv1d_2")
118 |
119 | # Extra affine transformation for dimensionality sync
120 | dec = tf.layers.dense(dec, hp.embed_size//2) # (N, T_y, E/2)
121 |
122 | # Highway Nets
123 | for i in range(4):
124 | dec = highwaynet(dec, num_units=hp.embed_size//2,
125 | scope='highwaynet_{}'.format(i)) # (N, T_y, E/2)
126 |
127 | # Bidirectional GRU
128 | dec = gru(dec, hp.embed_size//2, bidirection=True) # (N, T_y, E)
129 |
130 | # Outputs => (N, T_y, 1+n_fft//2)
131 | outputs = tf.layers.dense(dec, 1+hp.n_fft//2)
132 |
133 | return outputs
134 |
--------------------------------------------------------------------------------
/prepro.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | from utils import load_spectrograms
11 | import os
12 | from data_load import load_data
13 | import numpy as np
14 | import tqdm
15 |
16 | # Load data
17 | fpaths, _, _ = load_data() # list
18 |
19 | for fpath in tqdm.tqdm(fpaths):
20 | fname, mel, mag = load_spectrograms(fpath)
21 | if not os.path.exists("mels"): os.mkdir("mels")
22 | if not os.path.exists("mags"): os.mkdir("mags")
23 |
24 | np.save("mels/{}".format(fname.replace("wav", "npy")), mel)
25 | np.save("mags/{}".format(fname.replace("wav", "npy")), mag)
--------------------------------------------------------------------------------
/synthesize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # /usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | from hyperparams import Hyperparams as hp
11 | import tqdm
12 | from data_load import load_data
13 | import tensorflow as tf
14 | from train import Graph
15 | from utils import spectrogram2wav
16 | from scipy.io.wavfile import write
17 | import os
18 | import numpy as np
19 |
20 |
21 | def synthesize():
22 | if not os.path.exists(hp.sampledir): os.mkdir(hp.sampledir)
23 |
24 | # Load graph
25 | g = Graph(mode="synthesize"); print("Graph loaded")
26 |
27 | # Load data
28 | texts = load_data(mode="synthesize")
29 |
30 | saver = tf.train.Saver()
31 | with tf.Session() as sess:
32 | saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")
33 |
34 | # Feed Forward
35 | ## mel
36 | y_hat = np.zeros((texts.shape[0], 200, hp.n_mels*hp.r), np.float32) # hp.n_mels*hp.r
37 | for j in tqdm.tqdm(range(200)):
38 | _y_hat = sess.run(g.y_hat, {g.x: texts, g.y: y_hat})
39 | y_hat[:, j, :] = _y_hat[:, j, :]
40 | ## mag
41 | mags = sess.run(g.z_hat, {g.y_hat: y_hat})
42 | for i, mag in enumerate(mags):
43 | print("File {}.wav is being generated ...".format(i+1))
44 | audio = spectrogram2wav(mag)
45 | write(os.path.join(hp.sampledir, '{}.wav'.format(i+1)), hp.sr, audio)
46 |
47 | if __name__ == '__main__':
48 | synthesize()
49 | print("Done")
50 |
51 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/tacotron
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | import os
11 | from hyperparams import Hyperparams as hp
12 | import tensorflow as tf
13 | from tqdm import tqdm
14 | from data_load import get_batch, load_vocab
15 | from modules import *
16 | from networks import encoder, decoder1, decoder2
17 | from utils import *
18 |
19 | class Graph:
20 | def __init__(self, mode="train"):
21 | # Load vocabulary
22 | self.char2idx, self.idx2char = load_vocab()
23 |
24 | # Set phase
25 | is_training=True if mode=="train" else False
26 |
27 | # Graph
28 | # Data Feeding
29 | # x: Text. (N, Tx)
30 | # y: Reduced melspectrogram. (N, Ty//r, n_mels*r)
31 | # z: Magnitude. (N, Ty, n_fft//2+1)
32 | if mode=="train":
33 | self.x, self.y, self.z, self.fnames, self.num_batch = get_batch()
34 | elif mode=="eval":
35 | self.x = tf.placeholder(tf.int32, shape=(None, None))
36 | self.y = tf.placeholder(tf.float32, shape=(None, None, hp.n_mels*hp.r))
37 | self.z = tf.placeholder(tf.float32, shape=(None, None, 1+hp.n_fft//2))
38 | self.fnames = tf.placeholder(tf.string, shape=(None,))
39 | else: # Synthesize
40 | self.x = tf.placeholder(tf.int32, shape=(None, None))
41 | self.y = tf.placeholder(tf.float32, shape=(None, None, hp.n_mels * hp.r))
42 |
43 | # Get encoder/decoder inputs
44 | self.encoder_inputs = embed(self.x, len(hp.vocab), hp.embed_size) # (N, T_x, E)
45 | self.decoder_inputs = tf.concat((tf.zeros_like(self.y[:, :1, :]), self.y[:, :-1, :]), 1) # (N, Ty/r, n_mels*r)
46 | self.decoder_inputs = self.decoder_inputs[:, :, -hp.n_mels:] # feed last frames only (N, Ty/r, n_mels)
47 |
48 | # Networks
49 | with tf.variable_scope("net"):
50 | # Encoder
51 | self.memory = encoder(self.encoder_inputs, is_training=is_training) # (N, T_x, E)
52 |
53 | # Decoder1
54 | self.y_hat, self.alignments = decoder1(self.decoder_inputs,
55 | self.memory,
56 | is_training=is_training) # (N, T_y//r, n_mels*r)
57 | # Decoder2 or postprocessing
58 | self.z_hat = decoder2(self.y_hat, is_training=is_training) # (N, T_y//r, (1+n_fft//2)*r)
59 |
60 | # monitor
61 | self.audio = tf.py_func(spectrogram2wav, [self.z_hat[0]], tf.float32)
62 |
63 | if mode in ("train", "eval"):
64 | # Loss
65 | self.loss1 = tf.reduce_mean(tf.abs(self.y_hat - self.y))
66 | self.loss2 = tf.reduce_mean(tf.abs(self.z_hat - self.z))
67 | self.loss = self.loss1 + self.loss2
68 |
69 | # Training Scheme
70 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
71 | self.lr = learning_rate_decay(hp.lr, global_step=self.global_step)
72 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
73 |
74 | ## gradient clipping
75 | self.gvs = self.optimizer.compute_gradients(self.loss)
76 | self.clipped = []
77 | for grad, var in self.gvs:
78 | grad = tf.clip_by_norm(grad, 5.)
79 | self.clipped.append((grad, var))
80 | self.train_op = self.optimizer.apply_gradients(self.clipped, global_step=self.global_step)
81 |
82 | # Summary
83 | tf.summary.scalar('{}/loss1'.format(mode), self.loss1)
84 | tf.summary.scalar('{}/loss'.format(mode), self.loss)
85 | tf.summary.scalar('{}/lr'.format(mode), self.lr)
86 |
87 | tf.summary.image("{}/mel_gt".format(mode), tf.expand_dims(self.y, -1), max_outputs=1)
88 | tf.summary.image("{}/mel_hat".format(mode), tf.expand_dims(self.y_hat, -1), max_outputs=1)
89 | tf.summary.image("{}/mag_gt".format(mode), tf.expand_dims(self.z, -1), max_outputs=1)
90 | tf.summary.image("{}/mag_hat".format(mode), tf.expand_dims(self.z_hat, -1), max_outputs=1)
91 |
92 | tf.summary.audio("{}/sample".format(mode), tf.expand_dims(self.audio, 0), hp.sr)
93 | self.merged = tf.summary.merge_all()
94 |
95 | if __name__ == '__main__':
96 | g = Graph(); print("Training Graph loaded")
97 |
98 | # with g.graph.as_default():
99 | sv = tf.train.Supervisor(logdir=hp.logdir, save_summaries_secs=60, save_model_secs=0)
100 | with sv.managed_session() as sess:
101 | while 1:
102 | for _ in tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'):
103 | _, gs = sess.run([g.train_op, g.global_step])
104 |
105 | # Write checkpoint files
106 | if gs % 1000 == 0:
107 | sv.saver.save(sess, hp.logdir + '/model_gs_{}k'.format(gs//1000))
108 |
109 | # plot the first alignment for logging
110 | al = sess.run(g.alignments)
111 | plot_alignment(al[0], gs)
112 |
113 | print("Done")
114 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # /usr/bin/python2
3 | '''
4 | By kyubyong park. kbpark.linguist@gmail.com.
5 | https://www.github.com/kyubyong/dc_tts
6 | '''
7 | from __future__ import print_function, division
8 |
9 | from hyperparams import Hyperparams as hp
10 | import numpy as np
11 | import tensorflow as tf
12 | import librosa
13 | import copy
14 | import matplotlib
15 | matplotlib.use('pdf')
16 | import matplotlib.pyplot as plt
17 | from scipy import signal
18 | import os
19 |
20 |
21 | def get_spectrograms(fpath):
22 | '''Returns normalized log(melspectrogram) and log(magnitude) from `sound_file`.
23 | Args:
24 | sound_file: A string. The full path of a sound file.
25 |
26 | Returns:
27 | mel: A 2d array of shape (T, n_mels) <- Transposed
28 | mag: A 2d array of shape (T, 1+n_fft/2) <- Transposed
29 | '''
30 | # num = np.random.randn()
31 | # if num < .2:
32 | # y, sr = librosa.load(fpath, sr=hp.sr)
33 | # else:
34 | # if num < .4:
35 | # tempo = 1.1
36 | # elif num < .6:
37 | # tempo = 1.2
38 | # elif num < .8:
39 | # tempo = 0.9
40 | # else:
41 | # tempo = 0.8
42 | # cmd = "ffmpeg -i {} -y ar {} -hide_banner -loglevel panic -ac 1 -filter:a atempo={} -vn temp.wav".format(fpath, hp.sr, tempo)
43 | # os.system(cmd)
44 | # y, sr = librosa.load('temp.wav', sr=hp.sr)
45 |
46 | # Loading sound file
47 | y, sr = librosa.load(fpath, sr=hp.sr)
48 |
49 |
50 | # Trimming
51 | y, _ = librosa.effects.trim(y)
52 |
53 | # Preemphasis
54 | y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1])
55 |
56 | # stft
57 | linear = librosa.stft(y=y,
58 | n_fft=hp.n_fft,
59 | hop_length=hp.hop_length,
60 | win_length=hp.win_length)
61 |
62 | # magnitude spectrogram
63 | mag = np.abs(linear) # (1+n_fft//2, T)
64 |
65 | # mel spectrogram
66 | mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) # (n_mels, 1+n_fft//2)
67 | mel = np.dot(mel_basis, mag) # (n_mels, t)
68 |
69 | # to decibel
70 | mel = 20 * np.log10(np.maximum(1e-5, mel))
71 | mag = 20 * np.log10(np.maximum(1e-5, mag))
72 |
73 | # normalize
74 | mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
75 | mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1)
76 |
77 | # Transpose
78 | mel = mel.T.astype(np.float32) # (T, n_mels)
79 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2)
80 |
81 | return mel, mag
82 |
83 |
84 | def spectrogram2wav(mag):
85 | '''# Generate wave file from spectrogram'''
86 | # transpose
87 | mag = mag.T
88 |
89 | # de-noramlize
90 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db
91 |
92 | # to amplitude
93 | mag = np.power(10.0, mag * 0.05)
94 |
95 | # wav reconstruction
96 | wav = griffin_lim(mag)
97 |
98 | # de-preemphasis
99 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav)
100 |
101 | # trim
102 | wav, _ = librosa.effects.trim(wav)
103 |
104 | return wav.astype(np.float32)
105 |
106 |
107 | def griffin_lim(spectrogram):
108 | '''Applies Griffin-Lim's raw.
109 | '''
110 | X_best = copy.deepcopy(spectrogram)
111 | for i in range(hp.n_iter):
112 | X_t = invert_spectrogram(X_best)
113 | est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length)
114 | phase = est / np.maximum(1e-8, np.abs(est))
115 | X_best = spectrogram * phase
116 | X_t = invert_spectrogram(X_best)
117 | y = np.real(X_t)
118 |
119 | return y
120 |
121 |
122 | def invert_spectrogram(spectrogram):
123 | '''
124 | spectrogram: [f, t]
125 | '''
126 | return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann")
127 |
128 |
129 | def plot_alignment(alignment, gs):
130 | """Plots the alignment
131 | alignments: A list of (numpy) matrix of shape (encoder_steps, decoder_steps)
132 | gs : (int) global step
133 | """
134 | fig, ax = plt.subplots()
135 | im = ax.imshow(alignment)
136 |
137 | # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
138 | fig.colorbar(im)
139 | plt.title('{} Steps'.format(gs))
140 | plt.savefig('{}/alignment_{}k.png'.format(hp.logdir, gs//1000), format='png')
141 |
142 | def learning_rate_decay(init_lr, global_step, warmup_steps=4000.):
143 | '''Noam scheme from tensor2tensor'''
144 | step = tf.cast(global_step + 1, dtype=tf.float32)
145 | return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)
146 |
147 | def load_spectrograms(fpath):
148 | fname = os.path.basename(fpath)
149 | mel, mag = get_spectrograms(fpath)
150 | t = mel.shape[0]
151 | num_paddings = hp.r - (t % hp.r) if t % hp.r != 0 else 0 # for reduction
152 | mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode="constant")
153 | mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant")
154 | return fname, mel.reshape((-1, hp.n_mels*hp.r)), mag
155 |
--------------------------------------------------------------------------------