├── .gitignore
├── LICENSE
├── README.md
├── TRAINING_DATA.md
├── datasets
├── __init__.py
├── blizzard.py
├── datafeeder.py
└── ljspeech.py
├── demo_server.py
├── eval.py
├── hparams.py
├── models
├── __init__.py
├── helpers.py
├── modules.py
├── rnn_wrappers.py
└── tacotron.py
├── preprocess.py
├── requirements.txt
├── synthesizer.py
├── tests
├── __init__.py
├── cmudict_test.py
├── numbers_test.py
└── text_test.py
├── text
├── __init__.py
├── cleaners.py
├── cmudict.py
├── numbers.py
└── symbols.py
├── train.py
└── util
├── __init__.py
├── audio.py
├── infolog.py
└── plot.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .cache/
3 | *.pyc
4 | .DS_Store
5 | run*.sh
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tacotron
2 |
3 | An implementation of Tacotron speech synthesis in TensorFlow.
4 |
5 |
6 | ### Audio Samples
7 |
8 | * **[Audio Samples](https://keithito.github.io/audio-samples/)** from models trained using this repo.
9 | * The first set was trained for 441K steps on the [LJ Speech Dataset](https://keithito.com/LJ-Speech-Dataset/)
10 | * Speech started to become intelligible around 20K steps.
11 | * The second set was trained by [@MXGray](https://github.com/MXGray) for 140K steps on the [Nancy Corpus](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011/).
12 |
13 |
14 | ### Recent Updates
15 |
16 | 1. @npuichigo [fixed](https://github.com/keithito/tacotron/pull/205) a bug where dropout was not being applied in the prenet.
17 |
18 | 2. @begeekmyfriend created a [fork](https://github.com/begeekmyfriend/tacotron) that adds location-sensitive attention and the stop token from the [Tacotron 2](https://arxiv.org/abs/1712.05884) paper. This can greatly reduce the amount of data required to train a model.
19 |
20 |
21 | ## Background
22 |
23 | In April 2017, Google published a paper, [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/pdf/1703.10135.pdf),
24 | where they present a neural text-to-speech model that learns to synthesize speech directly from
25 | (text, audio) pairs. However, they didn't release their source code or training data. This is an
26 | independent attempt to provide an open-source implementation of the model described in their paper.
27 |
28 | The quality isn't as good as Google's demo yet, but hopefully it will get there someday :-).
29 | Pull requests are welcome!
30 |
31 |
32 |
33 | ## Quick Start
34 |
35 | ### Installing dependencies
36 |
37 | 1. Install Python 3.
38 |
39 | 2. Install the latest version of [TensorFlow](https://www.tensorflow.org/install/) for your platform. For better
40 | performance, install with GPU support if it's available. This code works with TensorFlow 1.3 and later.
41 |
42 | 3. Install requirements:
43 | ```
44 | pip install -r requirements.txt
45 | ```
46 |
47 |
48 | ### Using a pre-trained model
49 |
50 | 1. **Download and unpack a model**:
51 | ```
52 | curl https://data.keithito.com/data/speech/tacotron-20180906.tar.gz | tar xzC /tmp
53 | ```
54 |
55 | 2. **Run the demo server**:
56 | ```
57 | python3 demo_server.py --checkpoint /tmp/tacotron-20180906/model.ckpt
58 | ```
59 |
60 | 3. **Point your browser at localhost:9000**
61 | * Type what you want to synthesize
62 |
63 |
64 |
65 | ### Training
66 |
67 | *Note: you need at least 40GB of free disk space to train a model.*
68 |
69 | 1. **Download a speech dataset.**
70 |
71 | The following are supported out of the box:
72 | * [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) (Public Domain)
73 | * [Blizzard 2012](http://www.cstr.ed.ac.uk/projects/blizzard/2012/phase_one) (Creative Commons Attribution Share-Alike)
74 |
75 | You can use other datasets if you convert them to the right format. See [TRAINING_DATA.md](TRAINING_DATA.md) for more info.
76 |
77 |
78 | 2. **Unpack the dataset into `~/tacotron`**
79 |
80 | After unpacking, your tree should look like this for LJ Speech:
81 | ```
82 | tacotron
83 | |- LJSpeech-1.1
84 | |- metadata.csv
85 | |- wavs
86 | ```
87 |
88 | or like this for Blizzard 2012:
89 | ```
90 | tacotron
91 | |- Blizzard2012
92 | |- ATrampAbroad
93 | | |- sentence_index.txt
94 | | |- lab
95 | | |- wav
96 | |- TheManThatCorruptedHadleyburg
97 | |- sentence_index.txt
98 | |- lab
99 | |- wav
100 | ```
101 |
102 | 3. **Preprocess the data**
103 | ```
104 | python3 preprocess.py --dataset ljspeech
105 | ```
106 | * Use `--dataset blizzard` for Blizzard data
107 |
108 | 4. **Train a model**
109 | ```
110 | python3 train.py
111 | ```
112 |
113 | Tunable hyperparameters are found in [hparams.py](hparams.py). You can adjust these at the command
114 | line using the `--hparams` flag, for example `--hparams="batch_size=16,outputs_per_step=2"`.
115 | Hyperparameters should generally be set to the same values at both training and eval time.
116 | The default hyperparameters are recommended for LJ Speech and other English-language data.
117 | See [TRAINING_DATA.md](TRAINING_DATA.md) for other languages.
118 |
119 |
120 | 5. **Monitor with Tensorboard** (optional)
121 | ```
122 | tensorboard --logdir ~/tacotron/logs-tacotron
123 | ```
124 |
125 | The trainer dumps audio and alignments every 1000 steps. You can find these in
126 | `~/tacotron/logs-tacotron`.
127 |
128 | 6. **Synthesize from a checkpoint**
129 | ```
130 | python3 demo_server.py --checkpoint ~/tacotron/logs-tacotron/model.ckpt-185000
131 | ```
132 | Replace "185000" with the checkpoint number that you want to use, then open a browser
133 | to `localhost:9000` and type what you want to speak. Alternately, you can
134 | run [eval.py](eval.py) at the command line:
135 | ```
136 | python3 eval.py --checkpoint ~/tacotron/logs-tacotron/model.ckpt-185000
137 | ```
138 | If you set the `--hparams` flag when training, set the same value here.
139 |
140 |
141 | ## Notes and Common Issues
142 |
143 | * [TCMalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) seems to improve
144 | training speed and avoids occasional slowdowns seen with the default allocator. You
145 | can enable it by installing it and setting `LD_PRELOAD=/usr/lib/libtcmalloc.so`. With TCMalloc,
146 | you can get around 1.1 sec/step on a GTX 1080Ti.
147 |
148 | * You can train with [CMUDict](http://www.speech.cs.cmu.edu/cgi-bin/cmudict) by downloading the
149 | dictionary to ~/tacotron/training and then passing the flag `--hparams="use_cmudict=True"` to
150 | train.py. This will allow you to pass ARPAbet phonemes enclosed in curly braces at eval
151 | time to force a particular pronunciation, e.g. `Turn left on {HH AW1 S S T AH0 N} Street.`
152 |
153 | * If you pass a Slack incoming webhook URL as the `--slack_url` flag to train.py, it will send
154 | you progress updates every 1000 steps.
155 |
156 | * Occasionally, you may see a spike in loss and the model will forget how to attend (the
157 | alignments will no longer make sense). Although it will recover eventually, it may
158 | save time to restart at a checkpoint prior to the spike by passing the
159 | `--restore_step=150000` flag to train.py (replacing 150000 with a step number prior to the
160 | spike). **Update**: a recent [fix](https://github.com/keithito/tacotron/pull/7) to gradient
161 | clipping by @candlewill may have fixed this.
162 |
163 | * During eval and training, audio length is limited to `max_iters * outputs_per_step * frame_shift_ms`
164 | milliseconds. With the defaults (max_iters=200, outputs_per_step=5, frame_shift_ms=12.5), this is
165 | 12.5 seconds.
166 |
167 | If your training examples are longer, you will see an error like this:
168 | `Incompatible shapes: [32,1340,80] vs. [32,1000,80]`
169 |
170 | To fix this, you can set a larger value of `max_iters` by passing `--hparams="max_iters=300"` to
171 | train.py (replace "300" with a value based on how long your audio is and the formula above).
172 |
173 | * Here is the expected loss curve when training on LJ Speech with the default hyperparameters:
174 | 
175 |
176 |
177 | ## Other Implementations
178 | * By Alex Barron: https://github.com/barronalex/Tacotron
179 | * By Kyubyong Park: https://github.com/Kyubyong/tacotron
180 |
--------------------------------------------------------------------------------
/TRAINING_DATA.md:
--------------------------------------------------------------------------------
1 | # Training Data
2 |
3 |
4 | This repo supports the following speech datasets:
5 | * [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) (Public Domain)
6 | * [Blizzard 2012](http://www.cstr.ed.ac.uk/projects/blizzard/2012/phase_one) (Creative Commons Attribution Share-Alike)
7 |
8 | You can use any other dataset if you write a preprocessor for it.
9 |
10 |
11 | ### Writing a Preprocessor
12 |
13 | Each training example consists of:
14 | 1. The text that was spoken
15 | 2. A mel-scale spectrogram of the audio
16 | 3. A linear-scale spectrogram of the audio
17 |
18 | The preprocessor is responsible for generating these. See [ljspeech.py](datasets/ljspeech.py) for a
19 | commented example.
20 |
21 | For each training example, a preprocessor should:
22 |
23 | 1. Load the audio file:
24 | ```python
25 | wav = audio.load_wav(wav_path)
26 | ```
27 |
28 | 2. Compute linear-scale and mel-scale spectrograms (float32 numpy arrays):
29 | ```python
30 | spectrogram = audio.spectrogram(wav).astype(np.float32)
31 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
32 | ```
33 |
34 | 3. Save the spectrograms to disk:
35 | ```python
36 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False)
37 | np.save(os.path.join(out_dir, mel_spectrogram_filename), mel_spectrogram.T, allow_pickle=False)
38 | ```
39 | Note that the transpose of the matrix returned by `audio.spectrogram` is saved so that it's
40 | in time-major format.
41 |
42 | 4. Generate a tuple `(spectrogram_filename, mel_spectrogram_filename, n_frames, text)` to
43 | write to train.txt. n_frames is just the length of the time axis of the spectrogram.
44 |
45 |
46 | After you've written your preprocessor, you can add it to [preprocess.py](preprocess.py) by
47 | following the example of the other preprocessors in that file.
48 |
49 |
50 | ### Non-English Data
51 |
52 | If your training data is in a language other than English, you will probably want to change the
53 | text cleaners by setting the `cleaners` hyperparameter.
54 |
55 | * If your text is in a Latin script or can be transliterated to ASCII using the
56 | [Unidecode](https://pypi.python.org/pypi/Unidecode) library, you can use the transliteration
57 | cleaners by setting the hyperparameter `cleaners=transliteration_cleaners`.
58 |
59 | * If you don't want to transliterate, you can define a custom character set.
60 | This allows you to train directly on the character set used in your data.
61 |
62 | To do so, edit [symbols.py](text/symbols.py) and change the `_characters` variable to be a
63 | string containing the UTF-8 characters in your data. Then set the hyperparameter `cleaners=basic_cleaners`.
64 |
65 | * If you're not sure which option to use, you can evaluate the transliteration cleaners like this:
66 |
67 | ```python
68 | from text import cleaners
69 | cleaners.transliteration_cleaners('Здравствуйте') # Replace with the text you want to try
70 | ```
71 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keithito/tacotron/d26c763342518d4e432e9c4036a1aff3b4fdaa1e/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/blizzard.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ProcessPoolExecutor
2 | from functools import partial
3 | import numpy as np
4 | import os
5 | from hparams import hparams
6 | from util import audio
7 |
8 |
9 | _max_out_length = 700
10 | _end_buffer = 0.05
11 | _min_confidence = 90
12 |
13 | # Note: "A Tramp Abroad" & "The Man That Corrupted Hadleyburg" are higher quality than the others.
14 | books = [
15 | 'ATrampAbroad',
16 | 'TheManThatCorruptedHadleyburg',
17 | # 'LifeOnTheMississippi',
18 | # 'TheAdventuresOfTomSawyer',
19 | ]
20 |
21 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
22 | executor = ProcessPoolExecutor(max_workers=num_workers)
23 | futures = []
24 | index = 1
25 | for book in books:
26 | with open(os.path.join(in_dir, book, 'sentence_index.txt')) as f:
27 | for line in f:
28 | parts = line.strip().split('\t')
29 | if line[0] is not '#' and len(parts) == 8 and float(parts[3]) > _min_confidence:
30 | wav_path = os.path.join(in_dir, book, 'wav', '%s.wav' % parts[0])
31 | labels_path = os.path.join(in_dir, book, 'lab', '%s.lab' % parts[0])
32 | text = parts[5]
33 | task = partial(_process_utterance, out_dir, index, wav_path, labels_path, text)
34 | futures.append(executor.submit(task))
35 | index += 1
36 | results = [future.result() for future in tqdm(futures)]
37 | return [r for r in results if r is not None]
38 |
39 |
40 | def _process_utterance(out_dir, index, wav_path, labels_path, text):
41 | # Load the wav file and trim silence from the ends:
42 | wav = audio.load_wav(wav_path)
43 | start_offset, end_offset = _parse_labels(labels_path)
44 | start = int(start_offset * hparams.sample_rate)
45 | end = int(end_offset * hparams.sample_rate) if end_offset is not None else -1
46 | wav = wav[start:end]
47 | max_samples = _max_out_length * hparams.frame_shift_ms / 1000 * hparams.sample_rate
48 | if len(wav) > max_samples:
49 | return None
50 | spectrogram = audio.spectrogram(wav).astype(np.float32)
51 | n_frames = spectrogram.shape[1]
52 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
53 | spectrogram_filename = 'blizzard-spec-%05d.npy' % index
54 | mel_filename = 'blizzard-mel-%05d.npy' % index
55 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False)
56 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False)
57 | return (spectrogram_filename, mel_filename, n_frames, text)
58 |
59 |
60 | def _parse_labels(path):
61 | labels = []
62 | with open(os.path.join(path)) as f:
63 | for line in f:
64 | parts = line.strip().split(' ')
65 | if len(parts) >= 3:
66 | labels.append((float(parts[0]), ' '.join(parts[2:])))
67 | start = 0
68 | end = None
69 | if labels[0][1] == 'sil':
70 | start = labels[0][0]
71 | if labels[-1][1] == 'sil':
72 | end = labels[-2][0] + _end_buffer
73 | return (start, end)
74 |
--------------------------------------------------------------------------------
/datasets/datafeeder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import tensorflow as tf
5 | import threading
6 | import time
7 | import traceback
8 | from text import cmudict, text_to_sequence
9 | from util.infolog import log
10 |
11 |
12 | _batches_per_group = 32
13 | _p_cmudict = 0.5
14 | _pad = 0
15 |
16 |
17 | class DataFeeder(threading.Thread):
18 | '''Feeds batches of data into a queue on a background thread.'''
19 |
20 | def __init__(self, coordinator, metadata_filename, hparams):
21 | super(DataFeeder, self).__init__()
22 | self._coord = coordinator
23 | self._hparams = hparams
24 | self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
25 | self._offset = 0
26 |
27 | # Load metadata:
28 | self._datadir = os.path.dirname(metadata_filename)
29 | with open(metadata_filename, encoding='utf-8') as f:
30 | self._metadata = [line.strip().split('|') for line in f]
31 | hours = sum((int(x[2]) for x in self._metadata)) * hparams.frame_shift_ms / (3600 * 1000)
32 | log('Loaded metadata for %d examples (%.2f hours)' % (len(self._metadata), hours))
33 |
34 | # Create placeholders for inputs and targets. Don't specify batch size because we want to
35 | # be able to feed different sized batches at eval time.
36 | self._placeholders = [
37 | tf.placeholder(tf.int32, [None, None], 'inputs'),
38 | tf.placeholder(tf.int32, [None], 'input_lengths'),
39 | tf.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets'),
40 | tf.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets')
41 | ]
42 |
43 | # Create queue for buffering data:
44 | queue = tf.FIFOQueue(8, [tf.int32, tf.int32, tf.float32, tf.float32], name='input_queue')
45 | self._enqueue_op = queue.enqueue(self._placeholders)
46 | self.inputs, self.input_lengths, self.mel_targets, self.linear_targets = queue.dequeue()
47 | self.inputs.set_shape(self._placeholders[0].shape)
48 | self.input_lengths.set_shape(self._placeholders[1].shape)
49 | self.mel_targets.set_shape(self._placeholders[2].shape)
50 | self.linear_targets.set_shape(self._placeholders[3].shape)
51 |
52 | # Load CMUDict: If enabled, this will randomly substitute some words in the training data with
53 | # their ARPABet equivalents, which will allow you to also pass ARPABet to the model for
54 | # synthesis (useful for proper nouns, etc.)
55 | if hparams.use_cmudict:
56 | cmudict_path = os.path.join(self._datadir, 'cmudict-0.7b')
57 | if not os.path.isfile(cmudict_path):
58 | raise Exception('If use_cmudict=True, you must download ' +
59 | 'http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b to %s' % cmudict_path)
60 | self._cmudict = cmudict.CMUDict(cmudict_path, keep_ambiguous=False)
61 | log('Loaded CMUDict with %d unambiguous entries' % len(self._cmudict))
62 | else:
63 | self._cmudict = None
64 |
65 |
66 | def start_in_session(self, session):
67 | self._session = session
68 | self.start()
69 |
70 |
71 | def run(self):
72 | try:
73 | while not self._coord.should_stop():
74 | self._enqueue_next_group()
75 | except Exception as e:
76 | traceback.print_exc()
77 | self._coord.request_stop(e)
78 |
79 |
80 | def _enqueue_next_group(self):
81 | start = time.time()
82 |
83 | # Read a group of examples:
84 | n = self._hparams.batch_size
85 | r = self._hparams.outputs_per_step
86 | examples = [self._get_next_example() for i in range(n * _batches_per_group)]
87 |
88 | # Bucket examples based on similar output sequence length for efficiency:
89 | examples.sort(key=lambda x: x[-1])
90 | batches = [examples[i:i+n] for i in range(0, len(examples), n)]
91 | random.shuffle(batches)
92 |
93 | log('Generated %d batches of size %d in %.03f sec' % (len(batches), n, time.time() - start))
94 | for batch in batches:
95 | feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r)))
96 | self._session.run(self._enqueue_op, feed_dict=feed_dict)
97 |
98 |
99 | def _get_next_example(self):
100 | '''Loads a single example (input, mel_target, linear_target, cost) from disk'''
101 | if self._offset >= len(self._metadata):
102 | self._offset = 0
103 | random.shuffle(self._metadata)
104 | meta = self._metadata[self._offset]
105 | self._offset += 1
106 |
107 | text = meta[3]
108 | if self._cmudict and random.random() < _p_cmudict:
109 | text = ' '.join([self._maybe_get_arpabet(word) for word in text.split(' ')])
110 |
111 | input_data = np.asarray(text_to_sequence(text, self._cleaner_names), dtype=np.int32)
112 | linear_target = np.load(os.path.join(self._datadir, meta[0]))
113 | mel_target = np.load(os.path.join(self._datadir, meta[1]))
114 | return (input_data, mel_target, linear_target, len(linear_target))
115 |
116 |
117 | def _maybe_get_arpabet(self, word):
118 | arpabet = self._cmudict.lookup(word)
119 | return '{%s}' % arpabet[0] if arpabet is not None and random.random() < 0.5 else word
120 |
121 |
122 | def _prepare_batch(batch, outputs_per_step):
123 | random.shuffle(batch)
124 | inputs = _prepare_inputs([x[0] for x in batch])
125 | input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
126 | mel_targets = _prepare_targets([x[1] for x in batch], outputs_per_step)
127 | linear_targets = _prepare_targets([x[2] for x in batch], outputs_per_step)
128 | return (inputs, input_lengths, mel_targets, linear_targets)
129 |
130 |
131 | def _prepare_inputs(inputs):
132 | max_len = max((len(x) for x in inputs))
133 | return np.stack([_pad_input(x, max_len) for x in inputs])
134 |
135 |
136 | def _prepare_targets(targets, alignment):
137 | max_len = max((len(t) for t in targets)) + 1
138 | return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets])
139 |
140 |
141 | def _pad_input(x, length):
142 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
143 |
144 |
145 | def _pad_target(t, length):
146 | return np.pad(t, [(0, length - t.shape[0]), (0,0)], mode='constant', constant_values=_pad)
147 |
148 |
149 | def _round_up(x, multiple):
150 | remainder = x % multiple
151 | return x if remainder == 0 else x + multiple - remainder
152 |
--------------------------------------------------------------------------------
/datasets/ljspeech.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ProcessPoolExecutor
2 | from functools import partial
3 | import numpy as np
4 | import os
5 | from util import audio
6 |
7 |
8 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
9 | '''Preprocesses the LJ Speech dataset from a given input path into a given output directory.
10 |
11 | Args:
12 | in_dir: The directory where you have downloaded the LJ Speech dataset
13 | out_dir: The directory to write the output into
14 | num_workers: Optional number of worker processes to parallelize across
15 | tqdm: You can optionally pass tqdm to get a nice progress bar
16 |
17 | Returns:
18 | A list of tuples describing the training examples. This should be written to train.txt
19 | '''
20 |
21 | # We use ProcessPoolExecutor to parallelize across processes. This is just an optimization and you
22 | # can omit it and just call _process_utterance on each input if you want.
23 | executor = ProcessPoolExecutor(max_workers=num_workers)
24 | futures = []
25 | index = 1
26 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f:
27 | for line in f:
28 | parts = line.strip().split('|')
29 | wav_path = os.path.join(in_dir, 'wavs', '%s.wav' % parts[0])
30 | text = parts[2]
31 | futures.append(executor.submit(partial(_process_utterance, out_dir, index, wav_path, text)))
32 | index += 1
33 | return [future.result() for future in tqdm(futures)]
34 |
35 |
36 | def _process_utterance(out_dir, index, wav_path, text):
37 | '''Preprocesses a single utterance audio/text pair.
38 |
39 | This writes the mel and linear scale spectrograms to disk and returns a tuple to write
40 | to the train.txt file.
41 |
42 | Args:
43 | out_dir: The directory to write the spectrograms into
44 | index: The numeric index to use in the spectrogram filenames.
45 | wav_path: Path to the audio file containing the speech input
46 | text: The text spoken in the input audio file
47 |
48 | Returns:
49 | A (spectrogram_filename, mel_filename, n_frames, text) tuple to write to train.txt
50 | '''
51 |
52 | # Load the audio to a numpy array:
53 | wav = audio.load_wav(wav_path)
54 |
55 | # Compute the linear-scale spectrogram from the wav:
56 | spectrogram = audio.spectrogram(wav).astype(np.float32)
57 | n_frames = spectrogram.shape[1]
58 |
59 | # Compute a mel-scale spectrogram from the wav:
60 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
61 |
62 | # Write the spectrograms to disk:
63 | spectrogram_filename = 'ljspeech-spec-%05d.npy' % index
64 | mel_filename = 'ljspeech-mel-%05d.npy' % index
65 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False)
66 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False)
67 |
68 | # Return a tuple describing this training example:
69 | return (spectrogram_filename, mel_filename, n_frames, text)
70 |
--------------------------------------------------------------------------------
/demo_server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import falcon
3 | from hparams import hparams, hparams_debug_string
4 | import os
5 | from synthesizer import Synthesizer
6 |
7 |
8 | html_body = '''
Demo
9 |
20 |
21 |
25 |
26 |
27 |
57 | '''
58 |
59 |
60 | class UIResource:
61 | def on_get(self, req, res):
62 | res.content_type = 'text/html'
63 | res.body = html_body
64 |
65 |
66 | class SynthesisResource:
67 | def on_get(self, req, res):
68 | if not req.params.get('text'):
69 | raise falcon.HTTPBadRequest()
70 | res.data = synthesizer.synthesize(req.params.get('text'))
71 | res.content_type = 'audio/wav'
72 |
73 |
74 | synthesizer = Synthesizer()
75 | api = falcon.API()
76 | api.add_route('/synthesize', SynthesisResource())
77 | api.add_route('/', UIResource())
78 |
79 |
80 | if __name__ == '__main__':
81 | from wsgiref import simple_server
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument('--checkpoint', required=True, help='Full path to model checkpoint')
84 | parser.add_argument('--port', type=int, default=9000)
85 | parser.add_argument('--hparams', default='',
86 | help='Hyperparameter overrides as a comma-separated list of name=value pairs')
87 | args = parser.parse_args()
88 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
89 | hparams.parse(args.hparams)
90 | print(hparams_debug_string())
91 | synthesizer.load(args.checkpoint)
92 | print('Serving on port %d' % args.port)
93 | simple_server.make_server('0.0.0.0', args.port, api).serve_forever()
94 | else:
95 | synthesizer.load(os.environ['CHECKPOINT'])
96 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | from hparams import hparams, hparams_debug_string
5 | from synthesizer import Synthesizer
6 |
7 |
8 | sentences = [
9 | # From July 8, 2017 New York Times:
10 | 'Scientists at the CERN laboratory say they have discovered a new particle.',
11 | 'There’s a way to measure the acute emotional intelligence that has never gone out of style.',
12 | 'President Trump met with other leaders at the Group of 20 conference.',
13 | 'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
14 | # From Google's Tacotron example page:
15 | 'Generative adversarial network or variational auto-encoder.',
16 | 'The buses aren\'t the problem, they actually provide a solution.',
17 | 'Does the quick brown fox jump over the lazy dog?',
18 | 'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
19 | ]
20 |
21 |
22 | def get_output_base_path(checkpoint_path):
23 | base_dir = os.path.dirname(checkpoint_path)
24 | m = re.compile(r'.*?\.ckpt\-([0-9]+)').match(checkpoint_path)
25 | name = 'eval-%d' % int(m.group(1)) if m else 'eval'
26 | return os.path.join(base_dir, name)
27 |
28 |
29 | def run_eval(args):
30 | print(hparams_debug_string())
31 | synth = Synthesizer()
32 | synth.load(args.checkpoint)
33 | base_path = get_output_base_path(args.checkpoint)
34 | for i, text in enumerate(sentences):
35 | path = '%s-%d.wav' % (base_path, i)
36 | print('Synthesizing: %s' % path)
37 | with open(path, 'wb') as f:
38 | f.write(synth.synthesize(text))
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
44 | parser.add_argument('--hparams', default='',
45 | help='Hyperparameter overrides as a comma-separated list of name=value pairs')
46 | args = parser.parse_args()
47 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
48 | hparams.parse(args.hparams)
49 | run_eval(args)
50 |
51 |
52 | if __name__ == '__main__':
53 | main()
54 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | # Default hyperparameters:
5 | hparams = tf.contrib.training.HParams(
6 | # Comma-separated list of cleaners to run on text prior to training and eval. For non-English
7 | # text, you may want to use "basic_cleaners" or "transliteration_cleaners" See TRAINING_DATA.md.
8 | cleaners='english_cleaners',
9 |
10 | # Audio:
11 | num_mels=80,
12 | num_freq=1025,
13 | sample_rate=20000,
14 | frame_length_ms=50,
15 | frame_shift_ms=12.5,
16 | preemphasis=0.97,
17 | min_level_db=-100,
18 | ref_level_db=20,
19 |
20 | # Model:
21 | outputs_per_step=5,
22 | embed_depth=256,
23 | prenet_depths=[256, 128],
24 | encoder_depth=256,
25 | postnet_depth=256,
26 | attention_depth=256,
27 | decoder_depth=256,
28 |
29 | # Training:
30 | batch_size=32,
31 | adam_beta1=0.9,
32 | adam_beta2=0.999,
33 | initial_learning_rate=0.002,
34 | decay_learning_rate=True,
35 | use_cmudict=False, # Use CMUDict during training to learn pronunciation of ARPAbet phonemes
36 |
37 | # Eval:
38 | max_iters=200,
39 | griffin_lim_iters=60,
40 | power=1.5, # Power to raise magnitudes to prior to Griffin-Lim
41 | )
42 |
43 |
44 | def hparams_debug_string():
45 | values = hparams.values()
46 | hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]
47 | return 'Hyperparameters:\n' + '\n'.join(hp)
48 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .tacotron import Tacotron
2 |
3 |
4 | def create_model(name, hparams):
5 | if name == 'tacotron':
6 | return Tacotron(hparams)
7 | else:
8 | raise Exception('Unknown model: ' + name)
9 |
--------------------------------------------------------------------------------
/models/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.contrib.seq2seq import Helper
4 |
5 |
6 | # Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper
7 | class TacoTestHelper(Helper):
8 | def __init__(self, batch_size, output_dim, r):
9 | with tf.name_scope('TacoTestHelper'):
10 | self._batch_size = batch_size
11 | self._output_dim = output_dim
12 | self._end_token = tf.tile([0.0], [output_dim * r])
13 |
14 | @property
15 | def batch_size(self):
16 | return self._batch_size
17 |
18 | @property
19 | def sample_ids_shape(self):
20 | return tf.TensorShape([])
21 |
22 | @property
23 | def sample_ids_dtype(self):
24 | return np.int32
25 |
26 | def initialize(self, name=None):
27 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))
28 |
29 | def sample(self, time, outputs, state, name=None):
30 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them
31 |
32 | def next_inputs(self, time, outputs, state, sample_ids, name=None):
33 | '''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.'''
34 | with tf.name_scope('TacoTestHelper'):
35 | finished = tf.reduce_all(tf.equal(outputs, self._end_token), axis=1)
36 | # Feed last output frame as next input. outputs is [N, output_dim * r]
37 | next_inputs = outputs[:, -self._output_dim:]
38 | return (finished, next_inputs, state)
39 |
40 |
41 | class TacoTrainingHelper(Helper):
42 | def __init__(self, inputs, targets, output_dim, r):
43 | # inputs is [N, T_in], targets is [N, T_out, D]
44 | with tf.name_scope('TacoTrainingHelper'):
45 | self._batch_size = tf.shape(inputs)[0]
46 | self._output_dim = output_dim
47 |
48 | # Feed every r-th target frame as input
49 | self._targets = targets[:, r-1::r, :]
50 |
51 | # Use full length for every target because we don't want to mask the padding frames
52 | num_steps = tf.shape(self._targets)[1]
53 | self._lengths = tf.tile([num_steps], [self._batch_size])
54 |
55 | @property
56 | def batch_size(self):
57 | return self._batch_size
58 |
59 | @property
60 | def sample_ids_shape(self):
61 | return tf.TensorShape([])
62 |
63 | @property
64 | def sample_ids_dtype(self):
65 | return np.int32
66 |
67 | def initialize(self, name=None):
68 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))
69 |
70 | def sample(self, time, outputs, state, name=None):
71 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them
72 |
73 | def next_inputs(self, time, outputs, state, sample_ids, name=None):
74 | with tf.name_scope(name or 'TacoTrainingHelper'):
75 | finished = (time + 1 >= self._lengths)
76 | next_inputs = self._targets[:, time, :]
77 | return (finished, next_inputs, state)
78 |
79 |
80 | def _go_frames(batch_size, output_dim):
81 | '''Returns all-zero frames for a given batch size and output dimension'''
82 | return tf.tile([[0.0]], [batch_size, output_dim])
83 |
84 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.rnn import GRUCell
3 |
4 |
5 | def prenet(inputs, is_training, layer_sizes, scope=None):
6 | x = inputs
7 | drop_rate = 0.5 if is_training else 0.0
8 | with tf.variable_scope(scope or 'prenet'):
9 | for i, size in enumerate(layer_sizes):
10 | dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1))
11 | x = tf.layers.dropout(dense, rate=drop_rate, training=is_training, name='dropout_%d' % (i+1))
12 | return x
13 |
14 |
15 | def encoder_cbhg(inputs, input_lengths, is_training, depth):
16 | input_channels = inputs.get_shape()[2]
17 | return cbhg(
18 | inputs,
19 | input_lengths,
20 | is_training,
21 | scope='encoder_cbhg',
22 | K=16,
23 | projections=[128, input_channels],
24 | depth=depth)
25 |
26 |
27 | def post_cbhg(inputs, input_dim, is_training, depth):
28 | return cbhg(
29 | inputs,
30 | None,
31 | is_training,
32 | scope='post_cbhg',
33 | K=8,
34 | projections=[256, input_dim],
35 | depth=depth)
36 |
37 |
38 | def cbhg(inputs, input_lengths, is_training, scope, K, projections, depth):
39 | with tf.variable_scope(scope):
40 | with tf.variable_scope('conv_bank'):
41 | # Convolution bank: concatenate on the last axis to stack channels from all convolutions
42 | conv_outputs = tf.concat(
43 | [conv1d(inputs, k, 128, tf.nn.relu, is_training, 'conv1d_%d' % k) for k in range(1, K+1)],
44 | axis=-1
45 | )
46 |
47 | # Maxpooling:
48 | maxpool_output = tf.layers.max_pooling1d(
49 | conv_outputs,
50 | pool_size=2,
51 | strides=1,
52 | padding='same')
53 |
54 | # Two projection layers:
55 | proj1_output = conv1d(maxpool_output, 3, projections[0], tf.nn.relu, is_training, 'proj_1')
56 | proj2_output = conv1d(proj1_output, 3, projections[1], None, is_training, 'proj_2')
57 |
58 | # Residual connection:
59 | highway_input = proj2_output + inputs
60 |
61 | half_depth = depth // 2
62 | assert half_depth*2 == depth, 'encoder and postnet depths must be even.'
63 |
64 | # Handle dimensionality mismatch:
65 | if highway_input.shape[2] != half_depth:
66 | highway_input = tf.layers.dense(highway_input, half_depth)
67 |
68 | # 4-layer HighwayNet:
69 | for i in range(4):
70 | highway_input = highwaynet(highway_input, 'highway_%d' % (i+1), half_depth)
71 | rnn_input = highway_input
72 |
73 | # Bidirectional RNN
74 | outputs, states = tf.nn.bidirectional_dynamic_rnn(
75 | GRUCell(half_depth),
76 | GRUCell(half_depth),
77 | rnn_input,
78 | sequence_length=input_lengths,
79 | dtype=tf.float32)
80 | return tf.concat(outputs, axis=2) # Concat forward and backward
81 |
82 |
83 | def highwaynet(inputs, scope, depth):
84 | with tf.variable_scope(scope):
85 | H = tf.layers.dense(
86 | inputs,
87 | units=depth,
88 | activation=tf.nn.relu,
89 | name='H')
90 | T = tf.layers.dense(
91 | inputs,
92 | units=depth,
93 | activation=tf.nn.sigmoid,
94 | name='T',
95 | bias_initializer=tf.constant_initializer(-1.0))
96 | return H * T + inputs * (1.0 - T)
97 |
98 |
99 | def conv1d(inputs, kernel_size, channels, activation, is_training, scope):
100 | with tf.variable_scope(scope):
101 | conv1d_output = tf.layers.conv1d(
102 | inputs,
103 | filters=channels,
104 | kernel_size=kernel_size,
105 | activation=activation,
106 | padding='same')
107 | return tf.layers.batch_normalization(conv1d_output, training=is_training)
108 |
--------------------------------------------------------------------------------
/models/rnn_wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.contrib.rnn import RNNCell
4 | from .modules import prenet
5 |
6 |
7 | class DecoderPrenetWrapper(RNNCell):
8 | '''Runs RNN inputs through a prenet before sending them to the cell.'''
9 | def __init__(self, cell, is_training, layer_sizes):
10 | super(DecoderPrenetWrapper, self).__init__()
11 | self._cell = cell
12 | self._is_training = is_training
13 | self._layer_sizes = layer_sizes
14 |
15 | @property
16 | def state_size(self):
17 | return self._cell.state_size
18 |
19 | @property
20 | def output_size(self):
21 | return self._cell.output_size
22 |
23 | def call(self, inputs, state):
24 | prenet_out = prenet(inputs, self._is_training, self._layer_sizes, scope='decoder_prenet')
25 | return self._cell(prenet_out, state)
26 |
27 | def zero_state(self, batch_size, dtype):
28 | return self._cell.zero_state(batch_size, dtype)
29 |
30 |
31 |
32 | class ConcatOutputAndAttentionWrapper(RNNCell):
33 | '''Concatenates RNN cell output with the attention context vector.
34 |
35 | This is expected to wrap a cell wrapped with an AttentionWrapper constructed with
36 | attention_layer_size=None and output_attention=False. Such a cell's state will include an
37 | "attention" field that is the context vector.
38 | '''
39 | def __init__(self, cell):
40 | super(ConcatOutputAndAttentionWrapper, self).__init__()
41 | self._cell = cell
42 |
43 | @property
44 | def state_size(self):
45 | return self._cell.state_size
46 |
47 | @property
48 | def output_size(self):
49 | return self._cell.output_size + self._cell.state_size.attention
50 |
51 | def call(self, inputs, state):
52 | output, res_state = self._cell(inputs, state)
53 | return tf.concat([output, res_state.attention], axis=-1), res_state
54 |
55 | def zero_state(self, batch_size, dtype):
56 | return self._cell.zero_state(batch_size, dtype)
57 |
--------------------------------------------------------------------------------
/models/tacotron.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
3 | from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper
4 | from text.symbols import symbols
5 | from util.infolog import log
6 | from .helpers import TacoTestHelper, TacoTrainingHelper
7 | from .modules import encoder_cbhg, post_cbhg, prenet
8 | from .rnn_wrappers import DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper
9 |
10 |
11 |
12 | class Tacotron():
13 | def __init__(self, hparams):
14 | self._hparams = hparams
15 |
16 |
17 | def initialize(self, inputs, input_lengths, mel_targets=None, linear_targets=None):
18 | '''Initializes the model for inference.
19 |
20 | Sets "mel_outputs", "linear_outputs", and "alignments" fields.
21 |
22 | Args:
23 | inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of
24 | steps in the input time series, and values are character IDs
25 | input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths
26 | of each sequence in inputs.
27 | mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number
28 | of steps in the output time series, M is num_mels, and values are entries in the mel
29 | spectrogram. Only needed for training.
30 | linear_targets: float32 Tensor with shape [N, T_out, F] where N is batch_size, T_out is number
31 | of steps in the output time series, F is num_freq, and values are entries in the linear
32 | spectrogram. Only needed for training.
33 | '''
34 | with tf.variable_scope('inference') as scope:
35 | is_training = linear_targets is not None
36 | batch_size = tf.shape(inputs)[0]
37 | hp = self._hparams
38 |
39 | # Embeddings
40 | embedding_table = tf.get_variable(
41 | 'embedding', [len(symbols), hp.embed_depth], dtype=tf.float32,
42 | initializer=tf.truncated_normal_initializer(stddev=0.5))
43 | embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, embed_depth=256]
44 |
45 | # Encoder
46 | prenet_outputs = prenet(embedded_inputs, is_training, hp.prenet_depths) # [N, T_in, prenet_depths[-1]=128]
47 | encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training, # [N, T_in, encoder_depth=256]
48 | hp.encoder_depth)
49 |
50 | # Attention
51 | attention_cell = AttentionWrapper(
52 | GRUCell(hp.attention_depth),
53 | BahdanauAttention(hp.attention_depth, encoder_outputs),
54 | alignment_history=True,
55 | output_attention=False) # [N, T_in, attention_depth=256]
56 |
57 | # Apply prenet before concatenation in AttentionWrapper.
58 | attention_cell = DecoderPrenetWrapper(attention_cell, is_training, hp.prenet_depths)
59 |
60 | # Concatenate attention context vector and RNN cell output into a 2*attention_depth=512D vector.
61 | concat_cell = ConcatOutputAndAttentionWrapper(attention_cell) # [N, T_in, 2*attention_depth=512]
62 |
63 | # Decoder (layers specified bottom to top):
64 | decoder_cell = MultiRNNCell([
65 | OutputProjectionWrapper(concat_cell, hp.decoder_depth),
66 | ResidualWrapper(GRUCell(hp.decoder_depth)),
67 | ResidualWrapper(GRUCell(hp.decoder_depth))
68 | ], state_is_tuple=True) # [N, T_in, decoder_depth=256]
69 |
70 | # Project onto r mel spectrograms (predict r outputs at each RNN step):
71 | output_cell = OutputProjectionWrapper(decoder_cell, hp.num_mels * hp.outputs_per_step)
72 | decoder_init_state = output_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
73 |
74 | if is_training:
75 | helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.outputs_per_step)
76 | else:
77 | helper = TacoTestHelper(batch_size, hp.num_mels, hp.outputs_per_step)
78 |
79 | (decoder_outputs, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
80 | BasicDecoder(output_cell, helper, decoder_init_state),
81 | maximum_iterations=hp.max_iters) # [N, T_out/r, M*r]
82 |
83 | # Reshape outputs to be one output per entry
84 | mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels]) # [N, T_out, M]
85 |
86 | # Add post-processing CBHG:
87 | post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training, # [N, T_out, postnet_depth=256]
88 | hp.postnet_depth)
89 | linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F]
90 |
91 | # Grab alignments from the final decoder state:
92 | alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0])
93 |
94 | self.inputs = inputs
95 | self.input_lengths = input_lengths
96 | self.mel_outputs = mel_outputs
97 | self.linear_outputs = linear_outputs
98 | self.alignments = alignments
99 | self.mel_targets = mel_targets
100 | self.linear_targets = linear_targets
101 | log('Initialized Tacotron model. Dimensions: ')
102 | log(' embedding: %d' % embedded_inputs.shape[-1])
103 | log(' prenet out: %d' % prenet_outputs.shape[-1])
104 | log(' encoder out: %d' % encoder_outputs.shape[-1])
105 | log(' attention out: %d' % attention_cell.output_size)
106 | log(' concat attn & out: %d' % concat_cell.output_size)
107 | log(' decoder cell out: %d' % decoder_cell.output_size)
108 | log(' decoder out (%d frames): %d' % (hp.outputs_per_step, decoder_outputs.shape[-1]))
109 | log(' decoder out (1 frame): %d' % mel_outputs.shape[-1])
110 | log(' postnet out: %d' % post_outputs.shape[-1])
111 | log(' linear out: %d' % linear_outputs.shape[-1])
112 |
113 |
114 | def add_loss(self):
115 | '''Adds loss to the model. Sets "loss" field. initialize must have been called.'''
116 | with tf.variable_scope('loss') as scope:
117 | hp = self._hparams
118 | self.mel_loss = tf.reduce_mean(tf.abs(self.mel_targets - self.mel_outputs))
119 | l1 = tf.abs(self.linear_targets - self.linear_outputs)
120 | # Prioritize loss for frequencies under 3000 Hz.
121 | n_priority_freq = int(3000 / (hp.sample_rate * 0.5) * hp.num_freq)
122 | self.linear_loss = 0.5 * tf.reduce_mean(l1) + 0.5 * tf.reduce_mean(l1[:,:,0:n_priority_freq])
123 | self.loss = self.mel_loss + self.linear_loss
124 |
125 |
126 | def add_optimizer(self, global_step):
127 | '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called.
128 |
129 | Args:
130 | global_step: int32 scalar Tensor representing current global step in training
131 | '''
132 | with tf.variable_scope('optimizer') as scope:
133 | hp = self._hparams
134 | if hp.decay_learning_rate:
135 | self.learning_rate = _learning_rate_decay(hp.initial_learning_rate, global_step)
136 | else:
137 | self.learning_rate = tf.convert_to_tensor(hp.initial_learning_rate)
138 | optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2)
139 | gradients, variables = zip(*optimizer.compute_gradients(self.loss))
140 | self.gradients = gradients
141 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
142 |
143 | # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See:
144 | # https://github.com/tensorflow/tensorflow/issues/1122
145 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
146 | self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables),
147 | global_step=global_step)
148 |
149 |
150 | def _learning_rate_decay(init_lr, global_step):
151 | # Noam scheme from tensor2tensor:
152 | warmup_steps = 4000.0
153 | step = tf.cast(global_step + 1, dtype=tf.float32)
154 | return init_lr * warmup_steps**0.5 * tf.minimum(step * warmup_steps**-1.5, step**-0.5)
155 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from multiprocessing import cpu_count
4 | from tqdm import tqdm
5 | from datasets import blizzard, ljspeech
6 | from hparams import hparams
7 |
8 |
9 | def preprocess_blizzard(args):
10 | in_dir = os.path.join(args.base_dir, 'Blizzard2012')
11 | out_dir = os.path.join(args.base_dir, args.output)
12 | os.makedirs(out_dir, exist_ok=True)
13 | metadata = blizzard.build_from_path(in_dir, out_dir, args.num_workers, tqdm=tqdm)
14 | write_metadata(metadata, out_dir)
15 |
16 |
17 | def preprocess_ljspeech(args):
18 | in_dir = os.path.join(args.base_dir, 'LJSpeech-1.1')
19 | out_dir = os.path.join(args.base_dir, args.output)
20 | os.makedirs(out_dir, exist_ok=True)
21 | metadata = ljspeech.build_from_path(in_dir, out_dir, args.num_workers, tqdm=tqdm)
22 | write_metadata(metadata, out_dir)
23 |
24 |
25 | def write_metadata(metadata, out_dir):
26 | with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f:
27 | for m in metadata:
28 | f.write('|'.join([str(x) for x in m]) + '\n')
29 | frames = sum([m[2] for m in metadata])
30 | hours = frames * hparams.frame_shift_ms / (3600 * 1000)
31 | print('Wrote %d utterances, %d frames (%.2f hours)' % (len(metadata), frames, hours))
32 | print('Max input length: %d' % max(len(m[3]) for m in metadata))
33 | print('Max output length: %d' % max(m[2] for m in metadata))
34 |
35 |
36 | def main():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--base_dir', default=os.path.expanduser('~/tacotron'))
39 | parser.add_argument('--output', default='training')
40 | parser.add_argument('--dataset', required=True, choices=['blizzard', 'ljspeech'])
41 | parser.add_argument('--num_workers', type=int, default=cpu_count())
42 | args = parser.parse_args()
43 | if args.dataset == 'blizzard':
44 | preprocess_blizzard(args)
45 | elif args.dataset == 'ljspeech':
46 | preprocess_ljspeech(args)
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Note: this doesn't include tensorflow or tensorflow-gpu because the package you need to install
2 | # depends on your platform. It is assumed you have already installed tensorflow.
3 | falcon==1.2.0
4 | inflect==0.2.5
5 | librosa==0.5.1
6 | matplotlib==2.0.2
7 | numpy==1.14.3
8 | scipy==0.19.0
9 | tqdm==4.11.2
10 | Unidecode==0.4.20
11 |
--------------------------------------------------------------------------------
/synthesizer.py:
--------------------------------------------------------------------------------
1 | import io
2 | import numpy as np
3 | import tensorflow as tf
4 | from hparams import hparams
5 | from librosa import effects
6 | from models import create_model
7 | from text import text_to_sequence
8 | from util import audio
9 |
10 |
11 | class Synthesizer:
12 | def load(self, checkpoint_path, model_name='tacotron'):
13 | print('Constructing model: %s' % model_name)
14 | inputs = tf.placeholder(tf.int32, [1, None], 'inputs')
15 | input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths')
16 | with tf.variable_scope('model') as scope:
17 | self.model = create_model(model_name, hparams)
18 | self.model.initialize(inputs, input_lengths)
19 | self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0])
20 |
21 | print('Loading checkpoint: %s' % checkpoint_path)
22 | self.session = tf.Session()
23 | self.session.run(tf.global_variables_initializer())
24 | saver = tf.train.Saver()
25 | saver.restore(self.session, checkpoint_path)
26 |
27 |
28 | def synthesize(self, text):
29 | cleaner_names = [x.strip() for x in hparams.cleaners.split(',')]
30 | seq = text_to_sequence(text, cleaner_names)
31 | feed_dict = {
32 | self.model.inputs: [np.asarray(seq, dtype=np.int32)],
33 | self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32)
34 | }
35 | wav = self.session.run(self.wav_output, feed_dict=feed_dict)
36 | wav = audio.inv_preemphasis(wav)
37 | wav = wav[:audio.find_endpoint(wav)]
38 | out = io.BytesIO()
39 | audio.save_wav(wav, out)
40 | return out.getvalue()
41 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/keithito/tacotron/d26c763342518d4e432e9c4036a1aff3b4fdaa1e/tests/__init__.py
--------------------------------------------------------------------------------
/tests/cmudict_test.py:
--------------------------------------------------------------------------------
1 | import io
2 | from text import cmudict
3 |
4 |
5 | test_data = '''
6 | ;;; # CMUdict -- Major Version: 0.07
7 | )PAREN P ER EH N
8 | 'TIS T IH Z
9 | ADVERSE AE0 D V ER1 S
10 | ADVERSE(1) AE1 D V ER2 S
11 | ADVERSE(2) AE2 D V ER1 S
12 | ADVERSELY AE0 D V ER1 S L IY0
13 | ADVERSITY AE0 D V ER1 S IH0 T IY2
14 | BARBERSHOP B AA1 R B ER0 SH AA2 P
15 | YOU'LL Y UW1 L
16 | '''
17 |
18 |
19 | def test_cmudict():
20 | c = cmudict.CMUDict(io.StringIO(test_data))
21 | assert len(c) == 6
22 | assert len(cmudict.valid_symbols) == 84
23 | assert c.lookup('ADVERSITY') == ['AE0 D V ER1 S IH0 T IY2']
24 | assert c.lookup('BarberShop') == ['B AA1 R B ER0 SH AA2 P']
25 | assert c.lookup("You'll") == ['Y UW1 L']
26 | assert c.lookup("'tis") == ['T IH Z']
27 | assert c.lookup('adverse') == [
28 | 'AE0 D V ER1 S',
29 | 'AE1 D V ER2 S',
30 | 'AE2 D V ER1 S',
31 | ]
32 | assert c.lookup('') == None
33 | assert c.lookup('foo') == None
34 | assert c.lookup(')paren') == None
35 |
36 |
37 | def test_cmudict_no_keep_ambiguous():
38 | c = cmudict.CMUDict(io.StringIO(test_data), keep_ambiguous=False)
39 | assert len(c) == 5
40 | assert c.lookup('adversity') == ['AE0 D V ER1 S IH0 T IY2']
41 | assert c.lookup('adverse') == None
42 |
--------------------------------------------------------------------------------
/tests/numbers_test.py:
--------------------------------------------------------------------------------
1 | from text.numbers import normalize_numbers
2 |
3 |
4 | def test_normalize_numbers():
5 | assert normalize_numbers('1') == 'one'
6 | assert normalize_numbers('15') == 'fifteen'
7 | assert normalize_numbers('24') == 'twenty-four'
8 | assert normalize_numbers('100') == 'one hundred'
9 | assert normalize_numbers('101') == 'one hundred one'
10 | assert normalize_numbers('456') == 'four hundred fifty-six'
11 | assert normalize_numbers('1000') == 'one thousand'
12 | assert normalize_numbers('1800') == 'eighteen hundred'
13 | assert normalize_numbers('2,000') == 'two thousand'
14 | assert normalize_numbers('3000') == 'three thousand'
15 | assert normalize_numbers('18000') == 'eighteen thousand'
16 | assert normalize_numbers('24,000') == 'twenty-four thousand'
17 | assert normalize_numbers('124,001') == 'one hundred twenty-four thousand one'
18 | assert normalize_numbers('6.4 sec') == 'six point four sec'
19 |
20 |
21 | def test_normalize_ordinals():
22 | assert normalize_numbers('1st') == 'first'
23 | assert normalize_numbers('2nd') == 'second'
24 | assert normalize_numbers('9th') == 'ninth'
25 | assert normalize_numbers('243rd place') == 'two hundred and forty-third place'
26 |
27 |
28 | def test_normalize_dates():
29 | assert normalize_numbers('1400') == 'fourteen hundred'
30 | assert normalize_numbers('1901') == 'nineteen oh one'
31 | assert normalize_numbers('1999') == 'nineteen ninety-nine'
32 | assert normalize_numbers('2000') == 'two thousand'
33 | assert normalize_numbers('2004') == 'two thousand four'
34 | assert normalize_numbers('2010') == 'twenty ten'
35 | assert normalize_numbers('2012') == 'twenty twelve'
36 | assert normalize_numbers('2025') == 'twenty twenty-five'
37 | assert normalize_numbers('September 11, 2001') == 'September eleven, two thousand one'
38 | assert normalize_numbers('July 26, 1984.') == 'July twenty-six, nineteen eighty-four.'
39 |
40 |
41 | def test_normalize_money():
42 | assert normalize_numbers('$0.00') == 'zero dollars'
43 | assert normalize_numbers('$1') == 'one dollar'
44 | assert normalize_numbers('$10') == 'ten dollars'
45 | assert normalize_numbers('$.01') == 'one cent'
46 | assert normalize_numbers('$0.25') == 'twenty-five cents'
47 | assert normalize_numbers('$5.00') == 'five dollars'
48 | assert normalize_numbers('$5.01') == 'five dollars, one cent'
49 | assert normalize_numbers('$135.99.') == 'one hundred thirty-five dollars, ninety-nine cents.'
50 | assert normalize_numbers('$40,000') == 'forty thousand dollars'
51 | assert normalize_numbers('for £2500!') == 'for twenty-five hundred pounds!'
52 |
--------------------------------------------------------------------------------
/tests/text_test.py:
--------------------------------------------------------------------------------
1 | from text import cleaners, symbols, text_to_sequence, sequence_to_text
2 | from unidecode import unidecode
3 |
4 |
5 | def test_symbols():
6 | assert len(symbols) >= 3
7 | assert symbols[0] == '_'
8 | assert symbols[1] == '~'
9 |
10 |
11 | def test_text_to_sequence():
12 | assert text_to_sequence('', []) == [1]
13 | assert text_to_sequence('Hi!', []) == [9, 36, 54, 1]
14 | assert text_to_sequence('"A"_B', []) == [2, 3, 1]
15 | assert text_to_sequence('A {AW1 S} B', []) == [2, 64, 83, 132, 64, 3, 1]
16 | assert text_to_sequence('Hi', ['lowercase']) == [35, 36, 1]
17 | assert text_to_sequence('A {AW1 S} B', ['english_cleaners']) == [28, 64, 83, 132, 64, 29, 1]
18 |
19 |
20 | def test_sequence_to_text():
21 | assert sequence_to_text([]) == ''
22 | assert sequence_to_text([1]) == '~'
23 | assert sequence_to_text([9, 36, 54, 1]) == 'Hi!~'
24 | assert sequence_to_text([2, 64, 83, 132, 64, 3]) == 'A {AW1 S} B'
25 |
26 |
27 | def test_collapse_whitespace():
28 | assert cleaners.collapse_whitespace('') == ''
29 | assert cleaners.collapse_whitespace(' ') == ' '
30 | assert cleaners.collapse_whitespace('x') == 'x'
31 | assert cleaners.collapse_whitespace(' x. y, \tz') == ' x. y, z'
32 |
33 |
34 | def test_convert_to_ascii():
35 | assert cleaners.convert_to_ascii("raison d'être") == "raison d'etre"
36 | assert cleaners.convert_to_ascii('grüß gott') == 'gruss gott'
37 | assert cleaners.convert_to_ascii('안녕') == 'annyeong'
38 | assert cleaners.convert_to_ascii('Здравствуйте') == 'Zdravstvuite'
39 |
40 |
41 | def test_lowercase():
42 | assert cleaners.lowercase('Happy Birthday!') == 'happy birthday!'
43 | assert cleaners.lowercase('CAFÉ') == 'café'
44 |
45 |
46 | def test_expand_abbreviations():
47 | assert cleaners.expand_abbreviations('mr. and mrs. smith') == 'mister and misess smith'
48 |
49 |
50 | def test_expand_numbers():
51 | assert cleaners.expand_numbers('3 apples and 44 pears') == 'three apples and forty-four pears'
52 | assert cleaners.expand_numbers('$3.50 for gas.') == 'three dollars, fifty cents for gas.'
53 |
54 |
55 | def test_cleaner_pipelines():
56 | text = 'Mr. Müller ate 2 Apples'
57 | assert cleaners.english_cleaners(text) == 'mister muller ate two apples'
58 | assert cleaners.transliteration_cleaners(text) == 'mr. muller ate 2 apples'
59 | assert cleaners.basic_cleaners(text) == 'mr. müller ate 2 apples'
60 |
61 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | import re
2 | from text import cleaners
3 | from text.symbols import symbols
4 |
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 | # Regular expression matching text enclosed in curly braces:
11 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
12 |
13 |
14 | def text_to_sequence(text, cleaner_names):
15 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
16 |
17 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
18 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
19 |
20 | Args:
21 | text: string to convert to a sequence
22 | cleaner_names: names of the cleaner functions to run the text through
23 |
24 | Returns:
25 | List of integers corresponding to the symbols in the text
26 | '''
27 | sequence = []
28 |
29 | # Check for curly braces and treat their contents as ARPAbet:
30 | while len(text):
31 | m = _curly_re.match(text)
32 | if not m:
33 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
34 | break
35 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
36 | sequence += _arpabet_to_sequence(m.group(2))
37 | text = m.group(3)
38 |
39 | # Append EOS token
40 | sequence.append(_symbol_to_id['~'])
41 | return sequence
42 |
43 |
44 | def sequence_to_text(sequence):
45 | '''Converts a sequence of IDs back to a string'''
46 | result = ''
47 | for symbol_id in sequence:
48 | if symbol_id in _id_to_symbol:
49 | s = _id_to_symbol[symbol_id]
50 | # Enclose ARPAbet back in curly braces:
51 | if len(s) > 1 and s[0] == '@':
52 | s = '{%s}' % s[1:]
53 | result += s
54 | return result.replace('}{', ' ')
55 |
56 |
57 | def _clean_text(text, cleaner_names):
58 | for name in cleaner_names:
59 | cleaner = getattr(cleaners, name)
60 | if not cleaner:
61 | raise Exception('Unknown cleaner: %s' % name)
62 | text = cleaner(text)
63 | return text
64 |
65 |
66 | def _symbols_to_sequence(symbols):
67 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
68 |
69 |
70 | def _arpabet_to_sequence(text):
71 | return _symbols_to_sequence(['@' + s for s in text.split()])
72 |
73 |
74 | def _should_keep_symbol(s):
75 | return s in _symbol_to_id and s is not '_' and s is not '~'
76 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | '''
2 | Cleaners are transformations that run over the input text at both training and eval time.
3 |
4 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
6 | 1. "english_cleaners" for English text
7 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10 | the symbols in symbols.py to match your data).
11 | '''
12 |
13 | import re
14 | from unidecode import unidecode
15 | from .numbers import normalize_numbers
16 |
17 |
18 | # Regular expression matching whitespace:
19 | _whitespace_re = re.compile(r'\s+')
20 |
21 | # List of (regular expression, replacement) pairs for abbreviations:
22 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
23 | ('mrs', 'misess'),
24 | ('mr', 'mister'),
25 | ('dr', 'doctor'),
26 | ('st', 'saint'),
27 | ('co', 'company'),
28 | ('jr', 'junior'),
29 | ('maj', 'major'),
30 | ('gen', 'general'),
31 | ('drs', 'doctors'),
32 | ('rev', 'reverend'),
33 | ('lt', 'lieutenant'),
34 | ('hon', 'honorable'),
35 | ('sgt', 'sergeant'),
36 | ('capt', 'captain'),
37 | ('esq', 'esquire'),
38 | ('ltd', 'limited'),
39 | ('col', 'colonel'),
40 | ('ft', 'fort'),
41 | ]]
42 |
43 |
44 | def expand_abbreviations(text):
45 | for regex, replacement in _abbreviations:
46 | text = re.sub(regex, replacement, text)
47 | return text
48 |
49 |
50 | def expand_numbers(text):
51 | return normalize_numbers(text)
52 |
53 |
54 | def lowercase(text):
55 | return text.lower()
56 |
57 |
58 | def collapse_whitespace(text):
59 | return re.sub(_whitespace_re, ' ', text)
60 |
61 |
62 | def convert_to_ascii(text):
63 | return unidecode(text)
64 |
65 |
66 | def basic_cleaners(text):
67 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
68 | text = lowercase(text)
69 | text = collapse_whitespace(text)
70 | return text
71 |
72 |
73 | def transliteration_cleaners(text):
74 | '''Pipeline for non-English text that transliterates to ASCII.'''
75 | text = convert_to_ascii(text)
76 | text = lowercase(text)
77 | text = collapse_whitespace(text)
78 | return text
79 |
80 |
81 | def english_cleaners(text):
82 | '''Pipeline for English text, including number and abbreviation expansion.'''
83 | text = convert_to_ascii(text)
84 | text = lowercase(text)
85 | text = expand_numbers(text)
86 | text = expand_abbreviations(text)
87 | text = collapse_whitespace(text)
88 | return text
89 |
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | valid_symbols = [
5 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
6 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
7 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
8 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
9 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
10 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
11 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
12 | ]
13 |
14 | _valid_symbol_set = set(valid_symbols)
15 |
16 |
17 | class CMUDict:
18 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
19 | def __init__(self, file_or_path, keep_ambiguous=True):
20 | if isinstance(file_or_path, str):
21 | with open(file_or_path, encoding='latin-1') as f:
22 | entries = _parse_cmudict(f)
23 | else:
24 | entries = _parse_cmudict(file_or_path)
25 | if not keep_ambiguous:
26 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
27 | self._entries = entries
28 |
29 |
30 | def __len__(self):
31 | return len(self._entries)
32 |
33 |
34 | def lookup(self, word):
35 | '''Returns list of ARPAbet pronunciations of the given word.'''
36 | return self._entries.get(word.upper())
37 |
38 |
39 |
40 | _alt_re = re.compile(r'\([0-9]+\)')
41 |
42 |
43 | def _parse_cmudict(file):
44 | cmudict = {}
45 | for line in file:
46 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
47 | parts = line.split(' ')
48 | word = re.sub(_alt_re, '', parts[0])
49 | pronunciation = _get_pronunciation(parts[1])
50 | if pronunciation:
51 | if word in cmudict:
52 | cmudict[word].append(pronunciation)
53 | else:
54 | cmudict[word] = [pronunciation]
55 | return cmudict
56 |
57 |
58 | def _get_pronunciation(s):
59 | parts = s.strip().split(' ')
60 | for part in parts:
61 | if part not in _valid_symbol_set:
62 | return None
63 | return ' '.join(parts)
64 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | import inflect
2 | import re
3 |
4 |
5 | _inflect = inflect.engine()
6 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
7 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
8 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
9 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
10 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
11 | _number_re = re.compile(r'[0-9]+')
12 |
13 |
14 | def _remove_commas(m):
15 | return m.group(1).replace(',', '')
16 |
17 |
18 | def _expand_decimal_point(m):
19 | return m.group(1).replace('.', ' point ')
20 |
21 |
22 | def _expand_dollars(m):
23 | match = m.group(1)
24 | parts = match.split('.')
25 | if len(parts) > 2:
26 | return match + ' dollars' # Unexpected format
27 | dollars = int(parts[0]) if parts[0] else 0
28 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29 | if dollars and cents:
30 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
31 | cent_unit = 'cent' if cents == 1 else 'cents'
32 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
33 | elif dollars:
34 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
35 | return '%s %s' % (dollars, dollar_unit)
36 | elif cents:
37 | cent_unit = 'cent' if cents == 1 else 'cents'
38 | return '%s %s' % (cents, cent_unit)
39 | else:
40 | return 'zero dollars'
41 |
42 |
43 | def _expand_ordinal(m):
44 | return _inflect.number_to_words(m.group(0))
45 |
46 |
47 | def _expand_number(m):
48 | num = int(m.group(0))
49 | if num > 1000 and num < 3000:
50 | if num == 2000:
51 | return 'two thousand'
52 | elif num > 2000 and num < 2010:
53 | return 'two thousand ' + _inflect.number_to_words(num % 100)
54 | elif num % 100 == 0:
55 | return _inflect.number_to_words(num // 100) + ' hundred'
56 | else:
57 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
58 | else:
59 | return _inflect.number_to_words(num, andword='')
60 |
61 |
62 | def normalize_numbers(text):
63 | text = re.sub(_comma_number_re, _remove_commas, text)
64 | text = re.sub(_pounds_re, r'\1 pounds', text)
65 | text = re.sub(_dollars_re, _expand_dollars, text)
66 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67 | text = re.sub(_ordinal_re, _expand_ordinal, text)
68 | text = re.sub(_number_re, _expand_number, text)
69 | return text
70 |
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | '''
2 | Defines the set of symbols used in text input to the model.
3 |
4 | The default is a set of ASCII characters that works well for English or text that has been run
5 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6 | '''
7 | from text import cmudict
8 |
9 | _pad = '_'
10 | _eos = '~'
11 | _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
12 |
13 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
14 | _arpabet = ['@' + s for s in cmudict.valid_symbols]
15 |
16 | # Export all symbols:
17 | symbols = [_pad, _eos] + list(_characters) + _arpabet
18 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import math
4 | import os
5 | import subprocess
6 | import time
7 | import tensorflow as tf
8 | import traceback
9 |
10 | from datasets.datafeeder import DataFeeder
11 | from hparams import hparams, hparams_debug_string
12 | from models import create_model
13 | from text import sequence_to_text
14 | from util import audio, infolog, plot, ValueWindow
15 | log = infolog.log
16 |
17 |
18 | def get_git_commit():
19 | subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
20 | commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()[:10]
21 | log('Git commit: %s' % commit)
22 | return commit
23 |
24 |
25 | def add_stats(model):
26 | with tf.variable_scope('stats') as scope:
27 | tf.summary.histogram('linear_outputs', model.linear_outputs)
28 | tf.summary.histogram('linear_targets', model.linear_targets)
29 | tf.summary.histogram('mel_outputs', model.mel_outputs)
30 | tf.summary.histogram('mel_targets', model.mel_targets)
31 | tf.summary.scalar('loss_mel', model.mel_loss)
32 | tf.summary.scalar('loss_linear', model.linear_loss)
33 | tf.summary.scalar('learning_rate', model.learning_rate)
34 | tf.summary.scalar('loss', model.loss)
35 | gradient_norms = [tf.norm(grad) for grad in model.gradients]
36 | tf.summary.histogram('gradient_norm', gradient_norms)
37 | tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms))
38 | return tf.summary.merge_all()
39 |
40 |
41 | def time_string():
42 | return datetime.now().strftime('%Y-%m-%d %H:%M')
43 |
44 |
45 | def train(log_dir, args):
46 | commit = get_git_commit() if args.git else 'None'
47 | checkpoint_path = os.path.join(log_dir, 'model.ckpt')
48 | input_path = os.path.join(args.base_dir, args.input)
49 | log('Checkpoint path: %s' % checkpoint_path)
50 | log('Loading training data from: %s' % input_path)
51 | log('Using model: %s' % args.model)
52 | log(hparams_debug_string())
53 |
54 | # Set up DataFeeder:
55 | coord = tf.train.Coordinator()
56 | with tf.variable_scope('datafeeder') as scope:
57 | feeder = DataFeeder(coord, input_path, hparams)
58 |
59 | # Set up model:
60 | global_step = tf.Variable(0, name='global_step', trainable=False)
61 | with tf.variable_scope('model') as scope:
62 | model = create_model(args.model, hparams)
63 | model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets)
64 | model.add_loss()
65 | model.add_optimizer(global_step)
66 | stats = add_stats(model)
67 |
68 | # Bookkeeping:
69 | step = 0
70 | time_window = ValueWindow(100)
71 | loss_window = ValueWindow(100)
72 | saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2)
73 |
74 | # Train!
75 | with tf.Session() as sess:
76 | try:
77 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
78 | sess.run(tf.global_variables_initializer())
79 |
80 | if args.restore_step:
81 | # Restore from a checkpoint if the user requested it.
82 | restore_path = '%s-%d' % (checkpoint_path, args.restore_step)
83 | saver.restore(sess, restore_path)
84 | log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True)
85 | else:
86 | log('Starting new training run at commit: %s' % commit, slack=True)
87 |
88 | feeder.start_in_session(sess)
89 |
90 | while not coord.should_stop():
91 | start_time = time.time()
92 | step, loss, opt = sess.run([global_step, model.loss, model.optimize])
93 | time_window.append(time.time() - start_time)
94 | loss_window.append(loss)
95 | message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (
96 | step, time_window.average, loss, loss_window.average)
97 | log(message, slack=(step % args.checkpoint_interval == 0))
98 |
99 | if loss > 100 or math.isnan(loss):
100 | log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True)
101 | raise Exception('Loss Exploded')
102 |
103 | if step % args.summary_interval == 0:
104 | log('Writing summary at step: %d' % step)
105 | summary_writer.add_summary(sess.run(stats), step)
106 |
107 | if step % args.checkpoint_interval == 0:
108 | log('Saving checkpoint to: %s-%d' % (checkpoint_path, step))
109 | saver.save(sess, checkpoint_path, global_step=step)
110 | log('Saving audio and alignment...')
111 | input_seq, spectrogram, alignment = sess.run([
112 | model.inputs[0], model.linear_outputs[0], model.alignments[0]])
113 | waveform = audio.inv_spectrogram(spectrogram.T)
114 | audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step))
115 | plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
116 | info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss))
117 | log('Input: %s' % sequence_to_text(input_seq))
118 |
119 | except Exception as e:
120 | log('Exiting due to exception: %s' % e, slack=True)
121 | traceback.print_exc()
122 | coord.request_stop(e)
123 |
124 |
125 | def main():
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument('--base_dir', default=os.path.expanduser('~/tacotron'))
128 | parser.add_argument('--input', default='training/train.txt')
129 | parser.add_argument('--model', default='tacotron')
130 | parser.add_argument('--name', help='Name of the run. Used for logging. Defaults to model name.')
131 | parser.add_argument('--hparams', default='',
132 | help='Hyperparameter overrides as a comma-separated list of name=value pairs')
133 | parser.add_argument('--restore_step', type=int, help='Global step to restore from checkpoint.')
134 | parser.add_argument('--summary_interval', type=int, default=100,
135 | help='Steps between running summary ops.')
136 | parser.add_argument('--checkpoint_interval', type=int, default=1000,
137 | help='Steps between writing checkpoints.')
138 | parser.add_argument('--slack_url', help='Slack webhook URL to get periodic reports.')
139 | parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.')
140 | parser.add_argument('--git', action='store_true', help='If set, verify that the client is clean.')
141 | args = parser.parse_args()
142 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
143 | run_name = args.name or args.model
144 | log_dir = os.path.join(args.base_dir, 'logs-%s' % run_name)
145 | os.makedirs(log_dir, exist_ok=True)
146 | infolog.init(os.path.join(log_dir, 'train.log'), run_name, args.slack_url)
147 | hparams.parse(args.hparams)
148 | train(log_dir, args)
149 |
150 |
151 | if __name__ == '__main__':
152 | main()
153 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | class ValueWindow():
2 | def __init__(self, window_size=100):
3 | self._window_size = window_size
4 | self._values = []
5 |
6 | def append(self, x):
7 | self._values = self._values[-(self._window_size - 1):] + [x]
8 |
9 | @property
10 | def sum(self):
11 | return sum(self._values)
12 |
13 | @property
14 | def count(self):
15 | return len(self._values)
16 |
17 | @property
18 | def average(self):
19 | return self.sum / max(1, self.count)
20 |
21 | def reset(self):
22 | self._values = []
23 |
--------------------------------------------------------------------------------
/util/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import librosa.filters
3 | import math
4 | import numpy as np
5 | import tensorflow as tf
6 | import scipy
7 | from hparams import hparams
8 |
9 |
10 | def load_wav(path):
11 | return librosa.core.load(path, sr=hparams.sample_rate)[0]
12 |
13 |
14 | def save_wav(wav, path):
15 | wav *= 32767 / max(0.01, np.max(np.abs(wav)))
16 | scipy.io.wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
17 |
18 |
19 | def preemphasis(x):
20 | return scipy.signal.lfilter([1, -hparams.preemphasis], [1], x)
21 |
22 |
23 | def inv_preemphasis(x):
24 | return scipy.signal.lfilter([1], [1, -hparams.preemphasis], x)
25 |
26 |
27 | def spectrogram(y):
28 | D = _stft(preemphasis(y))
29 | S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
30 | return _normalize(S)
31 |
32 |
33 | def inv_spectrogram(spectrogram):
34 | '''Converts spectrogram to waveform using librosa'''
35 | S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear
36 | return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
37 |
38 |
39 | def inv_spectrogram_tensorflow(spectrogram):
40 | '''Builds computational graph to convert spectrogram to waveform using TensorFlow.
41 |
42 | Unlike inv_spectrogram, this does NOT invert the preemphasis. The caller should call
43 | inv_preemphasis on the output after running the graph.
44 | '''
45 | S = _db_to_amp_tensorflow(_denormalize_tensorflow(spectrogram) + hparams.ref_level_db)
46 | return _griffin_lim_tensorflow(tf.pow(S, hparams.power))
47 |
48 |
49 | def melspectrogram(y):
50 | D = _stft(preemphasis(y))
51 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
52 | return _normalize(S)
53 |
54 |
55 | def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8):
56 | window_length = int(hparams.sample_rate * min_silence_sec)
57 | hop_length = int(window_length / 4)
58 | threshold = _db_to_amp(threshold_db)
59 | for x in range(hop_length, len(wav) - window_length, hop_length):
60 | if np.max(wav[x:x+window_length]) < threshold:
61 | return x + hop_length
62 | return len(wav)
63 |
64 |
65 | def _griffin_lim(S):
66 | '''librosa implementation of Griffin-Lim
67 | Based on https://github.com/librosa/librosa/issues/434
68 | '''
69 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
70 | S_complex = np.abs(S).astype(np.complex)
71 | y = _istft(S_complex * angles)
72 | for i in range(hparams.griffin_lim_iters):
73 | angles = np.exp(1j * np.angle(_stft(y)))
74 | y = _istft(S_complex * angles)
75 | return y
76 |
77 |
78 | def _griffin_lim_tensorflow(S):
79 | '''TensorFlow implementation of Griffin-Lim
80 | Based on https://github.com/Kyubyong/tensorflow-exercises/blob/master/Audio_Processing.ipynb
81 | '''
82 | with tf.variable_scope('griffinlim'):
83 | # TensorFlow's stft and istft operate on a batch of spectrograms; create batch of size 1
84 | S = tf.expand_dims(S, 0)
85 | S_complex = tf.identity(tf.cast(S, dtype=tf.complex64))
86 | y = _istft_tensorflow(S_complex)
87 | for i in range(hparams.griffin_lim_iters):
88 | est = _stft_tensorflow(y)
89 | angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64)
90 | y = _istft_tensorflow(S_complex * angles)
91 | return tf.squeeze(y, 0)
92 |
93 |
94 | def _stft(y):
95 | n_fft, hop_length, win_length = _stft_parameters()
96 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
97 |
98 |
99 | def _istft(y):
100 | _, hop_length, win_length = _stft_parameters()
101 | return librosa.istft(y, hop_length=hop_length, win_length=win_length)
102 |
103 |
104 | def _stft_tensorflow(signals):
105 | n_fft, hop_length, win_length = _stft_parameters()
106 | return tf.contrib.signal.stft(signals, win_length, hop_length, n_fft, pad_end=False)
107 |
108 |
109 | def _istft_tensorflow(stfts):
110 | n_fft, hop_length, win_length = _stft_parameters()
111 | return tf.contrib.signal.inverse_stft(stfts, win_length, hop_length, n_fft)
112 |
113 |
114 | def _stft_parameters():
115 | n_fft = (hparams.num_freq - 1) * 2
116 | hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
117 | win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate)
118 | return n_fft, hop_length, win_length
119 |
120 |
121 | # Conversions:
122 |
123 | _mel_basis = None
124 |
125 | def _linear_to_mel(spectrogram):
126 | global _mel_basis
127 | if _mel_basis is None:
128 | _mel_basis = _build_mel_basis()
129 | return np.dot(_mel_basis, spectrogram)
130 |
131 | def _build_mel_basis():
132 | n_fft = (hparams.num_freq - 1) * 2
133 | return librosa.filters.mel(hparams.sample_rate, n_fft, n_mels=hparams.num_mels)
134 |
135 | def _amp_to_db(x):
136 | return 20 * np.log10(np.maximum(1e-5, x))
137 |
138 | def _db_to_amp(x):
139 | return np.power(10.0, x * 0.05)
140 |
141 | def _db_to_amp_tensorflow(x):
142 | return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)
143 |
144 | def _normalize(S):
145 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
146 |
147 | def _denormalize(S):
148 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
149 |
150 | def _denormalize_tensorflow(S):
151 | return (tf.clip_by_value(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
152 |
--------------------------------------------------------------------------------
/util/infolog.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | from datetime import datetime
3 | import json
4 | from threading import Thread
5 | from urllib.request import Request, urlopen
6 |
7 |
8 | _format = '%Y-%m-%d %H:%M:%S.%f'
9 | _file = None
10 | _run_name = None
11 | _slack_url = None
12 |
13 |
14 | def init(filename, run_name, slack_url=None):
15 | global _file, _run_name, _slack_url
16 | _close_logfile()
17 | _file = open(filename, 'a', encoding="utf-8")
18 | _file.write('\n-----------------------------------------------------------------\n')
19 | _file.write('Starting new training run\n')
20 | _file.write('-----------------------------------------------------------------\n')
21 | _run_name = run_name
22 | _slack_url = slack_url
23 |
24 |
25 | def log(msg, slack=False):
26 | print(msg)
27 | if _file is not None:
28 | _file.write('[%s] %s\n' % (datetime.now().strftime(_format)[:-3], msg))
29 | if slack and _slack_url is not None:
30 | Thread(target=_send_slack, args=(msg,)).start()
31 |
32 |
33 | def _close_logfile():
34 | global _file
35 | if _file is not None:
36 | _file.close()
37 | _file = None
38 |
39 |
40 | def _send_slack(msg):
41 | req = Request(_slack_url)
42 | req.add_header('Content-Type', 'application/json')
43 | urlopen(req, json.dumps({
44 | 'username': 'tacotron',
45 | 'icon_emoji': ':taco:',
46 | 'text': '*%s*: %s' % (_run_name, msg)
47 | }).encode())
48 |
49 |
50 | atexit.register(_close_logfile)
51 |
--------------------------------------------------------------------------------
/util/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def plot_alignment(alignment, path, info=None):
7 | fig, ax = plt.subplots()
8 | im = ax.imshow(
9 | alignment,
10 | aspect='auto',
11 | origin='lower',
12 | interpolation='none')
13 | fig.colorbar(im, ax=ax)
14 | xlabel = 'Decoder timestep'
15 | if info is not None:
16 | xlabel += '\n\n' + info
17 | plt.xlabel(xlabel)
18 | plt.ylabel('Encoder timestep')
19 | plt.tight_layout()
20 | plt.savefig(path, format='png')
21 |
--------------------------------------------------------------------------------