├── README.md ├── data_load.py ├── hyperparams.py ├── modules.py ├── networks.py ├── prepro.py ├── ref1 ├── 01_au_f.wav ├── 02_au_f.wav ├── 03_au_f.wav ├── 04_au_f.wav ├── 05_au_f.wav ├── 06_br_f.wav ├── 07_br_f.wav ├── 08_br_f.wav ├── 09_br_f.wav ├── 10_br_f.wav ├── 11_am_f1.wav ├── 12_am_f1.wav ├── 13_am_m.wav ├── 14_am_m.wav ├── 15_am_m.wav └── 16_am_m.wav ├── ref2 ├── 01_am_f.wav ├── 02_am_f.wav ├── 03_am_f.wav ├── 04_am_f.wav ├── 05_am_f.wav ├── 06_am_f.wav ├── 07_am_f.wav ├── 08_am_f.wav ├── 09_am_f.wav ├── 10_am_f.wav ├── 11_am_f2.wav ├── 12_am_f2.wav ├── 13_am_f.wav ├── 14_am_f.wav ├── 15_am_f.wav └── 16_am_f.wav ├── synthesize.py ├── test_sents.txt ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # A TensorFlow Implementation of Expressive Tacotron 2 | 3 | This project aims at implementing the paper, [Towards End-to-End Prosody Transfer for Expressive Speech Synthesis with Tacotron](https://arxiv.org/abs/1803.09047), to verify its concept. Most of the baseline codes are based on my previous [Tacotron implementation](https://github.com/Kyubyong/tacotron). 4 | 5 | ## Requirements 6 | 7 | * NumPy >= 1.11.1 8 | * TensorFlow >= 1.3 9 | * librosa 10 | * tqdm 11 | * matplotlib 12 | * scipy 13 | 14 | ## Data 15 | 16 | 17 | 18 | Because the paper used their internal data, I train the model on the [LJ Speech Dataset](https://keithito.com/LJ-Speech-Dataset/) 19 | 20 | LJ Speech Dataset is recently widely used as a benchmark dataset in the TTS task because it is publicly available. It has 24 hours of reasonable quality samples. 21 | 22 | ## Training 23 | * STEP 0. Download [LJ Speech Dataset](https://keithito.com/LJ-Speech-Dataset/) or prepare your own data. 24 | * STEP 1. Adjust hyper parameters in `hyperparams.py`. (If you want to do preprocessing, set `prepro` True`. 25 | * STEP 2. Run `python train.py`. (If you set `prepro` True, run `python prepro.py` first) 26 | * STEP 3. Run `python eval.py` regularly during training. 27 | 28 | ## Sample Synthesis 29 | 30 | I generate speech samples based on the same script as the one used for the original [web demo](https://google.github.io/tacotron/publications/end_to_end_prosody_transfer/). You can check it in `test_sents.txt`. 31 | 32 | * Run `python synthesize.py` and check the files in `samples`. 33 | 34 | 35 | ## Samples 36 | 37 | 16 sample sentences in the first chapter of the original [web demo](https://google.github.io/tacotron/publications/end_to_end_prosody_transfer/) are collected for sample synthesis. Two audio clips per sentence are used for prosody embedding--reference voice and base voice. 38 | Mostly, those two are different in terms of gender or region. The samples below look like the following: 39 | 40 | * 1a: the first reference audio 41 | * 1b: sample embedded with 1a's prosody 42 | * 1c: the second reference audio (base) 43 | * 1d: sample embedded with 1c's prosody 44 | 45 | Check out the samples at each steps. 46 | 47 | * [130k steps](https://soundcloud.com/kyubyong-park/sets/expressive_tacotron_130k) 48 | * [420k steps](https://soundcloud.com/kyubyong-park/sets/expressive_tacotron_420k) 49 | 50 | ## Analysis 51 | * Hearing the results of 130k steps, it's not clear if the model has learned the prosody. 52 | * It's clear that different reference audios cause different samples. 53 | * Some samples are worthy of note. For example, listen to the four audios of no.15. The stress of "right" part was obvious transferred. 54 | * Check out no.9, reference audios of which are sung. They are fun. 55 | 56 | ## Notes 57 | 58 | * Because this repo focuses on investigating the concept of the paper, I did not follow some details of the paper. 59 | * The paper used phoneme inputs, whereas I stuck to graphemes. 60 | * Instead of the Bahdanau attention, the paper used the GMM attention. 61 | * The original audio samples were obtained from wavenet vocoder. 62 | * I'm still confused what the paper claims to be a prosody embedding can be isolated from the speaker. 63 | * For prosody embedding, the authors employed conv2d. Why not conv1d? 64 | * When the reference audio's text or sentence structure is totally different from the inference script, what happens? 65 | * If I have time, I'd like to implement their 2nd paper: [Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis](https://arxiv.org/abs/1803.09017) 66 | 67 | April 2018, 68 | Kyubyong Park 69 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | from hyperparams import Hyperparams as hp 11 | import numpy as np 12 | import tensorflow as tf 13 | from utils import * 14 | import codecs 15 | import re 16 | import os 17 | import unicodedata 18 | 19 | def load_vocab(): 20 | char2idx = {char: idx for idx, char in enumerate(hp.vocab)} 21 | idx2char = {idx: char for idx, char in enumerate(hp.vocab)} 22 | return char2idx, idx2char 23 | 24 | def text_normalize(text): 25 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 26 | if unicodedata.category(char) != 'Mn') # Strip accents 27 | 28 | text = re.sub(u"[^{}]".format(hp.vocab), " ", text) 29 | text = re.sub("[ ]+", " ", text) 30 | return text 31 | 32 | def load_data(mode="train"): 33 | # Load vocabulary 34 | char2idx, idx2char = load_vocab() 35 | 36 | if mode == "train": 37 | # Parse 38 | fpaths, texts = [], [] 39 | transcript = os.path.join(hp.data, 'metadata.csv') 40 | lines = codecs.open(transcript, 'r', 'utf-8').readlines() 41 | 42 | for line in lines: 43 | fname, _, text = line.strip().split("|") 44 | fpath = os.path.join(hp.data, "wavs", fname + ".wav") 45 | 46 | fpaths.append(fpath) 47 | 48 | text = text_normalize(text) + u"␃" # ␃: EOS 49 | text = [char2idx[char] for char in text] 50 | texts.append(np.array(text, np.int32).tostring()) 51 | return fpaths, texts 52 | else: 53 | # Parse 54 | lines = codecs.open(hp.test_data, 'r', 'utf-8').readlines()[1:] 55 | sents = [text_normalize(line.split(" ", 1)[-1]).strip() + u"␃" for line in lines] # text normalization, E: EOS 56 | texts = np.zeros((len(lines), hp.Tx), np.int32) 57 | for i, sent in enumerate(sents): 58 | texts[i, :len(sent)] = [char2idx[char] for char in sent] 59 | return texts 60 | 61 | def get_batch(): 62 | """Loads training data and put them in queues""" 63 | with tf.device('/cpu:0'): 64 | # Load data 65 | fpaths, texts = load_data() # list 66 | 67 | # Calc total batch count 68 | num_batch = len(fpaths) // hp.batch_size 69 | 70 | fpaths = tf.convert_to_tensor(fpaths) 71 | texts = tf.convert_to_tensor(texts) 72 | 73 | # Create Queues 74 | fpath, text = tf.train.slice_input_producer([fpaths, texts], shuffle=True) 75 | 76 | # Text parsing 77 | text = tf.decode_raw(text, tf.int32) # (None,) 78 | 79 | # Padding 80 | text = tf.pad(text, ([0, hp.Tx], ))[:hp.Tx] # (Tx,) 81 | 82 | if hp.prepro: 83 | def _load_spectrograms(fpath): 84 | fname = os.path.basename(fpath) 85 | mel = "mels/{}".format(fname.replace("wav", "npy")) 86 | mag = "mags/{}".format(fname.replace("wav", "npy")) 87 | return fname, np.load(mel), np.load(mag) 88 | 89 | fname, mel, mag = tf.py_func(_load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32]) 90 | else: 91 | fname, mel, mag = tf.py_func(load_spectrograms, [fpath], [tf.string, tf.float32, tf.float32]) # (None, n_mels) 92 | 93 | # Add shape information 94 | fname.set_shape(()) 95 | text.set_shape((hp.Tx,)) 96 | mel.set_shape((None, hp.n_mels*hp.r)) 97 | mag.set_shape((None, hp.n_fft//2+1)) 98 | 99 | # Batching 100 | texts, mels, mags, fnames = tf.train.batch([text, mel, mag, fname], 101 | num_threads=8, 102 | batch_size=hp.batch_size, 103 | capacity=hp.batch_size * 64, 104 | allow_smaller_final_batch=False, 105 | dynamic_pad=True) 106 | 107 | return texts, mels, mags, fnames, num_batch 108 | 109 | -------------------------------------------------------------------------------- /hyperparams.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | class Hyperparams: 8 | '''Hyper parameters''' 9 | # pipeline 10 | prepro = True # if True, run `python prepro.py` first before running `python train.py`. 11 | 12 | vocab = u'''␀␃ !',-.:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz''' # ␀: Padding ␃: End of Sentence 13 | 14 | # data 15 | data = "/data/private/voice/LJSpeech-1.0" 16 | test_data = 'test_sents.txt' 17 | ref_audio = 'ref1/*.wav' 18 | Tx = 188 # Fixed length of text length. 19 | 20 | # signal processing 21 | sr = 22050 # Sample rate. 22 | n_fft = 2048 # fft points (samples) 23 | frame_shift = 0.0125 # seconds 24 | frame_length = 0.05 # seconds 25 | hop_length = int(sr*frame_shift) # samples. 26 | win_length = int(sr*frame_length) # samples. 27 | n_mels = 80 # Number of Mel banks to generate 28 | power = 1.2 # Exponent for amplifying the predicted magnitude 29 | n_iter = 50 # Number of inversion iterations 30 | preemphasis = .97 # or None 31 | max_db = 100 32 | ref_db = 20 33 | 34 | # model 35 | embed_size = 256 # alias = E 36 | encoder_num_banks = 16 37 | decoder_num_banks = 8 38 | num_highwaynet_blocks = 4 39 | r = 5 # Reduction factor. 40 | dropout_rate = .5 41 | 42 | # training scheme 43 | lr = 0.001 # Initial learning rate. 44 | logdir = "logdir" 45 | sampledir = 'samples' 46 | batch_size = 32 47 | num_iterations = 1000000 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | from hyperparams import Hyperparams as hp 11 | import tensorflow as tf 12 | 13 | 14 | def embed(inputs, vocab_size, num_units, zero_pad=True, scope="embedding", reuse=None): 15 | '''Embeds a given tensor. 16 | 17 | Args: 18 | inputs: A `Tensor` with type `int32` or `int64` containing the ids 19 | to be looked up in `lookup table`. 20 | vocab_size: An int. Vocabulary size. 21 | num_units: An int. Number of embedding hidden units. 22 | zero_pad: A boolean. If True, all the values of the fist row (id 0) 23 | should be constant zeros. 24 | scope: Optional scope for `variable_scope`. 25 | reuse: Boolean, whether to reuse the weights of a previous layer 26 | by the same name. 27 | 28 | Returns: 29 | A `Tensor` with one more rank than inputs's. The last dimesionality 30 | should be `num_units`. 31 | ''' 32 | with tf.variable_scope(scope, reuse=reuse): 33 | lookup_table = tf.get_variable('lookup_table', 34 | dtype=tf.float32, 35 | shape=[vocab_size, num_units], 36 | initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.01)) 37 | if zero_pad: 38 | lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), 39 | lookup_table[1:, :]), 0) 40 | return tf.nn.embedding_lookup(lookup_table, inputs) 41 | 42 | 43 | def bn(inputs, 44 | is_training=True, 45 | activation_fn=None, 46 | scope="bn", 47 | reuse=None): 48 | '''Applies batch normalization. 49 | 50 | Args: 51 | inputs: A tensor with 2 or more dimensions, where the first dimension has 52 | `batch_size`. If type is `bn`, the normalization is over all but 53 | the last dimension. Or if type is `ln`, the normalization is over 54 | the last dimension. Note that this is different from the native 55 | `tf.contrib.layers.batch_norm`. For this I recommend you change 56 | a line in ``tensorflow/contrib/layers/python/layers/layer.py` 57 | as follows. 58 | Before: mean, variance = nn.moments(inputs, axis, keep_dims=True) 59 | After: mean, variance = nn.moments(inputs, [-1], keep_dims=True) 60 | is_training: Whether or not the layer is in training mode. 61 | activation_fn: Activation function. 62 | scope: Optional scope for `variable_scope`. 63 | reuse: Boolean, whether to reuse the weights of a previous layer 64 | by the same name. 65 | 66 | Returns: 67 | A tensor with the same shape and data dtype as `inputs`. 68 | ''' 69 | inputs_shape = inputs.get_shape() 70 | inputs_rank = inputs_shape.ndims 71 | 72 | # use fused batch norm if inputs_rank in [2, 3, 4] as it is much faster. 73 | # pay attention to the fact that fused_batch_norm requires shape to be rank 4 of NHWC. 74 | if inputs_rank in [2, 3, 4]: 75 | if inputs_rank == 2: 76 | inputs = tf.expand_dims(inputs, axis=1) 77 | inputs = tf.expand_dims(inputs, axis=2) 78 | elif inputs_rank == 3: 79 | inputs = tf.expand_dims(inputs, axis=1) 80 | 81 | outputs = tf.contrib.layers.batch_norm(inputs=inputs, 82 | center=True, 83 | scale=True, 84 | updates_collections=None, 85 | is_training=is_training, 86 | scope=scope, 87 | fused=True, 88 | reuse=reuse) 89 | # restore original shape 90 | if inputs_rank == 2: 91 | outputs = tf.squeeze(outputs, axis=[1, 2]) 92 | elif inputs_rank == 3: 93 | outputs = tf.squeeze(outputs, axis=1) 94 | else: # fallback to naive batch norm 95 | outputs = tf.contrib.layers.batch_norm(inputs=inputs, 96 | center=True, 97 | scale=True, 98 | updates_collections=None, 99 | is_training=is_training, 100 | scope=scope, 101 | reuse=reuse, 102 | fused=False) 103 | if activation_fn is not None: 104 | outputs = activation_fn(outputs) 105 | 106 | return outputs 107 | 108 | def conv1d(inputs, 109 | filters=None, 110 | size=1, 111 | rate=1, 112 | padding="SAME", 113 | use_bias=False, 114 | activation_fn=None, 115 | scope="conv1d", 116 | reuse=None): 117 | ''' 118 | Args: 119 | inputs: A 3-D tensor with shape of [batch, time, depth]. 120 | filters: An int. Number of outputs (=activation maps) 121 | size: An int. Filter size. 122 | rate: An int. Dilation rate. 123 | padding: Either `same` or `valid` or `causal` (case-insensitive). 124 | use_bias: A boolean. 125 | scope: Optional scope for `variable_scope`. 126 | reuse: Boolean, whether to reuse the weights of a previous layer 127 | by the same name. 128 | ''' 129 | with tf.variable_scope(scope): 130 | if padding.lower()=="causal": 131 | # pre-padding for causality 132 | pad_len = (size - 1) * rate # padding size 133 | inputs = tf.pad(inputs, [[0, 0], [pad_len, 0], [0, 0]]) 134 | padding = "valid" 135 | 136 | if filters is None: 137 | filters = inputs.get_shape().as_list[-1] 138 | 139 | params = {"inputs":inputs, "filters":filters, "kernel_size":size, 140 | "dilation_rate":rate, "padding":padding, "activation":activation_fn, 141 | "use_bias":use_bias, "reuse":reuse} 142 | 143 | outputs = tf.layers.conv1d(**params) 144 | return outputs 145 | 146 | def conv1d_banks(inputs, K=16, is_training=True, scope="conv1d_banks", reuse=None): 147 | '''Applies a series of conv1d separately. 148 | 149 | Args: 150 | inputs: A 3d tensor with shape of [N, T, C] 151 | K: An int. The size of conv1d banks. That is, 152 | The `inputs` are convolved with K filters: 1, 2, ..., K. 153 | is_training: A boolean. This is passed to an argument of `bn`. 154 | scope: Optional scope for `variable_scope`. 155 | reuse: Boolean, whether to reuse the weights of a previous layer 156 | by the same name. 157 | 158 | Returns: 159 | A 3d tensor with shape of [N, T, K*Hp.embed_size//2]. 160 | ''' 161 | with tf.variable_scope(scope, reuse=reuse): 162 | outputs = conv1d(inputs, hp.embed_size//2, 1) # k=1 163 | for k in range(2, K+1): # k = 2...K 164 | with tf.variable_scope("num_{}".format(k)): 165 | output = conv1d(inputs, hp.embed_size // 2, k) 166 | outputs = tf.concat((outputs, output), -1) 167 | outputs = bn(outputs, is_training=is_training, activation_fn=tf.nn.relu) 168 | return outputs # (N, T, E//2*K) 169 | 170 | def gru(inputs, num_units=None, bidirection=False, scope="gru", reuse=None): 171 | '''Applies a GRU. 172 | 173 | Args: 174 | inputs: A 3d tensor with shape of [N, T, C]. 175 | num_units: An int. The number of hidden units. 176 | bidirection: A boolean. If True, bidirectional results 177 | are concatenated. 178 | scope: Optional scope for `variable_scope`. 179 | reuse: Boolean, whether to reuse the weights of a previous layer 180 | by the same name. 181 | 182 | Returns: 183 | If bidirection is True, a 3d tensor with shape of [N, T, 2*num_units], 184 | otherwise [N, T, num_units]. 185 | ''' 186 | with tf.variable_scope(scope, reuse=reuse): 187 | if num_units is None: 188 | num_units = inputs.get_shape().as_list()[-1] 189 | 190 | cell = tf.contrib.rnn.GRUCell(num_units) 191 | if bidirection: 192 | cell_bw = tf.contrib.rnn.GRUCell(num_units) 193 | outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell, cell_bw, inputs, dtype=tf.float32) 194 | return tf.concat(outputs, 2) 195 | else: 196 | outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32) 197 | return outputs 198 | 199 | def attention_decoder(inputs, memory, num_units=None, scope="attention_decoder", reuse=None): 200 | '''Applies a GRU to `inputs`, while attending `memory`. 201 | Args: 202 | inputs: A 3d tensor with shape of [N, T', C']. Decoder inputs. 203 | memory: A 3d tensor with shape of [N, T, C]. Outputs of encoder network. 204 | num_units: An int. Attention size. 205 | scope: Optional scope for `variable_scope`. 206 | reuse: Boolean, whether to reuse the weights of a previous layer 207 | by the same name. 208 | 209 | Returns: 210 | A 3d tensor with shape of [N, T, num_units]. 211 | ''' 212 | with tf.variable_scope(scope, reuse=reuse): 213 | if num_units is None: 214 | num_units = inputs.get_shape().as_list[-1] 215 | 216 | attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units, 217 | memory) 218 | decoder_cell = tf.contrib.rnn.GRUCell(num_units) 219 | cell_with_attention = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, 220 | attention_mechanism, 221 | num_units, 222 | alignment_history=True) 223 | outputs, state = tf.nn.dynamic_rnn(cell_with_attention, inputs, dtype=tf.float32) #( N, T', 16) 224 | 225 | return outputs, state 226 | 227 | def prenet(inputs, num_units=None, is_training=True, scope="prenet", reuse=None): 228 | '''Prenet for transcript_encoder and Decoder1. 229 | Args: 230 | inputs: A 2D or 3D tensor. 231 | num_units: A list of two integers. or None. 232 | is_training: A python boolean. 233 | scope: Optional scope for `variable_scope`. 234 | reuse: Boolean, whether to reuse the weights of a previous layer 235 | by the same name. 236 | 237 | Returns: 238 | A 3D tensor of shape [N, T, num_units/2]. 239 | ''' 240 | if num_units is None: 241 | num_units = [hp.embed_size, hp.embed_size//2] 242 | 243 | with tf.variable_scope(scope, reuse=reuse): 244 | outputs = tf.layers.dense(inputs, units=num_units[0], activation=tf.nn.relu, name="dense1") 245 | outputs = tf.layers.dropout(outputs, rate=hp.dropout_rate, training=is_training, name="dropout1") 246 | outputs = tf.layers.dense(outputs, units=num_units[1], activation=tf.nn.relu, name="dense2") 247 | outputs = tf.layers.dropout(outputs, rate=hp.dropout_rate, training=is_training, name="dropout2") 248 | return outputs # (N, ..., num_units[1]) 249 | 250 | def highwaynet(inputs, num_units=None, scope="highwaynet", reuse=None): 251 | '''Highway networks, see https://arxiv.org/abs/1505.00387 252 | 253 | Args: 254 | inputs: A 3D tensor of shape [N, T, W]. 255 | num_units: An int or `None`. Specifies the number of units in the highway layer 256 | or uses the input size if `None`. 257 | scope: Optional scope for `variable_scope`. 258 | reuse: Boolean, whether to reuse the weights of a previous layer 259 | by the same name. 260 | 261 | Returns: 262 | A 3D tensor of shape [N, T, W]. 263 | ''' 264 | if not num_units: 265 | num_units = inputs.get_shape()[-1] 266 | 267 | with tf.variable_scope(scope, reuse=reuse): 268 | H = tf.layers.dense(inputs, units=num_units, activation=tf.nn.relu, name="dense1") 269 | T = tf.layers.dense(inputs, units=num_units, activation=tf.nn.sigmoid, 270 | bias_initializer=tf.constant_initializer(-1.0), name="dense2") 271 | outputs = H*T + inputs*(1.-T) 272 | return outputs 273 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | from hyperparams import Hyperparams as hp 11 | from modules import * 12 | import tensorflow as tf 13 | 14 | 15 | def transcript_encoder(inputs, is_training=True, scope="encoder", reuse=None): 16 | ''' 17 | Args: 18 | inputs: A 3d tensor with shape of [N, Tx, E], with dtype of int32. Text inputs. 19 | is_training: Whether or not the layer is in training mode. 20 | scope: Optional scope for `variable_scope` 21 | reuse: Boolean, whether to reuse the weights of a previous layer 22 | by the same name. 23 | 24 | Returns: 25 | A collection of text hidden vectors. Has the shape of (N, Tx, E). 26 | ''' 27 | with tf.variable_scope(scope, reuse=reuse): 28 | # Encoder pre-net 29 | prenet_out = prenet(inputs, is_training=is_training) # (N, Tx, E/2) 30 | 31 | # Encoder CBHG 32 | ## Conv1D banks 33 | enc = conv1d_banks(prenet_out, K=hp.encoder_num_banks, is_training=is_training) # (N, Tx, K*E/2) 34 | 35 | ## Max pooling 36 | enc = tf.layers.max_pooling1d(enc, pool_size=2, strides=1, padding="same") # (N, Tx, K*E/2) 37 | 38 | ## Conv1D projections 39 | enc = conv1d(enc, filters=hp.embed_size//2, size=3, scope="conv1d_1") # (N, Tx, E/2) 40 | enc = bn(enc, is_training=is_training, activation_fn=tf.nn.relu, scope="conv1d_1") 41 | 42 | enc = conv1d(enc, filters=hp.embed_size // 2, size=3, scope="conv1d_2") # (N, Tx, E/2) 43 | enc = bn(enc, is_training=is_training, scope="conv1d_2") 44 | 45 | enc += prenet_out # (N, Tx, E/2) # residual connections 46 | 47 | ## Highway Nets 48 | for i in range(hp.num_highwaynet_blocks): 49 | enc = highwaynet(enc, num_units=hp.embed_size//2, 50 | scope='highwaynet_{}'.format(i)) # (N, Tx, E/2) 51 | 52 | ## Bidirectional GRU 53 | texts = gru(enc, num_units=hp.embed_size//2, bidirection=True) # (N, Tx, E) 54 | 55 | return texts 56 | 57 | 58 | def reference_encoder(inputs, is_training=True, scope="encoder", reuse=None): 59 | ''' 60 | Args: 61 | inputs: A 3d tensor with shape of (N, Ty, n_mels), with dtype of float32. 62 | Melspectrogram of reference audio. 63 | is_training: Whether or not the layer is in training mode. 64 | scope: Optional scope for `variable_scope` 65 | reuse: Boolean, whether to reuse the weights of a previous layer 66 | by the same name. 67 | 68 | Returns: 69 | Prosody vectors. Has the shape of (N, 128). 70 | ''' 71 | with tf.variable_scope(scope, reuse=reuse): 72 | # 6-Layer Strided Conv2D -> (N, T/64, n_mels/64, 128) 73 | tensor = tf.layers.conv2d(inputs=inputs, filters=32, kernel_size=3, strides=2, padding='SAME') 74 | tensor = bn(tensor, is_training=is_training, activation_fn=tf.nn.relu, scope="bn1") 75 | 76 | tensor = tf.layers.conv2d(inputs=tensor, filters=32, kernel_size=3, strides=2, padding='SAME') 77 | tensor = bn(tensor, is_training=is_training, activation_fn=tf.nn.relu, scope="bn2") 78 | 79 | tensor = tf.layers.conv2d(inputs=tensor, filters=64, kernel_size=3, strides=2, padding='SAME') 80 | tensor = bn(tensor, is_training=is_training, activation_fn=tf.nn.relu, scope="bn3") 81 | 82 | tensor = tf.layers.conv2d(inputs=tensor, filters=64, kernel_size=3, strides=2, padding='SAME') 83 | tensor = bn(tensor, is_training=is_training, activation_fn=tf.nn.relu, scope="bn4") 84 | 85 | tensor = tf.layers.conv2d(inputs=tensor, filters=128, kernel_size=3, strides=2, padding='SAME') 86 | tensor = bn(tensor, is_training=is_training, activation_fn=tf.nn.relu, scope="bn5") 87 | 88 | tensor = tf.layers.conv2d(inputs=tensor, filters=128, kernel_size=3, strides=2, padding='SAME') 89 | tensor = bn(tensor, is_training=is_training, activation_fn=tf.nn.relu, scope="bn6") 90 | 91 | # Unroll -> (N, T/64, 128*n_mels/64) 92 | N, _, W, C = tensor.get_shape().as_list() 93 | tensor = tf.reshape(tensor, (N, -1, W*C)) 94 | 95 | # GRU -> (N, T/64, 128) -> (N, 128) 96 | tensor = gru(tensor, num_units=128, bidirection=False, scope="gru") 97 | tensor = tensor[:, -1, :] 98 | 99 | # FC -> (N, 128) 100 | prosody = tf.layers.dense(tensor, 128, activation=tf.nn.tanh) 101 | 102 | return prosody 103 | 104 | def decoder1(inputs, memory, is_training=True, scope="decoder1", reuse=None): 105 | ''' 106 | Args: 107 | inputs: A 3d tensor with shape of [N, Ty/r, n_mels(*r)]. Shifted log melspectrogram of sound files. 108 | memory: A 3d tensor with shape of [N, Tx, E]. 109 | is_training: Whether or not the layer is in training mode. 110 | scope: Optional scope for `variable_scope` 111 | reuse: Boolean, whether to reuse the weights of a previous layer 112 | by the same name. 113 | 114 | Returns 115 | Predicted log melspectrogram tensor with shape of [N, Ty/r, n_mels*r]. 116 | ''' 117 | with tf.variable_scope(scope, reuse=reuse): 118 | # Decoder pre-net 119 | inputs = prenet(inputs, is_training=is_training) # (N, Ty/r, E/2) 120 | 121 | # Attention RNN 122 | dec, state = attention_decoder(inputs, memory, num_units=hp.embed_size) # (N, Ty/r, E) 123 | 124 | ## for attention monitoring 125 | alignments = tf.transpose(state.alignment_history.stack(),[1,2,0]) 126 | 127 | # Decoder RNNs 128 | dec += gru(dec, hp.embed_size, bidirection=False, scope="decoder_gru1") # (N, Ty/r, E) 129 | dec += gru(dec, hp.embed_size, bidirection=False, scope="decoder_gru2") # (N, Ty/r, E) 130 | 131 | # Outputs => (N, Ty/r, n_mels*r) 132 | mel_hats = tf.layers.dense(dec, hp.n_mels*hp.r) 133 | 134 | return mel_hats, alignments 135 | 136 | def decoder2(inputs, is_training=True, scope="decoder2", reuse=None): 137 | '''Decoder Post-processing net = CBHG 138 | Args: 139 | inputs: A 3d tensor with shape of [N, Ty/r, n_mels*r]. Log magnitude spectrogram of sound files. 140 | It is recovered to its original shape. 141 | is_training: Whether or not the layer is in training mode. 142 | scope: Optional scope for `variable_scope` 143 | reuse: Boolean, whether to reuse the weights of a previous layer 144 | by the same name. 145 | 146 | Returns 147 | Predicted linear spectrogram tensor with shape of [N, Ty, 1+n_fft//2]. 148 | ''' 149 | with tf.variable_scope(scope, reuse=reuse): 150 | # Restore shape -> (N, Ty, n_mels) 151 | inputs = tf.reshape(inputs, [tf.shape(inputs)[0], -1, hp.n_mels]) 152 | 153 | # Conv1D bank 154 | dec = conv1d_banks(inputs, K=hp.decoder_num_banks, is_training=is_training) # (N, Ty, E*K/2) 155 | 156 | # Max pooling 157 | dec = tf.layers.max_pooling1d(dec, pool_size=2, strides=1, padding="same") # (N, Ty, E*K/2) 158 | 159 | ## Conv1D projections 160 | dec = conv1d(dec, filters=hp.embed_size // 2, size=3, scope="conv1d_1") # (N, Tx, E/2) 161 | dec = bn(dec, is_training=is_training, activation_fn=tf.nn.relu, scope="conv1d_1") 162 | 163 | dec = conv1d(dec, filters=hp.n_mels, size=3, scope="conv1d_2") # (N, Tx, E/2) 164 | dec = bn(dec, is_training=is_training, scope="conv1d_2") 165 | 166 | # Extra affine transformation for dimensionality sync 167 | dec = tf.layers.dense(dec, hp.embed_size//2) # (N, Ty, E/2) 168 | 169 | # Highway Nets 170 | for i in range(4): 171 | dec = highwaynet(dec, num_units=hp.embed_size//2, 172 | scope='highwaynet_{}'.format(i)) # (N, Ty, E/2) 173 | 174 | # Bidirectional GRU 175 | dec = gru(dec, hp.embed_size//2, bidirection=True) # (N, Ty, E) 176 | 177 | # Outputs => (N, Ty, 1+n_fft//2) 178 | outputs = tf.layers.dense(dec, 1+hp.n_fft//2) 179 | 180 | return outputs 181 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | from utils import load_spectrograms 11 | import os 12 | from data_load import load_data 13 | import numpy as np 14 | from tqdm import tqdm 15 | from multiprocessing import Pool 16 | 17 | NUM_JOBS = 4 18 | 19 | # Utility function 20 | def f(fpath): 21 | fname, mel, mag = load_spectrograms(fpath) 22 | np.save("mels/{}".format(fname.replace("wav", "npy")), mel) 23 | np.save("mags/{}".format(fname.replace("wav", "npy")), mag) 24 | return None 25 | 26 | # Load data 27 | fpaths, _ = load_data() # list 28 | 29 | # Creates folders 30 | if not os.path.exists("mels"): os.mkdir("mels") 31 | if not os.path.exists("mags"): os.mkdir("mags") 32 | 33 | # Creates pool 34 | p = Pool(NUM_JOBS) 35 | 36 | total_files = len(fpaths) 37 | with tqdm(total=total_files) as pbar: 38 | for _ in tqdm(p.imap_unordered(f, fpaths)): 39 | pbar.update() -------------------------------------------------------------------------------- /ref1/01_au_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/01_au_f.wav -------------------------------------------------------------------------------- /ref1/02_au_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/02_au_f.wav -------------------------------------------------------------------------------- /ref1/03_au_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/03_au_f.wav -------------------------------------------------------------------------------- /ref1/04_au_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/04_au_f.wav -------------------------------------------------------------------------------- /ref1/05_au_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/05_au_f.wav -------------------------------------------------------------------------------- /ref1/06_br_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/06_br_f.wav -------------------------------------------------------------------------------- /ref1/07_br_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/07_br_f.wav -------------------------------------------------------------------------------- /ref1/08_br_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/08_br_f.wav -------------------------------------------------------------------------------- /ref1/09_br_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/09_br_f.wav -------------------------------------------------------------------------------- /ref1/10_br_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/10_br_f.wav -------------------------------------------------------------------------------- /ref1/11_am_f1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/11_am_f1.wav -------------------------------------------------------------------------------- /ref1/12_am_f1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/12_am_f1.wav -------------------------------------------------------------------------------- /ref1/13_am_m.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/13_am_m.wav -------------------------------------------------------------------------------- /ref1/14_am_m.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/14_am_m.wav -------------------------------------------------------------------------------- /ref1/15_am_m.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/15_am_m.wav -------------------------------------------------------------------------------- /ref1/16_am_m.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref1/16_am_m.wav -------------------------------------------------------------------------------- /ref2/01_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/01_am_f.wav -------------------------------------------------------------------------------- /ref2/02_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/02_am_f.wav -------------------------------------------------------------------------------- /ref2/03_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/03_am_f.wav -------------------------------------------------------------------------------- /ref2/04_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/04_am_f.wav -------------------------------------------------------------------------------- /ref2/05_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/05_am_f.wav -------------------------------------------------------------------------------- /ref2/06_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/06_am_f.wav -------------------------------------------------------------------------------- /ref2/07_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/07_am_f.wav -------------------------------------------------------------------------------- /ref2/08_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/08_am_f.wav -------------------------------------------------------------------------------- /ref2/09_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/09_am_f.wav -------------------------------------------------------------------------------- /ref2/10_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/10_am_f.wav -------------------------------------------------------------------------------- /ref2/11_am_f2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/11_am_f2.wav -------------------------------------------------------------------------------- /ref2/12_am_f2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/12_am_f2.wav -------------------------------------------------------------------------------- /ref2/13_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/13_am_f.wav -------------------------------------------------------------------------------- /ref2/14_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/14_am_f.wav -------------------------------------------------------------------------------- /ref2/15_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/15_am_f.wav -------------------------------------------------------------------------------- /ref2/16_am_f.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/expressive_tacotron/49293bbe6eb6034e2e214483c49fcf709ffbf878/ref2/16_am_f.wav -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # /usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/tacotron 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | from hyperparams import Hyperparams as hp 11 | import tqdm 12 | from data_load import load_data 13 | import tensorflow as tf 14 | from train import Graph 15 | from utils import spectrogram2wav, load_spectrograms 16 | from scipy.io.wavfile import write 17 | import os 18 | import sys 19 | from glob import glob 20 | import numpy as np 21 | from math import ceil 22 | 23 | 24 | def looper(ref, start, batch_size): 25 | num = int(ceil(float(ref.shape[0]) / batch_size)) + 1 26 | tiled = np.tile(ref, (num, 1, 1))[start:start + batch_size] 27 | return tiled, start + batch_size % ref.shape[0] 28 | 29 | 30 | def synthesize(): 31 | if not os.path.exists(hp.sampledir): 32 | os.mkdir(hp.sampledir) 33 | 34 | # Load data 35 | texts = load_data(mode="synthesize") 36 | 37 | # pad texts to multiple of batch_size 38 | texts_len = texts.shape[0] 39 | num_batches = int(ceil(float(texts_len) / hp.batch_size)) 40 | padding_len = num_batches * hp.batch_size - texts_len 41 | texts = np.pad( 42 | texts, ((0, padding_len), (0, 0)), 'constant', constant_values=0 43 | ) 44 | 45 | # reference audio 46 | mels, maxlen = [], 0 47 | files = glob(hp.ref_audio) 48 | for f in files: 49 | _, mel, _ = load_spectrograms(f) 50 | mel = np.reshape(mel, (-1, hp.n_mels)) 51 | maxlen = max(maxlen, mel.shape[0]) 52 | mels.append(mel) 53 | 54 | ref = np.zeros((len(mels), maxlen, hp.n_mels), np.float32) 55 | for i, m in enumerate(mels): 56 | ref[i, :m.shape[0], :] = m 57 | 58 | # Load graph 59 | g = Graph(mode="synthesize") 60 | print("Graph loaded") 61 | 62 | saver = tf.train.Saver() 63 | with tf.Session() as sess: 64 | if len(sys.argv) == 1: 65 | saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)) 66 | print("Restored latest checkpoint") 67 | else: 68 | saver.restore(sess, sys.argv[1]) 69 | print("Restored checkpoint: %s" % sys.argv[1]) 70 | 71 | batches = [ 72 | texts[i:i + hp.batch_size] 73 | for i in range(0, texts.shape[0], hp.batch_size) 74 | ] 75 | start = 0 76 | batch_index = 0 77 | # Feed Forward 78 | for batch in batches: 79 | ref_batch, start = looper(ref, start, hp.batch_size) 80 | ## mel 81 | y_hat = np.zeros( 82 | (batch.shape[0], 200, hp.n_mels * hp.r), np.float32 83 | ) # hp.n_mels*hp.r 84 | for j in tqdm.tqdm(range(200)): 85 | _y_hat = sess.run( 86 | g.y_hat, {g.x: batch, 87 | g.y: y_hat, 88 | g.ref: ref_batch} 89 | ) 90 | y_hat[:, j, :] = _y_hat[:, j, :] 91 | ## mag 92 | mags = sess.run(g.z_hat, {g.y_hat: y_hat}) 93 | for i, mag in enumerate(mags): 94 | index_label = batch_index * hp.batch_size + i + 1 95 | if index_label > texts_len: 96 | break 97 | print("File {}.wav is being generated ...".format(index_label)) 98 | audio = spectrogram2wav(mag) 99 | write( 100 | os.path.join(hp.sampledir, '{}.wav'.format(index_label)), 101 | hp.sr, audio 102 | ) 103 | 104 | batch_index += 1 105 | 106 | 107 | if __name__ == '__main__': 108 | synthesize() 109 | print("Done") 110 | -------------------------------------------------------------------------------- /test_sents.txt: -------------------------------------------------------------------------------- 1 | Audio samples from "Towards End-to-End Prosody Transfer for Expressive Speech Synthesis with Tacotron" https://google.github.io/tacotron/publications/end_to_end_prosody_transfer/ 2 | 1. How do bureaucrats wrap presents? With lots of red tape. 3 | 2. Why are libraries so strict? They have to go by the book. 4 | 3. Why are fish so smart? Because they hang out in schools so much. 5 | 4. Heaps of things. Like fairy bread, how the surf is today and why magpies swoop. 6 | 5. The past, the present, and the future walk into a bar. It was tense. 7 | 6. I usually down a cup of java script. Then I put on nature sounds and run a few strenuous searches to improve my speed 8 | 7. I don't have eyes, but I don't need them to know the vibe in here feels good 9 | 8. What time do you go to the dentist? At tooth-hurty! 10 | 9. Sweet dreams are made of these. Friendly Assistants who work hard to please 11 | 10. You are what you eat. So I guess I'm a whole lot of data and a little bit of pizza recipes. 12 | 11. Men say they know many things; But lo! they have taken wings, The arts and sciences, And a thousand appliances; The wind that blows Is all that any body knows. 13 | 12. Do you prefer chocolate or jelly? Which would you like in your belly? You could make a good case, For a cool ice cream base, But I'd argue against vermicelli 14 | 13. Halloween Edition it is! Remember to follow the moves as I say them. 15 | 14. Why are archaeologists so annoyed? They always have a bone to pick. 16 | 15. That one sailed RIGHT over my head. 17 | 16. Wear your heart on your sleeve. It'll terrify people. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | import sys 11 | import os 12 | from hyperparams import Hyperparams as hp 13 | import tensorflow as tf 14 | from tqdm import tqdm 15 | from data_load import get_batch, load_vocab 16 | from modules import * 17 | from networks import transcript_encoder, reference_encoder, decoder1, decoder2 18 | from utils import * 19 | 20 | class Graph: 21 | def __init__(self, mode="train"): 22 | # Load vocabulary 23 | self.char2idx, self.idx2char = load_vocab() 24 | 25 | # Set phase 26 | is_training=True if mode=="train" else False 27 | 28 | # Graph 29 | # Data Feeding 30 | # x: Text. int32. (N, Tx) or (32, 188) 31 | # y: Reduced melspectrogram. float32. (N, Ty//r, n_mels*r) or (32, ?, 400) 32 | # z: Magnitude. (N, Ty, n_fft//2+1) or (32, ?, 1025) 33 | # ref: Melspectrogram of Reference audio. float32. (N, Ty, n_mels) or (32, ?, 80) 34 | if mode=="train": 35 | self.x, self.y, self.z, self.fnames, self.num_batch = get_batch() 36 | self.ref = tf.reshape(self.y, (hp.batch_size, -1, hp.n_mels)) 37 | else: # Synthesize 38 | self.x = tf.placeholder(tf.int32, shape=(hp.batch_size, hp.Tx)) 39 | self.y = tf.placeholder(tf.float32, shape=(hp.batch_size, None, hp.n_mels*hp.r)) 40 | self.ref = tf.placeholder(tf.float32, shape=(hp.batch_size, None, hp.n_mels)) 41 | 42 | # Get encoder/decoder inputs 43 | self.transcript_inputs = embed(self.x, len(hp.vocab), hp.embed_size) # (N, Tx, E) 44 | self.reference_inputs = tf.expand_dims(self.ref, -1) 45 | 46 | self.decoder_inputs = tf.concat((tf.zeros_like(self.y[:, :1, :]), self.y[:, :-1, :]), 1) # (N, Ty/r, n_mels*r) 47 | self.decoder_inputs = self.decoder_inputs[:, :, -hp.n_mels:] # feed last frames only (N, Ty/r, n_mels) 48 | 49 | # Networks 50 | with tf.variable_scope("net"): 51 | # Encoder 52 | self.texts = transcript_encoder(self.transcript_inputs, is_training=is_training) # (N, Tx=188, E) 53 | self.prosody = reference_encoder(self.reference_inputs, is_training=is_training) # (N, 128) 54 | self.prosody = tf.expand_dims(self.prosody, 1) # (N, 1, 128) 55 | self.prosody = tf.tile(self.prosody, (1, hp.Tx, 1)) # (N, Tx=188, 128) 56 | self.memory = tf.concat((self.texts, self.prosody), -1) # (N, Tx, E+128) 57 | 58 | # Decoder1 59 | self.y_hat, self.alignments = decoder1(self.decoder_inputs, 60 | self.memory, 61 | is_training=is_training) # (N, T_y//r, n_mels*r) 62 | # Decoder2 or postprocessing 63 | self.z_hat = decoder2(self.y_hat, is_training=is_training) # (N, T_y//r, (1+n_fft//2)*r) 64 | 65 | # monitor 66 | self.audio = tf.py_func(spectrogram2wav, [self.z_hat[0]], tf.float32) 67 | 68 | if mode=="train": 69 | # Loss 70 | self.loss1 = tf.reduce_mean(tf.abs(self.y_hat - self.y)) 71 | self.loss2 = tf.reduce_mean(tf.abs(self.z_hat - self.z)) 72 | self.loss = self.loss1 + self.loss2 73 | 74 | # Training Scheme 75 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 76 | self.lr = learning_rate_decay(hp.lr, global_step=self.global_step) 77 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 78 | 79 | ## gradient clipping 80 | self.gvs = self.optimizer.compute_gradients(self.loss) 81 | self.clipped = [] 82 | for grad, var in self.gvs: 83 | grad = tf.clip_by_norm(grad, 5.) 84 | self.clipped.append((grad, var)) 85 | self.train_op = self.optimizer.apply_gradients(self.clipped, global_step=self.global_step) 86 | 87 | # Summary 88 | tf.summary.scalar('{}/loss1'.format(mode), self.loss1) 89 | tf.summary.scalar('{}/loss2'.format(mode), self.loss2) 90 | tf.summary.scalar('{}/lr'.format(mode), self.lr) 91 | 92 | tf.summary.image("{}/mel_gt".format(mode), tf.expand_dims(self.y, -1), max_outputs=1) 93 | tf.summary.image("{}/mel_hat".format(mode), tf.expand_dims(self.y_hat, -1), max_outputs=1) 94 | tf.summary.image("{}/mag_gt".format(mode), tf.expand_dims(self.z, -1), max_outputs=1) 95 | tf.summary.image("{}/mag_hat".format(mode), tf.expand_dims(self.z_hat, -1), max_outputs=1) 96 | 97 | tf.summary.audio("{}/sample".format(mode), tf.expand_dims(self.audio, 0), hp.sr) 98 | self.merged = tf.summary.merge_all() 99 | 100 | if __name__ == '__main__': 101 | g = Graph(); print("Training Graph loaded") 102 | 103 | sv = tf.train.Supervisor(logdir=hp.logdir, save_summaries_secs=60, save_model_secs=0) 104 | with sv.managed_session() as sess: 105 | 106 | if len(sys.argv) == 2: 107 | sv.saver.restore(sess, sys.argv[1]) 108 | print("Model restored.") 109 | 110 | while 1: 111 | for _ in tqdm(range(g.num_batch), total=g.num_batch, ncols=70, leave=False, unit='b'): 112 | _, gs = sess.run([g.train_op, g.global_step]) 113 | 114 | # Write checkpoint files 115 | if gs % 1000 == 0: 116 | sv.saver.save(sess, hp.logdir + '/model_gs_{}k'.format(gs//1000)) 117 | 118 | # plot the first alignment for logging 119 | al = sess.run(g.alignments) 120 | plot_alignment(al[0], gs) 121 | 122 | if gs > hp.num_iterations: 123 | break 124 | 125 | print("Done") 126 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # /usr/bin/python2 3 | ''' 4 | By kyubyong park. kbpark.linguist@gmail.com. 5 | https://www.github.com/kyubyong/expressive_tacotron 6 | ''' 7 | from __future__ import print_function, division 8 | 9 | from hyperparams import Hyperparams as hp 10 | import numpy as np 11 | import tensorflow as tf 12 | import librosa 13 | import copy 14 | import matplotlib 15 | matplotlib.use('pdf') 16 | import matplotlib.pyplot as plt 17 | from scipy import signal 18 | import os 19 | 20 | 21 | def get_spectrograms(fpath): 22 | '''Returns normalized log(melspectrogram) and log(magnitude) from `sound_file`. 23 | Args: 24 | sound_file: A string. The full path of a sound file. 25 | 26 | Returns: 27 | mel: A 2d array of shape (T, n_mels) <- Transposed 28 | mag: A 2d array of shape (T, 1+n_fft/2) <- Transposed 29 | ''' 30 | # Loading sound file 31 | y, sr = librosa.load(fpath, sr=hp.sr) 32 | 33 | # Trimming 34 | y, _ = librosa.effects.trim(y) 35 | 36 | # Preemphasis 37 | y = np.append(y[0], y[1:] - hp.preemphasis * y[:-1]) 38 | 39 | # stft 40 | linear = librosa.stft(y=y, 41 | n_fft=hp.n_fft, 42 | hop_length=hp.hop_length, 43 | win_length=hp.win_length) 44 | 45 | # magnitude spectrogram 46 | mag = np.abs(linear) # (1+n_fft//2, T) 47 | 48 | # mel spectrogram 49 | mel_basis = librosa.filters.mel(hp.sr, hp.n_fft, hp.n_mels) # (n_mels, 1+n_fft//2) 50 | mel = np.dot(mel_basis, mag) # (n_mels, t) 51 | 52 | # to decibel 53 | mel = 20 * np.log10(np.maximum(1e-5, mel)) 54 | mag = 20 * np.log10(np.maximum(1e-5, mag)) 55 | 56 | # normalize 57 | mel = np.clip((mel - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 58 | mag = np.clip((mag - hp.ref_db + hp.max_db) / hp.max_db, 1e-8, 1) 59 | 60 | # Transpose 61 | mel = mel.T.astype(np.float32) # (T, n_mels) 62 | mag = mag.T.astype(np.float32) # (T, 1+n_fft//2) 63 | 64 | return mel, mag 65 | 66 | 67 | def spectrogram2wav(mag): 68 | '''# Generate wave file from spectrogram''' 69 | # transpose 70 | mag = mag.T 71 | 72 | # de-noramlize 73 | mag = (np.clip(mag, 0, 1) * hp.max_db) - hp.max_db + hp.ref_db 74 | 75 | # to amplitude 76 | mag = np.power(10.0, mag * 0.05) 77 | 78 | # wav reconstruction 79 | wav = griffin_lim(mag) 80 | 81 | # de-preemphasis 82 | wav = signal.lfilter([1], [1, -hp.preemphasis], wav) 83 | 84 | # trim 85 | wav, _ = librosa.effects.trim(wav) 86 | 87 | return wav.astype(np.float32) 88 | 89 | 90 | def griffin_lim(spectrogram): 91 | '''Applies Griffin-Lim's raw. 92 | ''' 93 | X_best = copy.deepcopy(spectrogram) 94 | for i in range(hp.n_iter): 95 | X_t = invert_spectrogram(X_best) 96 | est = librosa.stft(X_t, hp.n_fft, hp.hop_length, win_length=hp.win_length) 97 | phase = est / np.maximum(1e-8, np.abs(est)) 98 | X_best = spectrogram * phase 99 | X_t = invert_spectrogram(X_best) 100 | y = np.real(X_t) 101 | 102 | return y 103 | 104 | 105 | def invert_spectrogram(spectrogram): 106 | ''' 107 | spectrogram: [f, t] 108 | ''' 109 | return librosa.istft(spectrogram, hp.hop_length, win_length=hp.win_length, window="hann") 110 | 111 | 112 | def plot_alignment(alignment, gs): 113 | """Plots the alignment 114 | alignments: A list of (numpy) matrix of shape (encoder_steps, decoder_steps) 115 | gs : (int) global step 116 | """ 117 | fig, ax = plt.subplots() 118 | im = ax.imshow(alignment) 119 | 120 | # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) 121 | fig.colorbar(im) 122 | plt.title('{} Steps'.format(gs)) 123 | plt.savefig('{}/alignment_{}k.png'.format(hp.logdir, gs//1000), format='png') 124 | plt.close(fig) 125 | 126 | def learning_rate_decay(init_lr, global_step, warmup_steps=4000.): 127 | '''Noam scheme from tensor2tensor''' 128 | step = tf.cast(global_step + 1, dtype=tf.float32) 129 | return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5) 130 | 131 | def load_spectrograms(fpath): 132 | fname = os.path.basename(fpath) 133 | mel, mag = get_spectrograms(fpath) 134 | t = mel.shape[0] 135 | num_paddings = hp.r - (t % hp.r) if t % hp.r != 0 else 0 # for reduction 136 | mel = np.pad(mel, [[0, num_paddings], [0, 0]], mode="constant") 137 | mag = np.pad(mag, [[0, num_paddings], [0, 0]], mode="constant") 138 | return fname, mel.reshape((-1, hp.n_mels*hp.r)), mag --------------------------------------------------------------------------------