├── .gitignore ├── .gitmodules ├── CoordConv.py ├── LICENSE ├── README.md ├── app.py ├── audio_processing.py ├── data_utils.py ├── demo_guide.md ├── distributed.py ├── filelists ├── iemocap_spk_emo_all_test.txt ├── iemocap_spk_emo_all_train.txt ├── iemocap_spk_emo_all_valid.txt ├── koemo_spk_emo_all6_test.txt ├── koemo_spk_emo_all6_train.txt ├── koemo_spk_emo_all6_valid.txt ├── koemo_spk_emo_all_test.txt ├── koemo_spk_emo_all_train.txt └── koemo_spk_emo_all_valid.txt ├── fp16_optimizer.py ├── hparams.py ├── inference.ipynb ├── layers.py ├── logger.py ├── loss_function.py ├── loss_scaler.py ├── model.py ├── modules.py ├── multiproc.py ├── plotting_utils.py ├── requirements.txt ├── res ├── alignment.gif ├── demo.png ├── interpolation.png ├── kldiv.png ├── overview.png ├── reconloss.png ├── scatter.png ├── trainingloss.png ├── tsne.png └── validloss.png ├── samples ├── interpolation │ ├── ang0.3_sad0.6.wav │ ├── ang0.6_sad0.3.wav │ ├── ang1.0.wav │ ├── hap0.3_sad0.6.wav │ ├── hap0.6_sad0.3.wav │ ├── hap1.0.wav │ ├── neu0.3_sad0.6.wav │ ├── neu0.6_sad0.3.wav │ ├── neu1.0.wav │ └── sad1.0.wav ├── mix │ ├── hap0.25_ang0.75.wav │ ├── hap0.25_sad0.25_ang0.5.wav │ └── neu0.25_hap0.25_sad0.25_ang0.25.wav └── refs │ ├── recorded_ang.wav │ ├── recorded_hap.wav │ ├── recorded_neu.wav │ ├── recorded_sad.wav │ ├── ref_ang.wav │ ├── ref_hap.wav │ ├── ref_neu.wav │ └── ref_sad.wav ├── stft.py ├── synthesizer.py ├── text ├── LICENSE ├── __init__.py ├── cleaners.py ├── cmudict.py ├── ko_dictionary.py ├── korean.py ├── numbers_.py └── symbols.py ├── train.py ├── utils.py └── web ├── static ├── css │ └── main.css ├── js │ └── main.js └── uploads │ ├── KoreanEmotionSpeech │ └── koemo_spk_emo_all_test.txt └── templates └── index.html /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | .vscode/ 4 | .git/ 5 | web/audio/ 6 | web/static/uploads/ 7 | README_bac.md 8 | notes.md 9 | models/ 10 | outdir/ 11 | filelists/ 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "waveglow"] 2 | path = waveglow 3 | url = https://github.com/NVIDIA/waveglow 4 | branch = master 5 | -------------------------------------------------------------------------------- /CoordConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.modules.conv as conv 4 | from hparams import create_hparams 5 | 6 | hparams = create_hparams() 7 | 8 | class AddCoords(nn.Module): 9 | def __init__(self, rank, with_r=False): 10 | super(AddCoords, self).__init__() 11 | self.rank = rank 12 | self.with_r = with_r 13 | 14 | def forward(self, input_tensor): 15 | """ 16 | :param input_tensor: shape (N, C_in, H, W) 17 | :return: 18 | """ 19 | if self.rank == 1: 20 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape 21 | xx_range = torch.arange(dim_x, dtype=torch.int32) 22 | xx_channel = xx_range[None, None, :] 23 | 24 | xx_channel = xx_channel.float() / (dim_x - 1) 25 | xx_channel = xx_channel * 2 - 1 26 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) 27 | 28 | if torch.cuda.is_available: 29 | input_tensor = input_tensor.cuda() 30 | xx_channel = xx_channel.cuda() 31 | out = torch.cat([input_tensor, xx_channel], dim=1) 32 | 33 | if self.with_r: 34 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) 35 | out = torch.cat([out, rr], dim=1) 36 | 37 | elif self.rank == 2: 38 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape 39 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) 40 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) 41 | 42 | xx_range = torch.arange(dim_y, dtype=torch.int32) 43 | yy_range = torch.arange(dim_x, dtype=torch.int32) 44 | xx_range = xx_range[None, None, :, None] 45 | yy_range = yy_range[None, None, :, None] 46 | 47 | xx_channel = torch.matmul(xx_range, xx_ones) 48 | yy_channel = torch.matmul(yy_range, yy_ones) 49 | 50 | # transpose y 51 | yy_channel = yy_channel.permute(0, 1, 3, 2) 52 | 53 | xx_channel = xx_channel.float() / (dim_y - 1) 54 | yy_channel = yy_channel.float() / (dim_x - 1) 55 | 56 | xx_channel = xx_channel * 2 - 1 57 | yy_channel = yy_channel * 2 - 1 58 | 59 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) 60 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) 61 | 62 | if torch.cuda.is_available: 63 | input_tensor = input_tensor.cuda() 64 | xx_channel = xx_channel.cuda() 65 | yy_channel = yy_channel.cuda() 66 | if hparams.fp16_run: 67 | input_tensor = input_tensor.half() 68 | xx_channel = xx_channel.half() 69 | yy_channel = yy_channel.half() 70 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 71 | 72 | if self.with_r: 73 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 74 | out = torch.cat([out, rr], dim=1) 75 | 76 | elif self.rank == 3: 77 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape 78 | xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) 79 | yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) 80 | zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) 81 | 82 | xy_range = torch.arange(dim_y, dtype=torch.int32) 83 | xy_range = xy_range[None, None, None, :, None] 84 | 85 | yz_range = torch.arange(dim_z, dtype=torch.int32) 86 | yz_range = yz_range[None, None, None, :, None] 87 | 88 | zx_range = torch.arange(dim_x, dtype=torch.int32) 89 | zx_range = zx_range[None, None, None, :, None] 90 | 91 | xy_channel = torch.matmul(xy_range, xx_ones) 92 | xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) 93 | 94 | yz_channel = torch.matmul(yz_range, yy_ones) 95 | yz_channel = yz_channel.permute(0, 1, 3, 4, 2) 96 | yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) 97 | 98 | zx_channel = torch.matmul(zx_range, zz_ones) 99 | zx_channel = zx_channel.permute(0, 1, 4, 2, 3) 100 | zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) 101 | 102 | if torch.cuda.is_available: 103 | input_tensor = input_tensor.cuda() 104 | xx_channel = xx_channel.cuda() 105 | yy_channel = yy_channel.cuda() 106 | zz_channel = zz_channel.cuda() 107 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1) 108 | 109 | if self.with_r: 110 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + 111 | torch.pow(yy_channel - 0.5, 2) + 112 | torch.pow(zz_channel - 0.5, 2)) 113 | out = torch.cat([out, rr], dim=1) 114 | else: 115 | raise NotImplementedError 116 | 117 | return out 118 | 119 | 120 | class CoordConv1d(conv.Conv1d): 121 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 122 | padding=0, dilation=1, groups=1, bias=True, with_r=False): 123 | super(CoordConv1d, self).__init__(in_channels, out_channels, kernel_size, 124 | stride, padding, dilation, groups, bias) 125 | self.rank = 1 126 | self.addcoords = AddCoords(self.rank, with_r) 127 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels, 128 | kernel_size, stride, padding, dilation, groups, bias) 129 | 130 | def forward(self, input_tensor): 131 | """ 132 | input_tensor_shape: (N, C_in,H,W) 133 | output_tensor_shape: N,C_out,H_out,W_out) 134 | :return: CoordConv2d Result 135 | """ 136 | out = self.addcoords(input_tensor) 137 | out = self.conv(out) 138 | 139 | return out 140 | 141 | 142 | class CoordConv2d(conv.Conv2d): 143 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 144 | padding=0, dilation=1, groups=1, bias=True, with_r=False): 145 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size, 146 | stride, padding, dilation, groups, bias) 147 | self.rank = 2 148 | self.addcoords = AddCoords(self.rank, with_r) 149 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, 150 | kernel_size, stride, padding, dilation, groups, bias) 151 | 152 | def forward(self, input_tensor): 153 | """ 154 | input_tensor_shape: (N, C_in,H,W) 155 | output_tensor_shape: N,C_out,H_out,W_out) 156 | :return: CoordConv2d Result 157 | """ 158 | out = self.addcoords(input_tensor) 159 | out = self.conv(out) 160 | 161 | return out 162 | 163 | 164 | class CoordConv3d(conv.Conv3d): 165 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 166 | padding=0, dilation=1, groups=1, bias=True, with_r=False): 167 | super(CoordConv3d, self).__init__(in_channels, out_channels, kernel_size, 168 | stride, padding, dilation, groups, bias) 169 | self.rank = 3 170 | self.addcoords = AddCoords(self.rank, with_r) 171 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels, 172 | kernel_size, stride, padding, dilation, groups, bias) 173 | 174 | def forward(self, input_tensor): 175 | """ 176 | input_tensor_shape: (N, C_in,H,W) 177 | output_tensor_shape: N,C_out,H_out,W_out) 178 | :return: CoordConv2d Result 179 | """ 180 | out = self.addcoords(input_tensor) 181 | out = self.conv(out) 182 | 183 | return out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tacotron2-vae 2 | 3 | ## Overview 4 | ![overview](./res/overview.png) 5 | 6 | - Generate emotional voices by receiving text and emotional label as input 7 | - Korean Version of ["Learning Latent Representations for Style Control and Transfer in End-to-end Speech Synthesis"](https://arxiv.org/pdf/1812.04342.pdf) 8 | 9 | ## Data 10 | 1. Dataset 11 | * Korean Speech Emotion Dataset ([more info](http://aicompanion.or.kr/kor/main/)) 12 | * Single Female Voice Actor recorded six diffrent emotions(neutral, happy, sad, angry, disgust, fearful), each with 3,000 sentences. 13 | * For training, I used four emotions(neutral,happy,sad,angry), total 21.25 hours 14 | 15 | 2. Text 16 | * test: `python -m text.cleaners` 17 | * examples 18 | ``` 19 | 감정있는 한국어 목소리 생성 20 | ==> 21 | ['ᄀ', 'ㅏ', 'ㅁ', 'ᄌ', 'ㅓ', 'ㅇ', 'ᄋ', 'ㅣ', 'ㅆ', 'ᄂ', 'ㅡ', 'ㄴ', ' ', 'ᄒ', 'ㅏ', 'ㄴ', 'ᄀ', 'ㅜ', 'ㄱ', 'ᄋ', 'ㅓ', ' ', 'ᄆ', 'ㅗ', 'ㄱ', 'ᄉ', 'ㅗ', 'ᄅ', 'ㅣ', ' ', 'ᄉ', 'ㅐ', 'ㅇ', 'ᄉ', 'ㅓ', 'ㅇ', '~'] 22 | ==> 23 | [2, 21, 57, 14, 25, 62, 13, 41, 61, 4, 39, 45, 79, 20, 21, 45, 2, 34, 42, 13, 25, 79, 8, 29, 42, 11, 29, 7, 41, 79, 11, 22, 62, 11, 25, 62, 1] 24 | ``` 25 | 26 | 3. Audio 27 | * sampling rate: 16000 28 | * filter length: 1024 29 | * hop length: 256 30 | * win length: 1024 31 | * n_mel: 80 32 | * mel_fmin: 0 33 | * mel_fmax: 8000 34 | 35 | 4. Training files 36 | * `./filelists/*.txt` 37 | * path | text | speaker | emotion 38 | * examples 39 | ``` 40 | /KoreanEmotionSpeech/wav/neu/neu_00002289.wav|선생님이 초록색으로 머리를 염색하고 나타나서 모두들 깜짝 놀랐다.|0|0 41 | /KoreanEmotionSpeech/wav/sad/sad_00002266.wav|과외 선생님이 열심히 지도해준 덕택에 수학실력이 점점 늘고 있다.|0|1 42 | /KoreanEmotionSpeech/wav/ang/ang_00000019.wav|명백한 것은 각 당이 투사하고 있는 실상과 허상이 있다면 이제 허상은 걷어들여야 한다는 것이다.|0|2 43 | /KoreanEmotionSpeech/wav/hap/hap_00001920.wav|강력한 스크럽으로 상쾌한 양치효과를 주네요.|0|3 44 | ``` 45 | 46 | ## Training 47 | 1. Prepare Datasets 48 | 2. Clone this repo: `git clone https://github.com/jinhan/tacotron2-vae.git` 49 | 3. CD into this repo: `cd tacotron2-vae` 50 | 4. Initialize submodule: `git submodule init; git submodule update` 51 | 5. Update .wav paths: `sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt` 52 | 6. Install requirements `pip install -r requirements.txt` 53 | 7. Training: `python train.py --output_directory=outdir --log_directory=logdir -- hparams=training_files='filelists/koemo_spk_emo_all_train.txt',validation_files='filelists/koemo_spk_emo_all_valid.txt',anneal_function='constant',batch_size=6` 54 | 8. Monitoring: `tensorboard --logdir=outdir/logdir --host=127.0.0.1` 55 | 9. Training results (~ 250,000 steps) 56 | 57 | ![alignment](./res/alignment.gif) 58 | 59 | ## Visualization 60 | source: `inference.ipynb` 61 | 62 | 1. Load Models 63 | 64 | tacotron2-vae model 65 | 66 | ```python 67 | model = load_model(hparams) 68 | model.load_state_dict(torch.load(checkpoint_path)['state_dict']) 69 | _ = model.eval() 70 | ``` 71 | 72 | WaveGlow vocoder model 73 | 74 | ```python 75 | waveglow = torch.load(waveglow_path)['model'] 76 | waveglow.cuda() 77 | ``` 78 | 79 | 2. Load Data 80 | - 'Prosody' is the output of the fully connected layer to match dimension of z and text-encoded output. 81 | 82 | ```python 83 | path = './filelists/koemo_spk_emo_all_test.txt' 84 | with open(path, encoding='utf-8') as f: 85 | filepaths_and_text = [line.strip().split("|") for line in f] 86 | 87 | model.eval() 88 | prosody_outputs = [] 89 | emotions = [] 90 | mus = [] 91 | zs = [] 92 | 93 | for audio_path, _, _, emotion in tqdm(filepaths_and_text): 94 | melspec = load_mel(audio_path) 95 | prosody, mu, _, z = model.vae_gst(melspec) 96 | prosody_outputs.append(prosody.squeeze(1).cpu().data) 97 | mus.append(mu.cpu().data) 98 | zs.append(z.cpu().data) 99 | emotions.append(int(emotion)) 100 | 101 | prosody_outputs = torch.cat(prosody_outputs, dim=0) 102 | emotions = np.array(emotions) 103 | mus = torch.cat(mus, dim=0) 104 | zs = torch.cat(zs, dim=0) 105 | ``` 106 | 107 | 3. Scatter plot 108 | 109 | ```python 110 | colors = 'r','b','g','y' 111 | labels = 'neu','sad','ang','hap' 112 | 113 | data_x = mus.data.numpy() 114 | data_y = emotions 115 | 116 | plt.figure(figsize=(10,10)) 117 | for i, (c, label) in enumerate(zip(colors, labels)): 118 | plt.scatter(data_x[data_y==i,0], data_x[data_y==i,1], c=c, label=label, alpha=0.5) 119 | 120 | axes = plt.gca() 121 | plt.grid(True) 122 | plt.legend(loc='upper left') 123 | ``` 124 | ![scatter plot](./res/scatter.png) 125 | 126 | 4. t-SNE plot 127 | 128 | ```python 129 | colors = 'r','b','g','y' 130 | labels = 'neu','sad','ang','hap' 131 | 132 | data_x = mus 133 | data_y = emotions 134 | 135 | tsne_model = TSNE(n_components=2, random_state=0, init='random') 136 | tsne_all_data = tsne_model.fit_transform(data_x) 137 | tsne_all_y_data = data_y 138 | 139 | plt.figure(figsize=(10,10)) 140 | for i, (c, label) in enumerate(zip(colors, labels)): 141 | plt.scatter(tsne_all_data[tsne_all_y_data==i,0], tsne_all_data[tsne_all_y_data==i,1], c=c, label=label, alpha=0.5) 142 | 143 | plt.grid(True) 144 | plt.legend(loc='upper left') 145 | ``` 146 | ![t-SNE plot](./res/tsne.png) 147 | 148 | 149 | ## Inference 150 | source: `inference.ipynb` 151 | 152 | ### Reference Audio 153 | - Generate voice that follows the style of the reference audio 154 | 155 | Reference audio 156 | 157 | ```python 158 | def generate_audio_vae_by_ref(text, ref_audio): 159 | transcript_outputs = TextEncoder(text) 160 | 161 | print("reference audio") 162 | ipd.display(ipd.Audio(ref_audio, rate=hparams.sampling_rate)) 163 | 164 | ref_audio_mel = load_mel(ref_audio) 165 | latent_vector, _, _, _ = model.vae_gst(ref_audio_mel) 166 | latent_vector = latent_vector.unsqueeze(1).expand_as(transcript_outputs) 167 | 168 | encoder_outputs = transcript_outputs + latent_vector 169 | 170 | synth, mel_outputs = Decoder(encoder_outputs) 171 | 172 | ipd.display(ipd.Audio(synth[0].data.cpu().numpy(), rate=hparams.sampling_rate)) 173 | ipd.display(plot_data(mel_outputs.data.cpu().numpy()[0])) 174 | ``` 175 | 176 | Generate voice 177 | 178 | ```python 179 | text = "이 모델을 이용하면 같은 문장을 여러가지 스타일로 말할 수 있습니다." 180 | ref_wav = "/KoreanEmotionSpeech/wav/ang/ang_00000100.wav" 181 | generate_audio_vae_by_ref(text, ref_wav) 182 | ``` 183 | ### Interpolation 184 | - Create a new z by multiply ratios to the centroids. 185 | 186 | Interpolation 187 | 188 | ```python 189 | def generate_audio_vae(text, ref_audio, trg_audio, ratios): 190 | transcript_outputs = TextEncoder(text) 191 | 192 | for ratio in ratios: 193 | latent_vector = ref_audio * ratio + trg_audio * (1.0-ratio) 194 | latent_vector = torch.FloatTensor(latent_vector).cuda() 195 | latent_vector = model.vae_gst.fc3(latent_vector) 196 | 197 | encoder_outputs = transcript_outputs + latent_vector 198 | 199 | synth, mel_outputs_postnet = Decoder(encoder_outputs) 200 | ipd.display(ipd.Audio(synth[0].data.cpu().numpy(), rate=hparams.sampling_rate)) 201 | ipd.display(plot_data(mel_outputs_postnet.data.cpu().numpy()[0])) 202 | ``` 203 | 204 | Get Centroids 205 | 206 | ```python 207 | encoded = zs.data.numpy() 208 | neu = np.mean(encoded[emotions==0,:], axis=0) 209 | sad = np.mean(encoded[emotions==1,:], axis=0) 210 | ang = np.mean(encoded[emotions==2,:], axis=0) 211 | hap = np.mean(encoded[emotions==3,:], axis=0) 212 | ``` 213 | 214 | Generate voice 215 | 216 | ```python 217 | text = "이 모델을 이용하면 같은 문장을 여러가지 스타일로 말할 수 있습니다." 218 | ref_audio = hap 219 | trg_audio = sad 220 | ratios = [1.0, 0.64, 0.34, 0.0] 221 | generate_audio_vae(text, ref_audio, trg_audio, ratios) 222 | ``` 223 | 224 | 225 | 226 | ### Mixer 227 | - Result of mixing more than two labels at a desired ratio 228 | 229 | Mixer 230 | 231 | ```python 232 | def generate_audio_vae_mix(text, ratios): 233 | transcript_outputs = TextEncoder(text) 234 | 235 | latent_vector = ratios[0]*neu + ratios[1]*hap + ratios[2]*sad + ratios[3]*ang 236 | latent_vector = torch.FloatTensor(latent_vector).cuda() 237 | latent_vector = model.vae_gst.fc3(latent_vector) 238 | 239 | encoder_outputs = transcript_outputs + latent_vector 240 | 241 | synth, mel_outputs = Decoder(encoder_outputs) 242 | 243 | ipd.display(ipd.Audio(synth[0].data.cpu().numpy(), rate=hparams.sampling_rate)) 244 | ipd.display(plot_data(mel_outputs.data.cpu().numpy()[0])) 245 | ``` 246 | 247 | Generate voice 248 | 249 | ```python 250 | text = "이 모델을 이용하면 같은 문장을 여러가지 스타일로 말할 수 있습니다." 251 | ratios = [0.0, 0.25, 0.0, 0.75] #neu, hap, sad, ang 252 | generate_audio_vae_mix(text, ratios) 253 | ``` 254 | 255 | ## Demo page 256 | 1. Run: `python app.py --checkpoint_path="./models/032902_vae_250000" --waveglow_path="./models/waveglow_130000"` 257 | 2. Mix: Generate voices by adjusting the ratio of netural, sad, happy, and angry 258 | 3. Ref Audio: Generate voices by testset audio as a reference audio 259 | 260 | ![demo page](./res/demo.png) 261 | 262 | 263 | ## Samples 264 | - Interpolation: Result of interpolating between two labels at 1.0, 0.66, 0.33, and 0.0 265 | - refs: Result of recorded audio as a reference audio 266 | - mix: Result of mixing more than two labels at a desired ratio 267 | 268 | ## References 269 | - Tacotron2: https://github.com/NVIDIA/tacotron2 270 | - Prosody Encoder: https://github.com/KinglittleQ/GST-Tacotron/blob/master/GST.py 271 | - WaveGlow: https://github.com/NVIDIA/waveglow 272 | - Export Images from tfevents: https://github.com/anderskm/exportTensorFlowLog 273 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!flask/bin/python 2 | 3 | import os, traceback, json 4 | import hashlib 5 | import argparse 6 | from flask_cors import CORS, cross_origin 7 | from flask import Flask, request, render_template, jsonify, \ 8 | send_from_directory, make_response, send_file 9 | from synthesizer import Synthesizer 10 | from utils import str2bool, makedirs, add_postfix 11 | import base64 12 | 13 | 14 | ROOT_PATH = "web" 15 | AUDIO_DIR = "audio" 16 | AUDIO_PATH = os.path.join(ROOT_PATH, AUDIO_DIR) 17 | 18 | base_path = os.path.dirname(os.path.realpath(__file__)) 19 | static_path = os.path.join(base_path, 'web/static') 20 | 21 | global_config = None 22 | synthesizer = Synthesizer() 23 | app = Flask(__name__, root_path=ROOT_PATH, static_url_path='') 24 | CORS(app) 25 | 26 | def generate_audio_response(text, condition_on_ref, ref_audio, ratios): 27 | hashed_text = hashlib.md5(text.encode('utf-8')).hexdigest() 28 | 29 | relative_dir_path = os.path.join(AUDIO_DIR, 'tacotron2-vae') 30 | relative_audio_path = os.path.join( 31 | relative_dir_path, "{}.wav".format(hashed_text)) 32 | real_path = os.path.join(ROOT_PATH, relative_audio_path) 33 | makedirs(os.path.dirname(real_path)) 34 | print(ref_audio) 35 | if condition_on_ref: 36 | ref_audio = ref_audio.replace('/uploads', '/home/jhoh/dataset') 37 | 38 | try: 39 | synthesizer.synthesize(text, real_path, condition_on_ref, ref_audio, ratios) 40 | except Exception as e: 41 | traceback.print_exc() 42 | return jsonify(success=False), 400 43 | 44 | return send_file( 45 | relative_audio_path, 46 | mimetype="audio/wav", 47 | as_attachment=True, 48 | attachment_filename=hashed_text + ".wav") 49 | 50 | def generate_api_response(args): 51 | print(args) 52 | text = args['text'] 53 | print(text) 54 | condition_on_ref = False 55 | ref_audio = None 56 | 57 | n = float(args['neu']) 58 | s = float(args['sad']) 59 | h = float(args['hap']) 60 | a = float(args['ang']) 61 | sigma = n+s+h+a 62 | if sigma: 63 | ratios = [round(x / sigma * 100)/100 for x in [n, s, h, a]] 64 | else: 65 | ratios = [1.0, 0.0, 0.0, 0.0] 66 | 67 | hashed_text = hashlib.md5(text.encode('utf-8')).hexdigest() 68 | 69 | relative_dir_path = os.path.join(AUDIO_DIR, 'tacotron2-vae') 70 | relative_audio_path = os.path.join( 71 | relative_dir_path, "{}.wav".format(hashed_text)) 72 | real_path = os.path.join(ROOT_PATH, relative_audio_path) 73 | makedirs(os.path.dirname(real_path)) 74 | 75 | if condition_on_ref: 76 | ref_audio = ref_audio.replace('/uploads', '/home/jhoh/dataset') 77 | 78 | try: 79 | synthesizer.synthesize(text, real_path, condition_on_ref, ref_audio, ratios) 80 | except Exception as e: 81 | traceback.print_exc() 82 | return jsonify(success=False), 400 83 | 84 | b64_data = base64.b64encode(open(real_path, "rb").read()) 85 | return json.dumps({"params":{ 86 | "text":text, 87 | "neu": n, "hap": h, "sad": s, "ang": a}, 88 | "data": str(b64_data.decode('utf-8'))}) 89 | 90 | @app.route('/') 91 | def index(): 92 | text = request.args.get('text') or "듣고 싶은 문장을 입력해 주세요." 93 | return render_template('index.html', text=text) 94 | 95 | @app.route('/api', methods=['POST']) 96 | def API(): 97 | args = json.loads(request.data) 98 | 99 | return generate_api_response(args) 100 | 101 | @app.route('/generate') 102 | def view_method(): 103 | text = request.args.get('text') 104 | condition_on_ref = request.args.get('con') 105 | 106 | if text: 107 | if condition_on_ref=='true': 108 | ref_audio = request.args.get('ref') 109 | condition_on_ref = True 110 | ratios = None 111 | 112 | return generate_audio_response(text, condition_on_ref, ref_audio, ratios) # ref_audi, ratios 113 | else: 114 | n = float(request.args.get('n')) 115 | s = float(request.args.get('s')) 116 | h = float(request.args.get('h')) 117 | a = float(request.args.get('a')) 118 | sigma = n+s+h+a 119 | if sigma: 120 | ratios = [round(x / sigma * 100)/100 for x in [n, s, h, a]] 121 | else: 122 | ratios = [1.0, 0.0, 0.0, 0.0] 123 | 124 | ref_audio = None 125 | condition_on_ref = False 126 | 127 | return generate_audio_response(text, condition_on_ref, ref_audio, ratios) # ref_audi, ratios 128 | else: 129 | return {} 130 | 131 | @app.route('/js/') 132 | def send_js(path): 133 | return send_from_directory( 134 | os.path.join(static_path, 'js'), path) 135 | 136 | @app.route('/css/') 137 | def send_css(path): 138 | return send_from_directory( 139 | os.path.join(static_path, 'css'), path) 140 | 141 | @app.route('/audio/') 142 | def send_audio(path): 143 | return send_from_directory( 144 | os.path.join(static_path, 'audio'), path) 145 | 146 | @app.route('/uploads/') 147 | def send_uploads(path): 148 | return send_from_directory( 149 | os.path.join(static_path, 'uploads'), path) 150 | 151 | if __name__ == '__main__': 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--checkpoint_path', required=True) 154 | parser.add_argument('--waveglow_path', required=True) 155 | parser.add_argument('--port', default=51000, type=int) 156 | parser.add_argument('--debug', default=False, type=str2bool) 157 | parser.add_argument('--is_korean', default=True, type=str2bool) 158 | config = parser.parse_args() 159 | 160 | if os.path.exists(config.checkpoint_path): 161 | synthesizer.load(config.checkpoint_path, config.waveglow_path) 162 | else: 163 | print(" [!] load_path not found: {}".format(config.checkpoint_path)) 164 | 165 | app.run(host='0.0.0.0', threaded=True, port=config.port, debug=config.debug) 166 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 78 | """ 79 | PARAMS 80 | ------ 81 | C: compression factor 82 | """ 83 | return torch.log(torch.clamp(x, min=clip_val) * C) 84 | 85 | 86 | def dynamic_range_decompression(x, C=1): 87 | """ 88 | PARAMS 89 | ------ 90 | C: compression factor used to compress 91 | """ 92 | return torch.exp(x) / C 93 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | import layers 7 | from utils import load_wav_to_torch, load_filepaths_and_text 8 | from text import text_to_sequence 9 | 10 | 11 | class TextMelLoader(torch.utils.data.Dataset): 12 | """ 13 | 1) loads audio,text pairs 14 | 2) normalizes text and converts them to sequences of one-hot vectors 15 | 3) computes mel-spectrograms from audio files. 16 | """ 17 | def __init__(self, audiopaths_and_text, hparams): 18 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 19 | self.text_cleaners = hparams.text_cleaners 20 | self.max_wav_value = hparams.max_wav_value 21 | self.sampling_rate = hparams.sampling_rate 22 | self.load_mel_from_disk = hparams.load_mel_from_disk 23 | self.n_speakers = hparams.n_speakers 24 | self.n_emotions = hparams.n_emotions 25 | self.stft = layers.TacotronSTFT( 26 | hparams.filter_length, hparams.hop_length, hparams.win_length, 27 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 28 | hparams.mel_fmax) 29 | random.seed(1234) 30 | random.shuffle(self.audiopaths_and_text) 31 | 32 | def get_mel_text_pair(self, audiopath_and_text): 33 | # separate filename and text 34 | audiopath, text, speaker, emotion = audiopath_and_text[0], audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3] # filelists/*.txt 구조대로 parsing 35 | text = self.get_text(text) # int_tensor[char_index, ....] 36 | mel = self.get_mel(audiopath) # [] 37 | speaker = self.get_speaker(speaker) # 현재는 single speaker 38 | emotion = self.get_emotion(emotion) 39 | 40 | return (text, mel, speaker, emotion) 41 | 42 | def get_mel(self, filename): 43 | if not self.load_mel_from_disk: 44 | audio, sampling_rate = load_wav_to_torch(filename) 45 | if sampling_rate != self.stft.sampling_rate: 46 | raise ValueError("{} SR doesn't match target {} SR".format( 47 | sampling_rate, self.stft.sampling_rate)) 48 | audio_norm = audio / self.max_wav_value 49 | audio_norm = audio_norm.unsqueeze(0) 50 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 51 | melspec = self.stft.mel_spectrogram(audio_norm) 52 | melspec = torch.squeeze(melspec, 0) 53 | else: 54 | melspec = torch.from_numpy(np.load(filename)) 55 | assert melspec.size(0) == self.stft.n_mel_channels, ( 56 | 'Mel dimension mismatch: given {}, expected {}'.format( 57 | melspec.size(0), self.stft.n_mel_channels)) 58 | 59 | return melspec 60 | 61 | def get_text(self, text): 62 | text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) 63 | return text_norm 64 | 65 | def get_speaker(self, speaker): 66 | speaker_vector = np.zeros(self.n_speakers) 67 | speaker_vector[int(speaker)] = 1 68 | return torch.Tensor(speaker_vector.astype(dtype=np.float32)) 69 | 70 | def get_emotion(self, emotion): 71 | emotion_vector = np.zeros(self.n_emotions) 72 | emotion_vector[int(emotion)] = 1 73 | return torch.Tensor(emotion_vector.astype(dtype=np.float32)) 74 | 75 | def __getitem__(self, index): 76 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 77 | 78 | def __len__(self): 79 | return len(self.audiopaths_and_text) 80 | 81 | 82 | class TextMelCollate(): 83 | """ Zero-pads model inputs and targets based on number of frames per setep 84 | """ 85 | def __init__(self, n_frames_per_step): 86 | self.n_frames_per_step = n_frames_per_step 87 | 88 | def __call__(self, batch): 89 | """Collate's training batch from normalized text and mel-spectrogram 90 | PARAMS 91 | ------ 92 | batch: [[text_normalized, mel_normalized], ...] 93 | """ 94 | # Right zero-pad all one-hot text sequences to max input length 95 | input_lengths, ids_sorted_decreasing = torch.sort( 96 | torch.LongTensor([len(x[0]) for x in batch]), 97 | dim=0, descending=True) 98 | max_input_len = input_lengths[0] 99 | 100 | text_padded = torch.LongTensor(len(batch), max_input_len) 101 | text_padded.zero_() 102 | for i in range(len(ids_sorted_decreasing)): 103 | text = batch[ids_sorted_decreasing[i]][0] 104 | text_padded[i, :text.size(0)] = text 105 | 106 | speakers = torch.LongTensor(len(batch), len(batch[0][2])) 107 | for i in range(len(ids_sorted_decreasing)): 108 | speaker = batch[ids_sorted_decreasing[i]][2] 109 | speakers[i, :] = speaker 110 | 111 | emotions = torch.LongTensor(len(batch), len(batch[0][3])) 112 | for i in range(len(ids_sorted_decreasing)): 113 | emotion = batch[ids_sorted_decreasing[i]][3] 114 | emotions[i, :] = emotion 115 | 116 | # Right zero-pad mel-spec 117 | num_mels = batch[0][1].size(0) 118 | max_target_len = max([x[1].size(1) for x in batch]) 119 | # max_target_len = min(max([x[1].size(1) for x in batch]), 1000) # max_len 1000 120 | if max_target_len % self.n_frames_per_step != 0: 121 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 122 | assert max_target_len % self.n_frames_per_step == 0 123 | 124 | # include mel padded and gate padded 125 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 126 | mel_padded.zero_() 127 | gate_padded = torch.FloatTensor(len(batch), max_target_len) 128 | gate_padded.zero_() 129 | output_lengths = torch.LongTensor(len(batch)) 130 | for i in range(len(ids_sorted_decreasing)): 131 | mel = batch[ids_sorted_decreasing[i]][1] 132 | mel_padded[i, :, :mel.size(1)] = mel 133 | gate_padded[i, mel.size(1)-1:] = 1 134 | output_lengths[i] = mel.size(1) 135 | 136 | return text_padded, input_lengths, mel_padded, gate_padded, \ 137 | output_lengths, speakers, emotions 138 | -------------------------------------------------------------------------------- /demo_guide.md: -------------------------------------------------------------------------------- 1 | 2 | # TTS Demo Guide 3 | > HCID VPN 연결이 필요합니다. 4 | 5 | ## Demo page 6 | ### 실행 7 | * URL: http://mind.snu.ac.kr:5907 8 | * 재실행 방법 9 | > conda activate tts 확인 10 | 11 | ``` 12 | ssh {admin}@mind.snu.ac.kr 13 | screen -r tts 14 | python app.py --checkpoint_path="./models/032902_vae_250000" --waveglow_path="./models/waveglow_130000" 15 | ``` 16 | 17 | 18 | ### 기능 19 | * Mix: netural, sad, happy, angry 정도를 반영하여 문장을 읽습니다. 20 | * Ref Audio: 랜덤하게 선택된 음성의 스타일을 따라 읽습니다. 21 | 22 | ![demo page](./res/demo.png) 23 | 24 | 25 | ## API 26 | 27 | ### Request 28 | * Method: POST 29 | * URL: http://mind.snu.ac.kr:5907/api 30 | 31 | ### Request Parameters 32 | 33 | | Key | 설명 | type | 34 | | - | - | - | 35 | | text | 듣고 싶은 문장을 입력 | string | 36 | | neu | neutral 정도 (0.0 ~ 1.0) | string | 37 | | hap | happy 정도 (0.0 ~ 1.0) | string | 38 | | sad | sad 정도 (0.0 ~ 1.0) | string | 39 | | ang| angry 정도 (0.0 ~ 1.0) | string | 40 | 41 | ### Requests 예제 42 | ``` 43 | curl --request POST \ 44 | --header 'Content-Type: application/json' \ 45 | --data '{"text":"안녕하세요.", "neu":"1.0", "hap":"0.0", "sad":"0.0", "ang":"0.0"}' \ 46 | http://mind.snu.ac.kr:5907/api 47 | ``` 48 | 49 | ### Response Parameters 50 | * Response 51 | 52 | | Key | 설명 | type | 53 | | - | - | - | 54 | | params | request parameters (text, neu, hap, sad, ang) | object | 55 | | data | TTS 결과 (base64) | string | 56 | 57 | * params 58 | 59 | | Key | 설명 | type | 60 | | - | - | - | 61 | | text | 듣고 싶은 문장을 입력 | string | 62 | | neu | neutral 정도 (0.0 ~ 1.0) | float | 63 | | hap | happy 정도 (0.0 ~ 1.0) | float | 64 | | sad | sad 정도 (0.0 ~ 1.0) | float | 65 | | ang| angry 정도 (0.0 ~ 1.0) | float | 66 | 67 | 68 | ### Response 예제 69 | ``` 70 | { 71 | "params": { 72 | "text": "\uc548\ub155\ud558\uc138\uc694.", 73 | "neu": 1.0, 74 | "hap": 0.0, 75 | "sad": 0.0, 76 | "ang": 0.0 77 | }, 78 | "data": "UklGRj..." 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.modules import Module 4 | from torch.autograd import Variable 5 | 6 | def _flatten_dense_tensors(tensors): 7 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 8 | same dense type. 9 | Since inputs are dense, the resulting tensor will be a concatenated 1D 10 | buffer. Element-wise operation on this buffer will be equivalent to 11 | operating individually. 12 | Arguments: 13 | tensors (Iterable[Tensor]): dense tensors to flatten. 14 | Returns: 15 | A contiguous 1D buffer containing input tensors. 16 | """ 17 | if len(tensors) == 1: 18 | return tensors[0].contiguous().view(-1) 19 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 20 | return flat 21 | 22 | def _unflatten_dense_tensors(flat, tensors): 23 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 24 | same dense type, and that flat is given by _flatten_dense_tensors. 25 | Arguments: 26 | flat (Tensor): flattened dense tensors to unflatten. 27 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 28 | unflatten flat. 29 | Returns: 30 | Unflattened dense tensors with sizes same as tensors and values from 31 | flat. 32 | """ 33 | outputs = [] 34 | offset = 0 35 | for tensor in tensors: 36 | numel = tensor.numel() 37 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 38 | offset += numel 39 | return tuple(outputs) 40 | 41 | 42 | ''' 43 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 44 | launcher included with this example. It assumes that your run is using multiprocess with 1 45 | GPU/process, that the model is on the correct device, and that torch.set_device has been 46 | used to set the device. 47 | 48 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 49 | and will be allreduced at the finish of the backward pass. 50 | ''' 51 | class DistributedDataParallel(Module): 52 | 53 | def __init__(self, module): 54 | super(DistributedDataParallel, self).__init__() 55 | #fallback for PyTorch 0.3 56 | if not hasattr(dist, '_backend'): 57 | self.warn_on_half = True 58 | else: 59 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 60 | 61 | self.module = module 62 | 63 | for p in self.module.state_dict().values(): 64 | if not torch.is_tensor(p): 65 | continue 66 | dist.broadcast(p, 0) 67 | 68 | def allreduce_params(): 69 | if(self.needs_reduction): 70 | self.needs_reduction = False 71 | buckets = {} 72 | for param in self.module.parameters(): 73 | if param.requires_grad and param.grad is not None: 74 | tp = type(param.data) 75 | if tp not in buckets: 76 | buckets[tp] = [] 77 | buckets[tp].append(param) 78 | if self.warn_on_half: 79 | if torch.cuda.HalfTensor in buckets: 80 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 81 | " It is recommended to use the NCCL backend in this case. This currently requires" + 82 | "PyTorch built from top of tree master.") 83 | self.warn_on_half = False 84 | 85 | for tp in buckets: 86 | bucket = buckets[tp] 87 | grads = [param.grad.data for param in bucket] 88 | coalesced = _flatten_dense_tensors(grads) 89 | dist.all_reduce(coalesced) 90 | coalesced /= dist.get_world_size() 91 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 92 | buf.copy_(synced) 93 | 94 | for param in list(self.module.parameters()): 95 | def allreduce_hook(*unused): 96 | param._execution_engine.queue_callback(allreduce_params) 97 | if param.requires_grad: 98 | param.register_hook(allreduce_hook) 99 | 100 | def forward(self, *inputs, **kwargs): 101 | self.needs_reduction = True 102 | return self.module(*inputs, **kwargs) 103 | 104 | ''' 105 | def _sync_buffers(self): 106 | buffers = list(self.module._all_buffers()) 107 | if len(buffers) > 0: 108 | # cross-node buffer sync 109 | flat_buffers = _flatten_dense_tensors(buffers) 110 | dist.broadcast(flat_buffers, 0) 111 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 112 | buf.copy_(synced) 113 | def train(self, mode=True): 114 | # Clear NCCL communicator and CUDA event cache of the default group ID, 115 | # These cache will be recreated at the later call. This is currently a 116 | # work-around for a potential NCCL deadlock. 117 | if dist._backend == dist.dist_backend.NCCL: 118 | dist._clear_group_cache() 119 | super(DistributedDataParallel, self).train(mode) 120 | self.module.train(mode) 121 | ''' 122 | ''' 123 | Modifies existing model to do gradient allreduce, but doesn't change class 124 | so you don't need "module" 125 | ''' 126 | def apply_gradient_allreduce(module): 127 | if not hasattr(dist, '_backend'): 128 | module.warn_on_half = True 129 | else: 130 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 131 | 132 | for p in module.state_dict().values(): 133 | if not torch.is_tensor(p): 134 | continue 135 | dist.broadcast(p, 0) 136 | 137 | def allreduce_params(): 138 | if(module.needs_reduction): 139 | module.needs_reduction = False 140 | buckets = {} 141 | for param in module.parameters(): 142 | if param.requires_grad and param.grad is not None: 143 | # tp = type(param.data) 144 | tp = param.data.dtype 145 | if tp not in buckets: 146 | buckets[tp] = [] 147 | buckets[tp].append(param) 148 | if module.warn_on_half: 149 | if torch.cuda.HalfTensor in buckets: 150 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 151 | " It is recommended to use the NCCL backend in this case. This currently requires" + 152 | "PyTorch built from top of tree master.") 153 | module.warn_on_half = False 154 | 155 | for tp in buckets: 156 | bucket = buckets[tp] 157 | grads = [param.grad.data for param in bucket] 158 | coalesced = _flatten_dense_tensors(grads) 159 | dist.all_reduce(coalesced) 160 | coalesced /= dist.get_world_size() 161 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 162 | buf.copy_(synced) 163 | 164 | for param in list(module.parameters()): 165 | def allreduce_hook(*unused): 166 | Variable._execution_engine.queue_callback(allreduce_params) 167 | if param.requires_grad: 168 | param.register_hook(allreduce_hook) 169 | 170 | def set_needs_reduction(self, input, output): 171 | self.needs_reduction = True 172 | 173 | module.register_forward_hook(set_needs_reduction) 174 | return module 175 | -------------------------------------------------------------------------------- /fp16_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 6 | 7 | from loss_scaler import DynamicLossScaler, LossScaler 8 | 9 | FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) 10 | HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) 11 | 12 | def conversion_helper(val, conversion): 13 | """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" 14 | if not isinstance(val, (tuple, list)): 15 | return conversion(val) 16 | rtn = [conversion_helper(v, conversion) for v in val] 17 | if isinstance(val, tuple): 18 | rtn = tuple(rtn) 19 | return rtn 20 | 21 | def fp32_to_fp16(val): 22 | """Convert fp32 `val` to fp16""" 23 | def half_conversion(val): 24 | val_typecheck = val 25 | if isinstance(val_typecheck, (Parameter, Variable)): 26 | val_typecheck = val.data 27 | if isinstance(val_typecheck, FLOAT_TYPES): 28 | val = val.half() 29 | return val 30 | return conversion_helper(val, half_conversion) 31 | 32 | def fp16_to_fp32(val): 33 | """Convert fp16 `val` to fp32""" 34 | def float_conversion(val): 35 | val_typecheck = val 36 | if isinstance(val_typecheck, (Parameter, Variable)): 37 | val_typecheck = val.data 38 | if isinstance(val_typecheck, HALF_TYPES): 39 | val = val.float() 40 | return val 41 | return conversion_helper(val, float_conversion) 42 | 43 | class FP16_Module(nn.Module): 44 | def __init__(self, module): 45 | super(FP16_Module, self).__init__() 46 | self.add_module('module', module.half()) 47 | 48 | def forward(self, *inputs, **kwargs): 49 | return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) 50 | 51 | class FP16_Optimizer(object): 52 | """ 53 | FP16_Optimizer is designed to wrap an existing PyTorch optimizer, 54 | and enable an fp16 model to be trained using a master copy of fp32 weights. 55 | 56 | Args: 57 | optimizer (torch.optim.optimizer): Existing optimizer containing initialized fp16 parameters. Internally, FP16_Optimizer replaces the passed optimizer's fp16 parameters with new fp32 parameters copied from the original ones. FP16_Optimizer also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy after each step. 58 | static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale fp16 gradients computed by the model. Scaled gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so static_loss_scale should not affect learning rate. 59 | dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any static_loss_scale option. 60 | 61 | """ 62 | 63 | def __init__(self, optimizer, static_loss_scale=1.0, dynamic_loss_scale=False): 64 | if not torch.cuda.is_available: 65 | raise SystemError('Cannot use fp16 without CUDA') 66 | 67 | self.fp16_param_groups = [] 68 | self.fp32_param_groups = [] 69 | self.fp32_flattened_groups = [] 70 | for i, param_group in enumerate(optimizer.param_groups): 71 | print("FP16_Optimizer processing param group {}:".format(i)) 72 | fp16_params_this_group = [] 73 | fp32_params_this_group = [] 74 | for param in param_group['params']: 75 | if param.requires_grad: 76 | if param.type() == 'torch.cuda.HalfTensor': 77 | print("FP16_Optimizer received torch.cuda.HalfTensor with {}" 78 | .format(param.size())) 79 | fp16_params_this_group.append(param) 80 | elif param.type() == 'torch.cuda.FloatTensor': 81 | print("FP16_Optimizer received torch.cuda.FloatTensor with {}" 82 | .format(param.size())) 83 | fp32_params_this_group.append(param) 84 | else: 85 | raise TypeError("Wrapped parameters must be either " 86 | "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " 87 | "Received {}".format(param.type())) 88 | 89 | fp32_flattened_this_group = None 90 | if len(fp16_params_this_group) > 0: 91 | fp32_flattened_this_group = _flatten_dense_tensors( 92 | [param.detach().data.clone().float() for param in fp16_params_this_group]) 93 | 94 | fp32_flattened_this_group = Variable(fp32_flattened_this_group, requires_grad = True) 95 | 96 | fp32_flattened_this_group.grad = fp32_flattened_this_group.new( 97 | *fp32_flattened_this_group.size()) 98 | 99 | # python's lovely list concatenation via + 100 | if fp32_flattened_this_group is not None: 101 | param_group['params'] = [fp32_flattened_this_group] + fp32_params_this_group 102 | else: 103 | param_group['params'] = fp32_params_this_group 104 | 105 | self.fp16_param_groups.append(fp16_params_this_group) 106 | self.fp32_param_groups.append(fp32_params_this_group) 107 | self.fp32_flattened_groups.append(fp32_flattened_this_group) 108 | 109 | # print("self.fp32_flattened_groups = ", self.fp32_flattened_groups) 110 | # print("self.fp16_param_groups = ", self.fp16_param_groups) 111 | 112 | self.optimizer = optimizer.__class__(optimizer.param_groups) 113 | 114 | # self.optimizer.load_state_dict(optimizer.state_dict()) 115 | 116 | self.param_groups = self.optimizer.param_groups 117 | 118 | if dynamic_loss_scale: 119 | self.dynamic_loss_scale = True 120 | self.loss_scaler = DynamicLossScaler() 121 | else: 122 | self.dynamic_loss_scale = False 123 | self.loss_scaler = LossScaler(static_loss_scale) 124 | 125 | self.overflow = False 126 | self.first_closure_call_this_step = True 127 | 128 | def zero_grad(self): 129 | """ 130 | Zero fp32 and fp16 parameter grads. 131 | """ 132 | self.optimizer.zero_grad() 133 | for fp16_group in self.fp16_param_groups: 134 | for param in fp16_group: 135 | if param.grad is not None: 136 | param.grad.detach_() # This does appear in torch.optim.optimizer.zero_grad(), 137 | # but I'm not sure why it's needed. 138 | param.grad.zero_() 139 | 140 | def _check_overflow(self): 141 | params = [] 142 | for group in self.fp16_param_groups: 143 | for param in group: 144 | params.append(param) 145 | for group in self.fp32_param_groups: 146 | for param in group: 147 | params.append(param) 148 | self.overflow = self.loss_scaler.has_overflow(params) 149 | 150 | def _update_scale(self, has_overflow=False): 151 | self.loss_scaler.update_scale(has_overflow) 152 | 153 | def _copy_grads_fp16_to_fp32(self): 154 | for fp32_group, fp16_group in zip(self.fp32_flattened_groups, self.fp16_param_groups): 155 | if len(fp16_group) > 0: 156 | # print(fp16_group) 157 | # This might incur one more deep copy than is necessary. 158 | fp32_group.grad.data.copy_( 159 | _flatten_dense_tensors([fp16_param.grad.data for fp16_param in fp16_group])) 160 | 161 | def _downscale_fp32(self): 162 | if self.loss_scale != 1.0: 163 | for param_group in self.optimizer.param_groups: 164 | for param in param_group['params']: 165 | param.grad.data.mul_(1./self.loss_scale) 166 | 167 | def clip_fp32_grads(self, clip=-1): 168 | if not self.overflow: 169 | fp32_params = [] 170 | for param_group in self.optimizer.param_groups: 171 | for param in param_group['params']: 172 | fp32_params.append(param) 173 | if clip > 0: 174 | return torch.nn.utils.clip_grad_norm(fp32_params, clip) 175 | 176 | def _copy_params_fp32_to_fp16(self): 177 | for fp16_group, fp32_group in zip(self.fp16_param_groups, self.fp32_flattened_groups): 178 | if len(fp16_group) > 0: 179 | for fp16_param, fp32_data in zip(fp16_group, 180 | _unflatten_dense_tensors(fp32_group.data, fp16_group)): 181 | fp16_param.data.copy_(fp32_data) 182 | 183 | def state_dict(self): 184 | """ 185 | Returns a dict containing the current state of this FP16_Optimizer instance. 186 | This dict contains attributes of FP16_Optimizer, as well as the state_dict 187 | of the contained Pytorch optimizer. 188 | 189 | Untested. 190 | """ 191 | state_dict = {} 192 | state_dict['loss_scaler'] = self.loss_scaler 193 | state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale 194 | state_dict['overflow'] = self.overflow 195 | state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step 196 | state_dict['optimizer_state_dict'] = self.optimizer.state_dict() 197 | return state_dict 198 | 199 | def load_state_dict(self, state_dict): 200 | """ 201 | Loads a state_dict created by an earlier call to state_dict. 202 | 203 | Untested. 204 | """ 205 | self.loss_scaler = state_dict['loss_scaler'] 206 | self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] 207 | self.overflow = state_dict['overflow'] 208 | self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] 209 | self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) 210 | 211 | def step(self, closure=None): # could add clip option. 212 | """ 213 | If no closure is supplied, step should be called after fp16_optimizer_obj.backward(loss). 214 | step updates the fp32 master copy of parameters using the optimizer supplied to 215 | FP16_Optimizer's constructor, then copies the updated fp32 params into the fp16 params 216 | originally referenced by Fp16_Optimizer's constructor, so the user may immediately run 217 | another forward pass using their model. 218 | 219 | If a closure is supplied, step may be called without a prior call to self.backward(loss). 220 | However, the user should take care that any loss.backward() call within the closure 221 | has been replaced by fp16_optimizer_obj.backward(loss). 222 | 223 | Args: 224 | closure (optional): Closure that will be supplied to the underlying optimizer originally passed to FP16_Optimizer's constructor. closure should call zero_grad on the FP16_Optimizer object, compute the loss, call .backward(loss), and return the loss. 225 | 226 | Closure example:: 227 | 228 | # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an 229 | # existing pytorch optimizer. 230 | for input, target in dataset: 231 | def closure(): 232 | optimizer.zero_grad() 233 | output = model(input) 234 | loss = loss_fn(output, target) 235 | optimizer.backward(loss) 236 | return loss 237 | optimizer.step(closure) 238 | 239 | .. note:: 240 | The only changes that need to be made compared to 241 | `ordinary optimizer closures`_ are that "optimizer" itself should be an instance of 242 | FP16_Optimizer, and that the call to loss.backward should be replaced by 243 | optimizer.backward(loss). 244 | 245 | .. warning:: 246 | Currently, calling step with a closure is not compatible with dynamic loss scaling. 247 | 248 | .. _`ordinary optimizer closures`: 249 | http://pytorch.org/docs/master/optim.html#optimizer-step-closure 250 | """ 251 | if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler): 252 | raise TypeError("Using step with a closure is currently not " 253 | "compatible with dynamic loss scaling.") 254 | 255 | scale = self.loss_scaler.loss_scale 256 | self._update_scale(self.overflow) 257 | 258 | if self.overflow: 259 | print("OVERFLOW! Skipping step. Attempted loss scale: {}".format(scale)) 260 | return 261 | 262 | if closure is not None: 263 | self._step_with_closure(closure) 264 | else: 265 | self.optimizer.step() 266 | 267 | self._copy_params_fp32_to_fp16() 268 | 269 | return 270 | 271 | def _step_with_closure(self, closure): 272 | def wrapped_closure(): 273 | if self.first_closure_call_this_step: 274 | """ 275 | We expect that the fp16 params are initially fresh on entering self.step(), 276 | so _copy_params_fp32_to_fp16() is unnecessary the first time wrapped_closure() 277 | is called within self.optimizer.step(). 278 | """ 279 | self.first_closure_call_this_step = False 280 | else: 281 | """ 282 | If self.optimizer.step() internally calls wrapped_closure more than once, 283 | it may update the fp32 params after each call. However, self.optimizer 284 | doesn't know about the fp16 params at all. If the fp32 params get updated, 285 | we can't rely on self.optimizer to refresh the fp16 params. We need 286 | to handle that manually: 287 | """ 288 | self._copy_params_fp32_to_fp16() 289 | 290 | """ 291 | Our API expects the user to give us ownership of the backward() call by 292 | replacing all calls to loss.backward() with optimizer.backward(loss). 293 | This requirement holds whether or not the call to backward() is made within 294 | a closure. 295 | If the user is properly calling optimizer.backward(loss) within "closure," 296 | calling closure() here will give the fp32 master params fresh gradients 297 | for the optimizer to play with, 298 | so all wrapped_closure needs to do is call closure() and return the loss. 299 | """ 300 | temp_loss = closure() 301 | return temp_loss 302 | 303 | self.optimizer.step(wrapped_closure) 304 | 305 | self.first_closure_call_this_step = True 306 | 307 | def backward(self, loss, update_fp32_grads=True): 308 | """ 309 | fp16_optimizer_obj.backward performs the following conceptual operations: 310 | 311 | fp32_loss = loss.float() (see first Note below) 312 | 313 | scaled_loss = fp32_loss*loss_scale 314 | 315 | scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the 316 | fp16 model's leaves. 317 | 318 | fp16 grads are then copied to the stored fp32 params' .grad attributes (see second Note). 319 | 320 | Finally, fp32 grads are divided by loss_scale. 321 | 322 | In this way, after fp16_optimizer_obj.backward, the fp32 parameters have fresh gradients, 323 | and fp16_optimizer_obj.step may be called. 324 | 325 | .. note:: 326 | Converting the loss to fp32 before applying the loss scale provides some 327 | additional safety against overflow if the user has supplied an fp16 value. 328 | However, for maximum overflow safety, the user should 329 | compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to 330 | fp16_optimizer_obj.backward. 331 | 332 | .. note:: 333 | The gradients found in an fp16 model's leaves after a call to 334 | fp16_optimizer_obj.backward should not be regarded as valid in general, 335 | because it's possible 336 | they have been scaled (and in the case of dynamic loss scaling, 337 | the scale factor may silently change over time). 338 | If the user wants to inspect gradients after a call to fp16_optimizer_obj.backward, 339 | he/she should query the .grad attribute of FP16_Optimizer's stored fp32 parameters. 340 | 341 | Args: 342 | loss: The loss output by the user's model. loss may be either float or half (but see first Note above). 343 | update_fp32_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay this copy, which is useful to eliminate redundant fp16->fp32 grad copies if fp16_optimizer_obj.backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling fp16_optimizer_obj.update_fp32_grads before calling fp16_optimizer_obj.step. 344 | 345 | Example:: 346 | 347 | # Ordinary operation: 348 | optimizer.backward(loss) 349 | 350 | # Naive operation with multiple losses (technically valid, but less efficient): 351 | # fp32 grads will be correct after the second call, but 352 | # the first call incurs an unnecessary fp16->fp32 grad copy. 353 | optimizer.backward(loss1) 354 | optimizer.backward(loss2) 355 | 356 | # More efficient way to handle multiple losses: 357 | # The fp16->fp32 grad copy is delayed until fp16 grads from all 358 | # losses have been accumulated. 359 | optimizer.backward(loss1, update_fp32_grads=False) 360 | optimizer.backward(loss2, update_fp32_grads=False) 361 | optimizer.update_fp32_grads() 362 | """ 363 | self.loss_scaler.backward(loss.float()) 364 | if update_fp32_grads: 365 | self.update_fp32_grads() 366 | 367 | def update_fp32_grads(self): 368 | """ 369 | Copy the .grad attribute from stored references to fp16 parameters to 370 | the .grad attribute of the master fp32 parameters that are directly 371 | updated by the optimizer. :attr:`update_fp32_grads` only needs to be called if 372 | fp16_optimizer_obj.backward was called with update_fp32_grads=False. 373 | """ 374 | if self.dynamic_loss_scale: 375 | self._check_overflow() 376 | if self.overflow: return 377 | self._copy_grads_fp16_to_fp32() 378 | self._downscale_fp32() 379 | 380 | @property 381 | def loss_scale(self): 382 | return self.loss_scaler.loss_scale 383 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def create_hparams(hparams_string=None, verbose=False): 4 | """Create model hyperparameters. Parse nondefault from given string.""" 5 | 6 | hparams = tf.contrib.training.HParams( 7 | ################################ 8 | # Experiment Parameters # 9 | ################################ 10 | epochs=300, 11 | iters_per_checkpoint=500, 12 | seed=1234, 13 | dynamic_loss_scaling=True, 14 | fp16_run=False, 15 | distributed_run=False, 16 | 17 | dist_backend="nccl", 18 | dist_url="tcp://localhost:54321", 19 | cudnn_enabled=True, 20 | cudnn_benchmark=True, 21 | 22 | ################################ 23 | # Data Parameters # 24 | ################################ 25 | load_mel_from_disk=False, 26 | training_files='filelists/ms_kor_train.txt', 27 | validation_files='filelists/ms_kor_val.txt', 28 | text_cleaners=['korean_cleaners'], # english_cleaners, korean_cleaners 29 | sort_by_length=False, 30 | 31 | ################################ 32 | # Audio Parameters # 33 | ################################ 34 | max_wav_value=32768.0, 35 | sampling_rate=16000, 36 | filter_length=1024, 37 | hop_length=256, # number audio of frames between stft colmns, default win_length/4 38 | win_length=1024, # win_length int <= n_ftt: fft window size (frequency domain), defaults to win_length = n_fft 39 | n_mel_channels=80, 40 | mel_fmin=0.0, 41 | mel_fmax=8000.0, 42 | 43 | ################################ 44 | # Model Parameters # 45 | ################################ 46 | n_symbols = 80, # set 80 if u use korean_cleaners. set 65 if u use english_cleaners 47 | symbols_embedding_dim=512, 48 | 49 | # Transcript encoder parameters 50 | encoder_kernel_size = 5, 51 | encoder_n_convolutions = 3, 52 | encoder_embedding_dim = 512, 53 | 54 | # Speaker embedding parameters 55 | n_speakers = 1, 56 | speaker_embedding_dim=16, 57 | 58 | # ---------------------------------------- # 59 | # emotion 60 | n_emotions = 4, # number of emotion labels 61 | emotion_embedding_dim=16, 62 | 63 | # reference encoder 64 | E = 512, 65 | ref_enc_filters = [32, 32, 64, 64, 128, 128], 66 | ref_enc_size = [3, 3], 67 | ref_enc_strides = [2, 2], 68 | ref_enc_pad = [1, 1], 69 | ref_enc_gru_size = 512 // 2, 70 | 71 | z_latent_dim = 32, 72 | anneal_function = 'logistic', 73 | anneal_k = 0.0025, 74 | anneal_x0 = 10000, 75 | anneal_upper = 0.2, 76 | anneal_lag = 50000, 77 | 78 | # Prosody embedding parameters 79 | prosody_n_convolutions = 6, 80 | prosody_conv_dim_in = [1, 32, 32, 64, 64, 128], 81 | prosody_conv_dim_out = [32, 32, 64, 64, 128, 128], 82 | prosody_conv_kernel = 3, 83 | prosody_conv_stride = 2, 84 | prosody_embedding_dim = 128, 85 | 86 | # Decoder parameters 87 | n_frames_per_step=1, # currently only 1 is supported 88 | decoder_rnn_dim=1024, 89 | prenet_dim=256, 90 | max_decoder_steps=1000, 91 | gate_threshold=0.5, 92 | p_attention_dropout=0.1, 93 | p_decoder_dropout=0.1, 94 | 95 | # Attention parameters 96 | attention_rnn_dim=1024, 97 | attention_dim=128, 98 | 99 | # Location Layer parameters 100 | attention_location_n_filters=32, 101 | attention_location_kernel_size=31, 102 | 103 | # Mel-post processing network parameters 104 | postnet_embedding_dim=512, 105 | postnet_kernel_size=5, 106 | postnet_n_convolutions=5, 107 | 108 | ################################ 109 | # Optimization Hyperparameters # 110 | ################################ 111 | use_saved_learning_rate=False, 112 | learning_rate=1e-3, 113 | weight_decay=1e-6, 114 | grad_clip_thresh=1.0, 115 | batch_size=64, 116 | mask_padding=True # set model's padded outputs to padded values 117 | ) 118 | 119 | if hparams_string: 120 | tf.logging.info('Parsing command line hparams: %s', hparams_string) 121 | hparams.parse(hparams_string) 122 | 123 | if verbose: 124 | tf.logging.info('Final parsed hparams: %s', hparams.values()) 125 | 126 | return hparams 127 | 128 | if __name__=='__main__': 129 | hp = create_hparams(verbose=True) 130 | print(hp.batch_size) -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from audio_processing import dynamic_range_compression, dynamic_range_decompression 4 | from stft import STFT 5 | 6 | 7 | class LinearNorm(torch.nn.Module): 8 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 9 | super(LinearNorm, self).__init__() 10 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 11 | 12 | torch.nn.init.xavier_uniform_( 13 | self.linear_layer.weight, 14 | gain=torch.nn.init.calculate_gain(w_init_gain)) 15 | 16 | def forward(self, x): 17 | return self.linear_layer(x) 18 | 19 | 20 | class ConvNorm(torch.nn.Module): 21 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 22 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 23 | super(ConvNorm, self).__init__() 24 | if padding is None: 25 | assert(kernel_size % 2 == 1) 26 | padding = int(dilation * (kernel_size - 1) / 2) 27 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 28 | kernel_size=kernel_size, stride=stride, 29 | padding=padding, dilation=dilation, 30 | bias=bias) 31 | torch.nn.init.xavier_uniform_( 32 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 33 | 34 | def forward(self, signal): 35 | conv_signal = self.conv(signal) 36 | return conv_signal 37 | 38 | class ConvNorm2D(torch.nn.Module): 39 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 40 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 41 | super(ConvNorm2D, self).__init__() 42 | self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 43 | kernel_size=kernel_size, stride=stride, 44 | padding=padding, dilation=dilation, 45 | groups=1, bias=bias) 46 | torch.nn.init.xavier_uniform_( 47 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 48 | 49 | def forward(self, signal): 50 | conv_signal = self.conv(signal) 51 | return conv_signal 52 | 53 | 54 | class TacotronSTFT(torch.nn.Module): 55 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 56 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 57 | mel_fmax=8000.0): 58 | super(TacotronSTFT, self).__init__() 59 | self.n_mel_channels = n_mel_channels 60 | self.sampling_rate = sampling_rate 61 | self.stft_fn = STFT(filter_length, hop_length, win_length) 62 | mel_basis = librosa_mel_fn( 63 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 64 | mel_basis = torch.from_numpy(mel_basis).float() 65 | self.register_buffer('mel_basis', mel_basis) 66 | 67 | def spectral_normalize(self, magnitudes): 68 | output = dynamic_range_compression(magnitudes) 69 | return output 70 | 71 | def spectral_de_normalize(self, magnitudes): 72 | output = dynamic_range_decompression(magnitudes) 73 | return output 74 | 75 | def mel_spectrogram(self, y, ref_level_db = 20, magnitude_power=1.5): 76 | """Computes mel-spectrograms from a batch of waves 77 | PARAMS 78 | ------ 79 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 80 | 81 | RETURNS 82 | ------- 83 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 84 | """ 85 | assert(torch.min(y.data) >= -1) 86 | assert(torch.max(y.data) <= 1) 87 | 88 | magnitudes, phases = self.stft_fn.transform(y) 89 | magnitudes = magnitudes.data 90 | mel_output = torch.matmul(self.mel_basis, magnitudes) 91 | mel_output = self.spectral_normalize(mel_output) 92 | return mel_output 93 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.nn.functional as F 3 | from tensorboardX import SummaryWriter 4 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy 5 | from plotting_utils import plot_gate_outputs_to_numpy, plot_scatter 6 | 7 | 8 | class Tacotron2Logger(SummaryWriter): 9 | def __init__(self, logdir): 10 | super(Tacotron2Logger, self).__init__(logdir) 11 | 12 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, recon_loss, kl_div, kl_weight, 13 | iteration): 14 | self.add_scalar("training.loss", reduced_loss, iteration) 15 | self.add_scalar("grad.norm", grad_norm, iteration) 16 | self.add_scalar("learning.rate", learning_rate, iteration) 17 | self.add_scalar("duration", duration, iteration) 18 | self.add_scalar("kl_div", kl_div, iteration) 19 | self.add_scalar("kl_weight", kl_weight, iteration) 20 | self.add_scalar("recon_loss", recon_loss, iteration) 21 | 22 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 23 | self.add_scalar("validation.loss", reduced_loss, iteration) 24 | _, mel_outputs, gate_outputs, alignments, mus, _, _, emotions = y_pred 25 | mel_targets, gate_targets = y 26 | print(emotions) 27 | 28 | # plot distribution of parameters 29 | for tag, value in model.named_parameters(): 30 | tag = tag.replace('.', '/') 31 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 32 | 33 | # plot alignment, mel target and predicted, gate target and predicted 34 | idx = random.randint(0, alignments.size(0) - 1) 35 | self.add_image( 36 | "alignment", 37 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 38 | iteration) 39 | self.add_image( 40 | "mel_target", 41 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 42 | iteration) 43 | self.add_image( 44 | "mel_predicted", 45 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 46 | iteration) 47 | self.add_image( 48 | "gate", 49 | plot_gate_outputs_to_numpy( 50 | gate_targets[idx].data.cpu().numpy(), 51 | F.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 52 | iteration) 53 | self.add_image( 54 | "latent_dim", 55 | plot_scatter(mus, emotions), 56 | iteration) 57 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class Tacotron2Loss_VAE(nn.Module): 7 | def __init__(self, hparams): 8 | super(Tacotron2Loss_VAE, self).__init__() 9 | self.anneal_function = hparams.anneal_function 10 | self.lag = hparams.anneal_lag 11 | self.k = hparams.anneal_k 12 | self.x0 = hparams.anneal_x0 13 | self.upper = hparams.anneal_upper 14 | 15 | def kl_anneal_function(self, anneal_function, lag, step, k, x0, upper): 16 | if anneal_function == 'logistic': 17 | return float(upper/(upper+np.exp(-k*(step-x0)))) 18 | elif anneal_function == 'linear': 19 | if step > lag: 20 | return min(upper, step/x0) 21 | else: 22 | return 0 23 | elif anneal_function == 'constant': 24 | return 0.001 25 | 26 | 27 | def forward(self, model_output, targets, step): 28 | mel_target, gate_target = targets[0], targets[1] 29 | mel_target.requires_grad = False 30 | gate_target.requires_grad = False 31 | gate_target = gate_target.view(-1, 1) 32 | 33 | mel_out, mel_out_postnet, gate_out, _, mu, logvar, _, _ = model_output 34 | gate_out = gate_out.view(-1, 1) 35 | mel_loss = nn.MSELoss()(mel_out, mel_target) + \ 36 | nn.MSELoss()(mel_out_postnet, mel_target) 37 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 38 | 39 | kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 40 | kl_weight = self.kl_anneal_function(self.anneal_function, self.lag, step, self.k, self.x0, self.upper) 41 | 42 | recon_loss = mel_loss + gate_loss 43 | total_loss = recon_loss + kl_weight*kl_loss 44 | 45 | return total_loss, recon_loss, kl_loss, kl_weight -------------------------------------------------------------------------------- /loss_scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LossScaler: 4 | 5 | def __init__(self, scale=1): 6 | self.cur_scale = scale 7 | 8 | # `params` is a list / generator of torch.Variable 9 | def has_overflow(self, params): 10 | return False 11 | 12 | # `x` is a torch.Tensor 13 | def _has_inf_or_nan(x): 14 | return False 15 | 16 | # `overflow` is boolean indicating whether we overflowed in gradient 17 | def update_scale(self, overflow): 18 | pass 19 | 20 | @property 21 | def loss_scale(self): 22 | return self.cur_scale 23 | 24 | def scale_gradient(self, module, grad_in, grad_out): 25 | return tuple(self.loss_scale * g for g in grad_in) 26 | 27 | def backward(self, loss): 28 | scaled_loss = loss*self.loss_scale 29 | scaled_loss.backward() 30 | 31 | class DynamicLossScaler: 32 | 33 | def __init__(self, 34 | init_scale=2**32, 35 | scale_factor=2., 36 | scale_window=1000): 37 | self.cur_scale = init_scale 38 | self.cur_iter = 0 39 | self.last_overflow_iter = -1 40 | self.scale_factor = scale_factor 41 | self.scale_window = scale_window 42 | 43 | # `params` is a list / generator of torch.Variable 44 | def has_overflow(self, params): 45 | # return False 46 | for p in params: 47 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 48 | return True 49 | 50 | return False 51 | 52 | # `x` is a torch.Tensor 53 | def _has_inf_or_nan(x): 54 | cpu_sum = float(x.float().sum()) 55 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 56 | return True 57 | return False 58 | 59 | # `overflow` is boolean indicating whether we overflowed in gradient 60 | def update_scale(self, overflow): 61 | if overflow: 62 | #self.cur_scale /= self.scale_factor 63 | self.cur_scale = max(self.cur_scale/self.scale_factor, 1) 64 | self.last_overflow_iter = self.cur_iter 65 | else: 66 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 67 | self.cur_scale *= self.scale_factor 68 | # self.cur_scale = 1 69 | self.cur_iter += 1 70 | 71 | @property 72 | def loss_scale(self): 73 | return self.cur_scale 74 | 75 | def scale_gradient(self, module, grad_in, grad_out): 76 | return tuple(self.loss_scale * g for g in grad_in) 77 | 78 | def backward(self, loss): 79 | scaled_loss = loss*self.loss_scale 80 | scaled_loss.backward() 81 | 82 | ############################################################## 83 | # Example usage below here -- assuming it's in a separate file 84 | ############################################################## 85 | if __name__ == "__main__": 86 | import torch 87 | from torch.autograd import Variable 88 | from dynamic_loss_scaler import DynamicLossScaler 89 | 90 | # N is batch size; D_in is input dimension; 91 | # H is hidden dimension; D_out is output dimension. 92 | N, D_in, H, D_out = 64, 1000, 100, 10 93 | 94 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 95 | x = Variable(torch.randn(N, D_in), requires_grad=False) 96 | y = Variable(torch.randn(N, D_out), requires_grad=False) 97 | 98 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 99 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 100 | parameters = [w1, w2] 101 | 102 | learning_rate = 1e-6 103 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 104 | loss_scaler = DynamicLossScaler() 105 | 106 | for t in range(500): 107 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 108 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 109 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 110 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 111 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 112 | 113 | # Run backprop 114 | optimizer.zero_grad() 115 | loss.backward() 116 | 117 | # Check for overflow 118 | has_overflow = DynamicLossScaler.has_overflow(parameters) 119 | 120 | # If no overflow, unscale grad and update as usual 121 | if not has_overflow: 122 | for param in parameters: 123 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 124 | optimizer.step() 125 | # Otherwise, don't do anything -- ie, skip iteration 126 | else: 127 | print('OVERFLOW!') 128 | 129 | # Update loss scale for next iteration 130 | loss_scaler.update_scale(has_overflow) 131 | 132 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from math import sqrt, ceil 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from layers import ConvNorm, ConvNorm2D, LinearNorm 7 | from utils import to_gpu, get_mask_from_lengths 8 | from fp16_optimizer import fp32_to_fp16, fp16_to_fp32 9 | from modules import VAE_GST 10 | 11 | drop_rate = 0.5 12 | class LocationLayer(nn.Module): 13 | def __init__(self, attention_n_filters, attention_kernel_size, 14 | attention_dim): 15 | super(LocationLayer, self).__init__() 16 | padding = int((attention_kernel_size - 1) / 2) 17 | self.location_conv = ConvNorm(2, attention_n_filters, 18 | kernel_size=attention_kernel_size, 19 | padding=padding, bias=False, stride=1, 20 | dilation=1) 21 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 22 | bias=False, w_init_gain='tanh') 23 | 24 | def forward(self, attention_weights_cat): 25 | processed_attention = self.location_conv(attention_weights_cat) 26 | processed_attention = processed_attention.transpose(1, 2) 27 | processed_attention = self.location_dense(processed_attention) 28 | return processed_attention 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 33 | attention_location_n_filters, attention_location_kernel_size): 34 | super(Attention, self).__init__() 35 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 36 | bias=False, w_init_gain='tanh') 37 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 38 | w_init_gain='tanh') 39 | self.v = LinearNorm(attention_dim, 1, bias=False) 40 | self.location_layer = LocationLayer(attention_location_n_filters, 41 | attention_location_kernel_size, 42 | attention_dim) 43 | self.score_mask_value = -float("inf") 44 | 45 | def get_alignment_energies(self, query, processed_memory, 46 | attention_weights_cat): 47 | """ 48 | PARAMS 49 | ------ 50 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 51 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 52 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 53 | 54 | RETURNS 55 | ------- 56 | alignment (batch, max_time) 57 | """ 58 | 59 | processed_query = self.query_layer(query.unsqueeze(1)) 60 | processed_attention_weights = self.location_layer(attention_weights_cat) 61 | energies = self.v(torch.tanh( 62 | processed_query + processed_attention_weights + processed_memory)) 63 | 64 | energies = energies.squeeze(-1) 65 | return energies 66 | 67 | def forward(self, attention_hidden_state, memory, processed_memory, 68 | attention_weights_cat, mask): 69 | """ 70 | PARAMS 71 | ------ 72 | attention_hidden_state: attention rnn last output 73 | memory: encoder outputs 74 | processed_memory: processed encoder outputs 75 | attention_weights_cat: previous and cummulative attention weights 76 | mask: binary mask for padded data 77 | """ 78 | alignment = self.get_alignment_energies( 79 | attention_hidden_state, processed_memory, attention_weights_cat) 80 | 81 | if mask is not None: 82 | alignment.data.masked_fill_(mask, self.score_mask_value) 83 | 84 | attention_weights = F.softmax(alignment, dim=1) 85 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 86 | attention_context = attention_context.squeeze(1) 87 | 88 | return attention_context, attention_weights 89 | 90 | 91 | class Prenet(nn.Module): 92 | def __init__(self, in_dim, sizes): 93 | super(Prenet, self).__init__() 94 | in_sizes = [in_dim] + sizes[:-1] 95 | self.layers = nn.ModuleList( 96 | [LinearNorm(in_size, out_size, bias=False) 97 | for (in_size, out_size) in zip(in_sizes, sizes)]) 98 | 99 | def forward(self, x): 100 | for linear in self.layers: 101 | x = F.dropout(F.relu(linear(x)), p=drop_rate, training=True) 102 | return x 103 | 104 | 105 | class Postnet(nn.Module): 106 | """Postnet 107 | - Five 1-d convolution with 512 channels and kernel size 5 108 | """ 109 | 110 | def __init__(self, hparams): 111 | super(Postnet, self).__init__() 112 | self.convolutions = nn.ModuleList() 113 | 114 | self.convolutions.append( 115 | nn.Sequential( 116 | ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, 117 | kernel_size=hparams.postnet_kernel_size, stride=1, 118 | padding=int((hparams.postnet_kernel_size - 1) / 2), 119 | dilation=1, w_init_gain='tanh'), 120 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 121 | ) 122 | 123 | for i in range(1, hparams.postnet_n_convolutions - 1): 124 | self.convolutions.append( 125 | nn.Sequential( 126 | ConvNorm(hparams.postnet_embedding_dim, 127 | hparams.postnet_embedding_dim, 128 | kernel_size=hparams.postnet_kernel_size, stride=1, 129 | padding=int((hparams.postnet_kernel_size - 1) / 2), 130 | dilation=1, w_init_gain='tanh'), 131 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 132 | ) 133 | 134 | self.convolutions.append( 135 | nn.Sequential( 136 | ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, 137 | kernel_size=hparams.postnet_kernel_size, stride=1, 138 | padding=int((hparams.postnet_kernel_size - 1) / 2), 139 | dilation=1, w_init_gain='linear'), 140 | nn.BatchNorm1d(hparams.n_mel_channels)) 141 | ) 142 | 143 | def forward(self, x): 144 | for i in range(len(self.convolutions) - 1): 145 | x = F.dropout(torch.tanh(self.convolutions[i](x)), drop_rate, self.training) 146 | x = F.dropout(self.convolutions[-1](x), drop_rate, self.training) 147 | 148 | return x 149 | 150 | 151 | class Encoder(nn.Module): 152 | """Encoder module: 153 | - Three 1-d convolution banks 154 | - Bidirectional LSTM 155 | """ 156 | def __init__(self, hparams): 157 | super(Encoder, self).__init__() 158 | 159 | convolutions = [] 160 | for _ in range(hparams.encoder_n_convolutions): 161 | conv_layer = nn.Sequential( 162 | ConvNorm(hparams.encoder_embedding_dim, 163 | hparams.encoder_embedding_dim, 164 | kernel_size=hparams.encoder_kernel_size, stride=1, 165 | padding=int((hparams.encoder_kernel_size - 1) / 2), 166 | dilation=1, w_init_gain='relu'), 167 | nn.BatchNorm1d(hparams.encoder_embedding_dim)) 168 | convolutions.append(conv_layer) 169 | self.convolutions = nn.ModuleList(convolutions) 170 | 171 | self.lstm = nn.LSTM(hparams.encoder_embedding_dim, 172 | int(hparams.encoder_embedding_dim / 2), 1, 173 | batch_first=True, bidirectional=True) 174 | 175 | def forward(self, x, input_lengths): 176 | for conv in self.convolutions: 177 | x = F.dropout(F.relu(conv(x)), drop_rate, self.training) 178 | 179 | x = x.transpose(1, 2) 180 | 181 | # pytorch tensor are not reversible, hence the conversion 182 | input_lengths = input_lengths.cpu().numpy() 183 | x = nn.utils.rnn.pack_padded_sequence( 184 | x, input_lengths, batch_first=True) 185 | 186 | self.lstm.flatten_parameters() 187 | outputs, _ = self.lstm(x) 188 | 189 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 190 | outputs, batch_first=True) 191 | 192 | return outputs 193 | 194 | def inference(self, x): 195 | for conv in self.convolutions: 196 | x = F.dropout(F.relu(conv(x)), drop_rate, self.training) 197 | 198 | x = x.transpose(1, 2) 199 | 200 | self.lstm.flatten_parameters() 201 | outputs, _ = self.lstm(x) 202 | 203 | return outputs 204 | 205 | 206 | class Decoder(nn.Module): 207 | def __init__(self, hparams): 208 | super(Decoder, self).__init__() 209 | self.n_mel_channels = hparams.n_mel_channels 210 | self.n_frames_per_step = hparams.n_frames_per_step 211 | self.encoder_embedding_dim = hparams.encoder_embedding_dim 212 | self.attention_rnn_dim = hparams.attention_rnn_dim 213 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 214 | self.prenet_dim = hparams.prenet_dim 215 | self.max_decoder_steps = hparams.max_decoder_steps 216 | self.gate_threshold = hparams.gate_threshold 217 | self.p_attention_dropout = hparams.p_attention_dropout 218 | self.p_decoder_dropout = hparams.p_decoder_dropout 219 | 220 | self.prenet = Prenet( 221 | hparams.n_mel_channels * hparams.n_frames_per_step, 222 | [hparams.prenet_dim, hparams.prenet_dim]) 223 | 224 | self.attention_rnn = nn.LSTMCell( 225 | hparams.prenet_dim + self.encoder_embedding_dim, 226 | hparams.attention_rnn_dim) 227 | 228 | self.attention_layer = Attention( 229 | hparams.attention_rnn_dim, self.encoder_embedding_dim, 230 | hparams.attention_dim, hparams.attention_location_n_filters, 231 | hparams.attention_location_kernel_size) 232 | 233 | self.decoder_rnn = nn.LSTMCell( 234 | hparams.attention_rnn_dim + self.encoder_embedding_dim, 235 | hparams.decoder_rnn_dim, 1) 236 | 237 | self.linear_projection = LinearNorm( 238 | hparams.decoder_rnn_dim + self.encoder_embedding_dim, 239 | hparams.n_mel_channels * hparams.n_frames_per_step) 240 | 241 | self.gate_layer = LinearNorm( 242 | hparams.decoder_rnn_dim + self.encoder_embedding_dim, 1, 243 | bias=True, w_init_gain='sigmoid') 244 | 245 | def get_go_frame(self, memory): 246 | """ Gets all zeros frames to use as first decoder input 247 | PARAMS 248 | ------ 249 | memory: decoder outputs 250 | 251 | RETURNS 252 | ------- 253 | decoder_input: all zeros frames 254 | """ 255 | B = memory.size(0) 256 | decoder_input = Variable(memory.data.new( 257 | B, self.n_mel_channels * self.n_frames_per_step).zero_()) 258 | return decoder_input 259 | 260 | def initialize_decoder_states(self, memory, mask): 261 | """ Initializes attention rnn states, decoder rnn states, attention 262 | weights, attention cumulative weights, attention context, stores memory 263 | and stores processed memory 264 | PARAMS 265 | ------ 266 | memory: Encoder outputs 267 | mask: Mask for padded data if training, expects None for inference 268 | """ 269 | B = memory.size(0) 270 | MAX_TIME = memory.size(1) 271 | 272 | self.attention_hidden = Variable(memory.data.new( 273 | B, self.attention_rnn_dim).zero_()) 274 | self.attention_cell = Variable(memory.data.new( 275 | B, self.attention_rnn_dim).zero_()) 276 | 277 | self.decoder_hidden = Variable(memory.data.new( 278 | B, self.decoder_rnn_dim).zero_()) 279 | self.decoder_cell = Variable(memory.data.new( 280 | B, self.decoder_rnn_dim).zero_()) 281 | 282 | self.attention_weights = Variable(memory.data.new( 283 | B, MAX_TIME).zero_()) 284 | self.attention_weights_cum = Variable(memory.data.new( 285 | B, MAX_TIME).zero_()) 286 | self.attention_context = Variable(memory.data.new( 287 | B, self.encoder_embedding_dim).zero_()) 288 | 289 | self.memory = memory 290 | self.processed_memory = self.attention_layer.memory_layer(memory) 291 | self.mask = mask 292 | 293 | def parse_decoder_inputs(self, decoder_inputs): 294 | """ Prepares decoder inputs, i.e. mel outputs 295 | PARAMS 296 | ------ 297 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 298 | 299 | RETURNS 300 | ------- 301 | inputs: processed decoder inputs 302 | 303 | """ 304 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 305 | decoder_inputs = decoder_inputs.transpose(1, 2) 306 | decoder_inputs = decoder_inputs.view( 307 | decoder_inputs.size(0), 308 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1) 309 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 310 | decoder_inputs = decoder_inputs.transpose(0, 1) 311 | return decoder_inputs 312 | 313 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): 314 | """ Prepares decoder outputs for output 315 | PARAMS 316 | ------ 317 | mel_outputs: 318 | gate_outputs: gate output energies 319 | alignments: 320 | 321 | RETURNS 322 | ------- 323 | mel_outputs: 324 | gate_outpust: gate output energies 325 | alignments: 326 | """ 327 | # (T_out, B) -> (B, T_out) 328 | alignments = torch.stack(alignments).transpose(0, 1) 329 | # (T_out, B) -> (B, T_out) 330 | 331 | gate_outputs = torch.stack(gate_outputs) 332 | if len(gate_outputs.size()) == 1: 333 | gate_outputs.unsqueeze_(1) 334 | gate_outputs = gate_outputs.transpose(0, 1) 335 | gate_outputs = gate_outputs.contiguous() 336 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 337 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 338 | # decouple frames per step 339 | mel_outputs = mel_outputs.view( 340 | mel_outputs.size(0), -1, self.n_mel_channels) 341 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 342 | mel_outputs = mel_outputs.transpose(1, 2) 343 | 344 | return mel_outputs, gate_outputs, alignments 345 | 346 | def decode(self, decoder_input): 347 | """ Decoder step using stored states, attention and memory 348 | PARAMS 349 | ------ 350 | decoder_input: previous mel output 351 | 352 | RETURNS 353 | ------- 354 | mel_output: 355 | gate_output: gate output energies 356 | attention_weights: 357 | """ 358 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 359 | self.attention_hidden, self.attention_cell = self.attention_rnn( 360 | cell_input, (self.attention_hidden, self.attention_cell)) 361 | self.attention_hidden = F.dropout( 362 | self.attention_hidden, self.p_attention_dropout, self.training) 363 | self.attention_cell = F.dropout( 364 | self.attention_cell, self.p_attention_dropout, self.training) 365 | 366 | attention_weights_cat = torch.cat( 367 | (self.attention_weights.unsqueeze(1), 368 | self.attention_weights_cum.unsqueeze(1)), dim=1) 369 | self.attention_context, self.attention_weights = self.attention_layer( 370 | self.attention_hidden, self.memory, self.processed_memory, 371 | attention_weights_cat, self.mask) 372 | 373 | self.attention_weights_cum += self.attention_weights 374 | decoder_input = torch.cat( 375 | (self.attention_hidden, self.attention_context), -1) 376 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 377 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 378 | self.decoder_hidden = F.dropout( 379 | self.decoder_hidden, self.p_decoder_dropout, self.training) 380 | self.decoder_cell = F.dropout( 381 | self.decoder_cell, self.p_decoder_dropout, self.training) 382 | 383 | decoder_hidden_attention_context = torch.cat( 384 | (self.decoder_hidden, self.attention_context), dim=1) 385 | decoder_output = self.linear_projection( 386 | decoder_hidden_attention_context) 387 | 388 | gate_prediction = self.gate_layer(decoder_hidden_attention_context) 389 | return decoder_output, gate_prediction, self.attention_weights 390 | 391 | def forward(self, memory, decoder_inputs, memory_lengths): 392 | """ Decoder forward pass for training 393 | PARAMS 394 | ------ 395 | memory: Encoder outputs 396 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs 397 | memory_lengths: Encoder output lengths for attention masking. 398 | 399 | RETURNS 400 | ------- 401 | mel_outputs: mel outputs from the decoder 402 | gate_outputs: gate outputs from the decoder 403 | alignments: sequence of attention weights from the decoder 404 | """ 405 | 406 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 407 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 408 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 409 | decoder_inputs = self.prenet(decoder_inputs) 410 | 411 | self.initialize_decoder_states( 412 | memory, mask=~get_mask_from_lengths(memory_lengths)) 413 | 414 | mel_outputs, gate_outputs, alignments = [], [], [] 415 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 416 | decoder_input = decoder_inputs[len(mel_outputs)] 417 | mel_output, gate_output, attention_weights = self.decode( 418 | decoder_input) 419 | mel_outputs += [mel_output.squeeze(1)] 420 | gate_outputs += [gate_output.squeeze()] 421 | alignments += [attention_weights] 422 | 423 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 424 | mel_outputs, gate_outputs, alignments) 425 | 426 | return mel_outputs, gate_outputs, alignments 427 | 428 | def inference(self, memory): 429 | """ Decoder inference 430 | PARAMS 431 | ------ 432 | memory: Encoder outputs 433 | 434 | RETURNS 435 | ------- 436 | mel_outputs: mel outputs from the decoder 437 | gate_outputs: gate outputs from the decoder 438 | alignments: sequence of attention weights from the decoder 439 | """ 440 | decoder_input = self.get_go_frame(memory) 441 | 442 | self.initialize_decoder_states(memory, mask=None) 443 | 444 | mel_outputs, gate_outputs, alignments = [], [], [] 445 | while True: 446 | decoder_input = self.prenet(decoder_input) 447 | mel_output, gate_output, alignment = self.decode(decoder_input) 448 | 449 | mel_outputs += [mel_output.squeeze(1)] 450 | gate_outputs += [gate_output] 451 | alignments += [alignment] 452 | 453 | if torch.sigmoid(gate_output.data) > self.gate_threshold: 454 | break 455 | elif len(mel_outputs) == self.max_decoder_steps: 456 | print("Warning! Reached max decoder steps") 457 | break 458 | 459 | decoder_input = mel_output 460 | 461 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 462 | mel_outputs, gate_outputs, alignments) 463 | 464 | return mel_outputs, gate_outputs, alignments 465 | 466 | 467 | class Tacotron2(nn.Module): 468 | def __init__(self, hparams): 469 | super(Tacotron2, self).__init__() 470 | self.mask_padding = hparams.mask_padding 471 | self.fp16_run = hparams.fp16_run 472 | self.n_mel_channels = hparams.n_mel_channels 473 | self.n_frames_per_step = hparams.n_frames_per_step 474 | self.transcript_embedding = nn.Embedding( 475 | hparams.n_symbols, hparams.symbols_embedding_dim) 476 | self.speaker_embedding = LinearNorm( 477 | hparams.n_speakers, hparams.speaker_embedding_dim, bias=True, w_init_gain='tanh') 478 | self.emotion_embedding = LinearNorm( 479 | hparams.n_emotions, hparams.emotion_embedding_dim, bias=True, w_init_gain='tanh') 480 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 481 | val = sqrt(3.0) * std # uniform bounds for std 482 | self.transcript_embedding.weight.data.uniform_(-val, val) 483 | self.encoder = Encoder(hparams) 484 | self.decoder = Decoder(hparams) 485 | self.postnet = Postnet(hparams) 486 | 487 | self.vae_gst = VAE_GST(hparams) 488 | 489 | def parse_batch(self, batch): 490 | text_padded, input_lengths, mel_padded, gate_padded, \ 491 | output_lengths, speakers, emotions = batch 492 | text_padded = to_gpu(text_padded).long() 493 | speakers = to_gpu(speakers).float() 494 | emotions = to_gpu(emotions).float() 495 | input_lengths = to_gpu(input_lengths).long() 496 | max_len = torch.max(input_lengths.data).item() 497 | mel_padded = to_gpu(mel_padded).float() 498 | gate_padded = to_gpu(gate_padded).float() 499 | output_lengths = to_gpu(output_lengths).long() 500 | 501 | return ( 502 | (text_padded, input_lengths, mel_padded, max_len, output_lengths, speakers, emotions), 503 | (mel_padded, gate_padded)) 504 | 505 | def parse_input(self, inputs): 506 | inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs 507 | return inputs 508 | 509 | def parse_output(self, outputs, output_lengths=None): 510 | if self.mask_padding and output_lengths is not None: 511 | mask = ~get_mask_from_lengths(output_lengths) 512 | mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) 513 | mask = mask.permute(1, 0, 2) 514 | 515 | outputs[0].data.masked_fill_(mask, 0.0) 516 | outputs[1].data.masked_fill_(mask, 0.0) 517 | outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies 518 | 519 | outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs 520 | return outputs 521 | 522 | def forward(self, inputs): 523 | inputs, input_lengths, targets, _, \ 524 | output_lengths, speakers, emotions = self.parse_input(inputs) 525 | input_lengths, output_lengths = input_lengths.data, output_lengths.data 526 | 527 | ## added 528 | transcript_embedded_inputs = self.transcript_embedding(inputs).transpose(1, 2) 529 | 530 | # [N, transcript_T, int(encoder_dim/2)] 531 | transcript_outputs = self.encoder(transcript_embedded_inputs, input_lengths) 532 | 533 | transcript_outputs_size = list(transcript_outputs.shape) 534 | 535 | prosody_outputs, mu, logvar, z = self.vae_gst(targets) # get z 536 | prosody_outputs = prosody_outputs.unsqueeze(1).expand_as(transcript_outputs) 537 | encoder_outputs = transcript_outputs + prosody_outputs # for decoder input 538 | 539 | mel_outputs, gate_outputs, alignments = self.decoder( 540 | encoder_outputs, targets, memory_lengths=input_lengths) 541 | 542 | mel_outputs_postnet = self.postnet(mel_outputs) 543 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 544 | 545 | return self.parse_output( 546 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, mu, logvar, z, emotions], 547 | output_lengths) -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from layers import LinearNorm 6 | from CoordConv import CoordConv2d 7 | 8 | class VAE_GST(nn.Module): 9 | def __init__(self, hparams): 10 | super().__init__() 11 | self.ref_encoder = ReferenceEncoder(hparams) 12 | self.fc1 = nn.Linear(hparams.ref_enc_gru_size, hparams.z_latent_dim) 13 | self.fc2 = nn.Linear(hparams.ref_enc_gru_size, hparams.z_latent_dim) 14 | self.fc3 = nn.Linear(hparams.z_latent_dim, hparams.E) 15 | 16 | def reparameterize(self, mu, logvar): 17 | if self.training: 18 | std = torch.exp(0.5 * logvar) 19 | eps = torch.randn_like(std) 20 | return eps.mul(std).add_(mu) 21 | else: 22 | return mu 23 | 24 | def forward(self, inputs): 25 | enc_out = self.ref_encoder(inputs) 26 | mu = self.fc1(enc_out) 27 | logvar = self.fc2(enc_out) 28 | z = self.reparameterize(mu, logvar) 29 | style_embed = self.fc3(z) 30 | 31 | return style_embed, mu, logvar, z 32 | 33 | 34 | class ReferenceEncoder(nn.Module): 35 | ''' 36 | inputs --- [N, Ty/r, n_mels*r] mels 37 | outputs --- [N, ref_enc_gru_size] 38 | ''' 39 | 40 | def __init__(self, hparams): 41 | super().__init__() 42 | K = len(hparams.ref_enc_filters) 43 | filters = [1] + hparams.ref_enc_filters 44 | # 첫번째 레이어로 CoordConv를 사용하는 것이 positional 정보를 잘 보존한다고 함. https://arxiv.org/pdf/1811.02122.pdf 45 | convs = [CoordConv2d(in_channels=filters[0], 46 | out_channels=filters[0 + 1], 47 | kernel_size=(3, 3), 48 | stride=(2, 2), 49 | padding=(1, 1), with_r=True)] 50 | convs2 = [nn.Conv2d(in_channels=filters[i], 51 | out_channels=filters[i + 1], 52 | kernel_size=(3, 3), 53 | stride=(2, 2), 54 | padding=(1, 1)) for i in range(1,K)] 55 | convs.extend(convs2) 56 | self.convs = nn.ModuleList(convs) 57 | self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hparams.ref_enc_filters[i]) for i in range(K)]) 58 | 59 | out_channels = self.calculate_channels(hparams.n_mel_channels, 3, 2, 1, K) 60 | self.gru = nn.GRU(input_size=hparams.ref_enc_filters[-1] * out_channels, 61 | hidden_size=hparams.E // 2, 62 | batch_first=True) 63 | self.n_mels = hparams.n_mel_channels 64 | 65 | def forward(self, inputs): 66 | N = inputs.size(0) 67 | out = inputs.contiguous().view(N, 1, -1, self.n_mels) # [N, 1, Ty, n_mels] 68 | for conv, bn in zip(self.convs, self.bns): 69 | out = conv(out) 70 | out = bn(out) 71 | out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] 72 | 73 | out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] 74 | T = out.size(1) 75 | N = out.size(0) 76 | out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] 77 | 78 | memory, out = self.gru(out) # out --- [1, N, E//2] 79 | 80 | return out.squeeze(0) 81 | 82 | def calculate_channels(self, L, kernel_size, stride, pad, n_convs): 83 | for i in range(n_convs): 84 | L = (L - kernel_size + 2 * pad) // stride + 1 85 | return L -------------------------------------------------------------------------------- /multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | 46 | 47 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 48 | fig, ax = plt.subplots(figsize=(12, 3)) 49 | ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, 50 | color='green', marker='+', s=1, label='target') 51 | ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, 52 | color='red', marker='.', s=1, label='predicted') 53 | 54 | plt.xlabel("Frames (Green target, Red predicted)") 55 | plt.ylabel("Gate State") 56 | plt.tight_layout() 57 | 58 | fig.canvas.draw() 59 | data = save_figure_to_numpy(fig) 60 | plt.close() 61 | return data 62 | 63 | def plot_scatter(mus, y): 64 | """ 65 | tensorboardX에서 scatter plot 그릴 수 있음 66 | """ 67 | colors = 'r','b','g','y' 68 | labels = 'neu','sad','ang','hap' 69 | 70 | mus = mus.cpu().numpy() 71 | y = y.cpu().numpy() 72 | y = np.argmax(y,1) 73 | 74 | fig, ax = plt.subplots(figsize=(12,12)) 75 | for i, (c, label) in enumerate(zip(colors, labels)): 76 | ax.scatter(mus[y==i,0], mus[y==i,1], c=c, label=label, alpha=0.5) 77 | 78 | plt.legend(loc='upper left') 79 | 80 | fig.canvas.draw() 81 | data = save_figure_to_numpy(fig) 82 | plt.close() 83 | return data -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.1.0 2 | tensorflow 3 | inflect==0.2.5 4 | librosa==0.6.0 5 | scipy==1.0.0 6 | tensorboardX==1.1 7 | Unidecode==1.0.22 8 | pillow 9 | nltk==3.4.5 10 | jamo==0.4.1 -------------------------------------------------------------------------------- /res/alignment.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/alignment.gif -------------------------------------------------------------------------------- /res/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/demo.png -------------------------------------------------------------------------------- /res/interpolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/interpolation.png -------------------------------------------------------------------------------- /res/kldiv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/kldiv.png -------------------------------------------------------------------------------- /res/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/overview.png -------------------------------------------------------------------------------- /res/reconloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/reconloss.png -------------------------------------------------------------------------------- /res/scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/scatter.png -------------------------------------------------------------------------------- /res/trainingloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/trainingloss.png -------------------------------------------------------------------------------- /res/tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/tsne.png -------------------------------------------------------------------------------- /res/validloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/res/validloss.png -------------------------------------------------------------------------------- /samples/interpolation/ang0.3_sad0.6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/ang0.3_sad0.6.wav -------------------------------------------------------------------------------- /samples/interpolation/ang0.6_sad0.3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/ang0.6_sad0.3.wav -------------------------------------------------------------------------------- /samples/interpolation/ang1.0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/ang1.0.wav -------------------------------------------------------------------------------- /samples/interpolation/hap0.3_sad0.6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/hap0.3_sad0.6.wav -------------------------------------------------------------------------------- /samples/interpolation/hap0.6_sad0.3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/hap0.6_sad0.3.wav -------------------------------------------------------------------------------- /samples/interpolation/hap1.0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/hap1.0.wav -------------------------------------------------------------------------------- /samples/interpolation/neu0.3_sad0.6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/neu0.3_sad0.6.wav -------------------------------------------------------------------------------- /samples/interpolation/neu0.6_sad0.3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/neu0.6_sad0.3.wav -------------------------------------------------------------------------------- /samples/interpolation/neu1.0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/neu1.0.wav -------------------------------------------------------------------------------- /samples/interpolation/sad1.0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/interpolation/sad1.0.wav -------------------------------------------------------------------------------- /samples/mix/hap0.25_ang0.75.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/mix/hap0.25_ang0.75.wav -------------------------------------------------------------------------------- /samples/mix/hap0.25_sad0.25_ang0.5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/mix/hap0.25_sad0.25_ang0.5.wav -------------------------------------------------------------------------------- /samples/mix/neu0.25_hap0.25_sad0.25_ang0.25.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/mix/neu0.25_hap0.25_sad0.25_ang0.25.wav -------------------------------------------------------------------------------- /samples/refs/recorded_ang.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/recorded_ang.wav -------------------------------------------------------------------------------- /samples/refs/recorded_hap.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/recorded_hap.wav -------------------------------------------------------------------------------- /samples/refs/recorded_neu.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/recorded_neu.wav -------------------------------------------------------------------------------- /samples/refs/recorded_sad.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/recorded_sad.wav -------------------------------------------------------------------------------- /samples/refs/ref_ang.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/ref_ang.wav -------------------------------------------------------------------------------- /samples/refs/ref_hap.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/ref_hap.wav -------------------------------------------------------------------------------- /samples/refs/ref_neu.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/ref_neu.wav -------------------------------------------------------------------------------- /samples/refs/ref_sad.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinhan/tacotron2-vae/5fbe8ff968c87130638c60d588b229889725a55b/samples/refs/ref_sad.wav -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann'): 46 | super(STFT, self).__init__() 47 | self.filter_length = filter_length 48 | self.hop_length = hop_length 49 | self.win_length = win_length 50 | self.window = window 51 | self.forward_transform = None 52 | scale = self.filter_length / self.hop_length 53 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 54 | 55 | cutoff = int((self.filter_length / 2 + 1)) 56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 57 | np.imag(fourier_basis[:cutoff, :])]) 58 | 59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 60 | inverse_basis = torch.FloatTensor( 61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 62 | 63 | if window is not None: 64 | assert(filter_length >= win_length) 65 | # get window and zero center pad it to filter_length 66 | fft_window = get_window(window, win_length, fftbins=True) 67 | fft_window = pad_center(fft_window, filter_length) 68 | fft_window = torch.from_numpy(fft_window).float() 69 | 70 | # window the bases 71 | forward_basis *= fft_window 72 | inverse_basis *= fft_window 73 | 74 | self.register_buffer('forward_basis', forward_basis.float()) 75 | self.register_buffer('inverse_basis', inverse_basis.float()) 76 | 77 | def transform(self, input_data): 78 | num_batches = input_data.size(0) 79 | num_samples = input_data.size(1) 80 | 81 | self.num_samples = num_samples 82 | 83 | # similar to librosa, reflect-pad the input 84 | input_data = input_data.view(num_batches, 1, num_samples) 85 | input_data = F.pad( 86 | input_data.unsqueeze(1), 87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 88 | mode='reflect') 89 | input_data = input_data.squeeze(1) 90 | 91 | forward_transform = F.conv1d( 92 | input_data, 93 | Variable(self.forward_basis, requires_grad=False), 94 | stride=self.hop_length, 95 | padding=0) 96 | 97 | cutoff = int((self.filter_length / 2) + 1) 98 | real_part = forward_transform[:, :cutoff, :] 99 | imag_part = forward_transform[:, cutoff:, :] 100 | 101 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 102 | phase = torch.autograd.Variable( 103 | torch.atan2(imag_part.data, real_part.data)) 104 | 105 | return magnitude, phase 106 | 107 | def inverse(self, magnitude, phase): 108 | recombine_magnitude_phase = torch.cat( 109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 110 | 111 | inverse_transform = F.conv_transpose1d( 112 | recombine_magnitude_phase, 113 | Variable(self.inverse_basis, requires_grad=False), 114 | stride=self.hop_length, 115 | padding=0) 116 | 117 | if self.window is not None: 118 | window_sum = window_sumsquare( 119 | self.window, magnitude.size(-1), hop_length=self.hop_length, 120 | win_length=self.win_length, n_fft=self.filter_length, 121 | dtype=np.float32) 122 | # remove modulation effects 123 | approx_nonzero_indices = torch.from_numpy( 124 | np.where(window_sum > tiny(window_sum))[0]) 125 | window_sum = torch.autograd.Variable( 126 | torch.from_numpy(window_sum), requires_grad=False) 127 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 128 | 129 | # scale by hop ratio 130 | inverse_transform *= float(self.filter_length) / self.hop_length 131 | 132 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 133 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 134 | 135 | return inverse_transform 136 | 137 | def forward(self, input_data): 138 | self.magnitude, self.phase = self.transform(input_data) 139 | reconstruction = self.inverse(self.magnitude, self.phase) 140 | return reconstruction 141 | -------------------------------------------------------------------------------- /synthesizer.py: -------------------------------------------------------------------------------- 1 | # import io 2 | # import os 3 | # import re 4 | # import librosa 5 | # import argparse 6 | # import numpy as np 7 | # from glob import glob 8 | # from tqdm import tqdm 9 | # import tensorflow as tf 10 | # from functools import partial 11 | 12 | # from hparams import hparams 13 | # from models import create_model, get_most_recent_checkpoint 14 | # from audio import save_audio, inv_spectrogram, inv_preemphasis, \ 15 | # inv_spectrogram_tensorflow 16 | # from utils import plot, PARAMS_NAME, load_json, load_hparams, \ 17 | # add_prefix, add_postfix, get_time, parallel_run, makedirs, str2bool 18 | 19 | # from text.korean import tokenize 20 | # from text import text_to_sequence, sequence_to_text 21 | import sys 22 | sys.path.append('waveglow/') 23 | import numpy as np 24 | import torch 25 | 26 | from hparams import create_hparams 27 | from model import Tacotron2 28 | from layers import TacotronSTFT 29 | from train import load_model 30 | from text import text_to_sequence 31 | 32 | from utils import load_wav_to_torch 33 | from scipy.io.wavfile import write 34 | import os 35 | import time 36 | import librosa 37 | 38 | # from sklearn.manifold import TSNE 39 | # import matplotlib 40 | # matplotlib.use("Agg") 41 | # import matplotlib.pylab as plt 42 | # %matplotlib inline 43 | # import IPython.display as ipd 44 | from tqdm import tqdm 45 | 46 | class Synthesizer(object): 47 | def __init__(self): 48 | super().__init__() 49 | self.hparams = create_hparams() 50 | self.hparams.sampling_rate = 16000 51 | self.hparams.max_decoder_steps = 600 52 | 53 | self.stft = TacotronSTFT( 54 | self.hparams.filter_length, self.hparams.hop_length, self.hparams.win_length, 55 | self.hparams.n_mel_channels, self.hparams.sampling_rate, self.hparams.mel_fmin, 56 | self.hparams.mel_fmax) 57 | 58 | def load_mel(self, path): 59 | audio, sampling_rate = load_wav_to_torch(path) 60 | if sampling_rate != self.hparams.sampling_rate: 61 | raise ValueError("{} SR doesn't match target {} SR".format( 62 | sampling_rate, self.stft.sampling_rate)) 63 | audio_norm = audio / self.hparams.max_wav_value 64 | audio_norm = audio_norm.unsqueeze(0) 65 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 66 | melspec = self.stft.mel_spectrogram(audio_norm) 67 | melspec = melspec.cuda() 68 | return melspec 69 | 70 | # def close(self): 71 | # tf.reset_default_graph() 72 | # self.sess.close() 73 | 74 | def load(self, checkpoint_path, waveglow_path): 75 | self.model = load_model(self.hparams) 76 | self.model.load_state_dict(torch.load(checkpoint_path)['state_dict']) 77 | _ = self.model.eval() 78 | 79 | self.waveglow = torch.load(waveglow_path)['model'] 80 | self.waveglow.cuda() 81 | 82 | path = './web/static/uploads/koemo_spk_emo_all_test.txt' 83 | with open(path, encoding='utf-8') as f: 84 | filepaths_and_text = [line.strip().split("|") for line in f] 85 | 86 | base_path = os.path.dirname(checkpoint_path) 87 | data_path = os.path.basename(checkpoint_path) + '_' + path.rsplit('_', 1)[1].split('.')[0] + '.npz' 88 | npz_path = os.path.join(base_path, data_path) 89 | 90 | if os.path.exists(npz_path): 91 | d = np.load(npz_path) 92 | zs = d['zs'] 93 | emotions = d['emotions'] 94 | else: 95 | emotions = [] 96 | zs = [] 97 | for audio_path, _, _, emotion in tqdm(filepaths_and_text): 98 | melspec = self.load_mel(audio_path) 99 | _, _, _, z = self.model.vae_gst(melspec) 100 | zs.append(z.cpu().data) 101 | emotions.append(int(emotion)) 102 | emotions = np.array(emotions) # list이면 안됨 -> ndarray 103 | zs = torch.cat(zs, dim=0).data.numpy() 104 | d = {'zs':zs, 'emotions':emotions} 105 | np.savez(npz_path, **d) 106 | 107 | self.neu = np.mean(zs[emotions==0,:], axis=0) 108 | self.sad = np.mean(zs[emotions==1,:], axis=0) 109 | self.ang = np.mean(zs[emotions==2,:], axis=0) 110 | self.hap = np.mean(zs[emotions==3,:], axis=0) 111 | 112 | def synthesize(self, text, path, condition_on_ref, ref_audio, ratios): 113 | print(ratios) 114 | sequence = np.array(text_to_sequence(text, ['korean_cleaners']))[None, :] 115 | sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 116 | inputs = self.model.parse_input(sequence) 117 | transcript_embedded_inputs = self.model.transcript_embedding(inputs).transpose(1,2) 118 | transcript_outputs = self.model.encoder.inference(transcript_embedded_inputs) 119 | print(condition_on_ref) 120 | 121 | if condition_on_ref: 122 | #ref_audio = '/data1/jinhan/KoreanEmotionSpeech/wav/hap/hap_00000001.wav' 123 | ref_audio_mel = self.load_mel(ref_audio) 124 | latent_vector, _, _, _ = self.model.vae_gst(ref_audio_mel) 125 | latent_vector = latent_vector.unsqueeze(1).expand_as(transcript_outputs) 126 | 127 | else: # condition on emotion ratio 128 | latent_vector = ratios[0] * self.neu + ratios[1] * self.sad + \ 129 | ratios[2] * self.hap + ratios[3] * self.ang 130 | latent_vector = torch.FloatTensor(latent_vector).cuda() 131 | latent_vector = self.model.vae_gst.fc3(latent_vector) 132 | 133 | encoder_outputs = transcript_outputs + latent_vector 134 | 135 | decoder_input = self.model.decoder.get_go_frame(encoder_outputs) 136 | self.model.decoder.initialize_decoder_states(encoder_outputs, mask=None) 137 | mel_outputs, gate_outputs, alignments = [], [], [] 138 | 139 | while True: 140 | decoder_input = self.model.decoder.prenet(decoder_input) 141 | mel_output, gate_output, alignment = self.model.decoder.decode(decoder_input) 142 | 143 | mel_outputs += [mel_output] 144 | gate_outputs += [gate_output] 145 | alignments += [alignment] 146 | 147 | if torch.sigmoid(gate_output.data) > self.hparams.gate_threshold: 148 | # print(torch.sigmoid(gate_output.data), gate_output.data) 149 | break 150 | if len(mel_outputs) == self.hparams.max_decoder_steps: 151 | print("Warning! Reached max decoder steps") 152 | break 153 | 154 | decoder_input = mel_output 155 | 156 | mel_outputs, gate_outputs, alignments = self.model.decoder.parse_decoder_outputs( 157 | mel_outputs, gate_outputs, alignments) 158 | mel_outputs_postnet = self.model.postnet(mel_outputs) 159 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 160 | # print(mel_outputs_postnet.shape) 161 | 162 | with torch.no_grad(): 163 | synth = self.waveglow.infer(mel_outputs, sigma=0.666) 164 | 165 | # return synth[0].data.cpu().numpy() 166 | # path = add_postfix(path, idx) 167 | # print(path) 168 | librosa.output.write_wav(path, synth[0].data.cpu().numpy(), 16000) 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--load_path', required=True) 174 | parser.add_argument('--sample_path', default="samples") 175 | parser.add_argument('--text', required=True) 176 | parser.add_argument('--num_speakers', default=1, type=int) 177 | parser.add_argument('--speaker_id', default=0, type=int) 178 | parser.add_argument('--checkpoint_step', default=None, type=int) 179 | parser.add_argument('--is_korean', default=True, type=str2bool) 180 | config = parser.parse_args() 181 | 182 | makedirs(config.sample_path) 183 | 184 | synthesizer = Synthesizer() 185 | synthesizer.load(config.load_path, config.num_speakers, config.checkpoint_step) 186 | 187 | audio = synthesizer.synthesize( 188 | texts=[config.text], 189 | base_path=config.sample_path, 190 | speaker_ids=[config.speaker_id], 191 | attention_trim=False, 192 | isKorean=config.is_korean)[0] 193 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import eng_symbols, kor_symbols 5 | from hparams import create_hparams 6 | 7 | hparam = create_hparams() 8 | cleaner_names = hparam.text_cleaners 9 | 10 | # Mappings from symbol to numeric ID and vice versa: 11 | symbols = "" 12 | _symbol_to_id = {} 13 | _id_to_symbol = {} 14 | 15 | # Regular expression matching text enclosed in curly braces: 16 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 17 | 18 | def change_symbol(cleaner_names): 19 | symbols = "" 20 | global _symbol_to_id 21 | global _id_to_symbol 22 | if cleaner_names == ["english_cleaners"]: symbols = eng_symbols 23 | if cleaner_names == ["korean_cleaners"]: symbols = kor_symbols 24 | 25 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 26 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 27 | 28 | change_symbol(cleaner_names) 29 | 30 | def text_to_sequence(text, cleaner_names): 31 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 32 | 33 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 34 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 35 | 36 | Args: 37 | text: string to convert to a sequence 38 | cleaner_names: names of the cleaner functions to run the text through 39 | 40 | Returns: 41 | List of integers corresponding to the symbols in the text 42 | ''' 43 | sequence = [] 44 | change_symbol(cleaner_names) 45 | # Check for curly braces and treat their contents as ARPAbet: 46 | while len(text): 47 | m = _curly_re.match(text) 48 | try: 49 | if not m: 50 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 51 | break 52 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 53 | sequence += _arpabet_to_sequence(m.group(2)) 54 | text = m.group(3) 55 | except: 56 | # print(text) 57 | exit() 58 | # Append EOS token 59 | sequence.append(_symbol_to_id['~']) 60 | return sequence 61 | 62 | 63 | def sequence_to_text(sequence): 64 | '''Converts a sequence of IDs back to a string''' 65 | result = '' 66 | for symbol_id in sequence: 67 | if symbol_id in _id_to_symbol: 68 | s = _id_to_symbol[symbol_id] 69 | # Enclose ARPAbet back in curly braces: 70 | if len(s) > 1 and s[0] == '@': 71 | s = '{%s}' % s[1:] 72 | result += s 73 | return result.replace('}{', ' ') 74 | 75 | 76 | def _clean_text(text, cleaner_names): 77 | for name in cleaner_names: 78 | cleaner = getattr(cleaners, name) 79 | if not cleaner: 80 | raise Exception('Unknown cleaner: %s' % name) 81 | text = cleaner(text) 82 | # print(text) 83 | return text 84 | 85 | 86 | def _symbols_to_sequence(symbols): 87 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 88 | 89 | 90 | def _arpabet_to_sequence(text): 91 | return _symbols_to_sequence(['@' + s for s in text.split()]) 92 | 93 | 94 | def _should_keep_symbol(s): 95 | return s in _symbol_to_id and s is not '_' and s is not '~' 96 | 97 | if __name__ == "__main__": 98 | # print(text_to_sequence('this is test sentence.? ', ['english_cleaners'])) 99 | # print(text_to_sequence('Chapter one of Jane eyre. This is there librivox recording. All librivox recordings are in the public domain. For more information or to volunteer please visit librivox dot org.', ['english_cleaners'])) 100 | # print(text_to_sequence('Recording by Elisabeth Klett.', ['english_cleaners'])) 101 | # print(text_to_sequence('테스트 문장입니다.? ', ['korean_cleaners'])) 102 | # print(_clean_text('AB테스트 문장입니다.? ', ['korean_cleaners'])) 103 | # print(_clean_text('mp3 파일을 홈페이지에서 다운로드 받으시기 바랍니다.',['korean_cleaners'])) 104 | # print(_clean_text("마가렛 대처의 별명은 '철의 여인'이었다.", ['korean_cleaners'])) 105 | # print(_clean_text("제 전화번호는 01012345678이에요.", ['korean_cleaners'])) 106 | # print(_clean_text("‘아줌마’는 결혼한 여자를 뜻한다.", ['korean_cleaners'])) 107 | # print(text_to_sequence("‘아줌마’는 결혼한 여자를 뜻한다.", ['korean_cleaners'])) 108 | # print(runKoG2P('AB테스트 문장입니다.?', 'text/rulebook.txt')) 109 | # print(runKoG2P("마가렛 대처의 별명은 '철의 여인'이었다.", 'text/rulebook.txt')) 110 | # print(text_to_sequence("나는 난 닫았다 닫 라면 랄 바보 밥 아잉 앙 샀다 삿 잦았다 잦 항", ['korean_cleaners'])) 111 | print(text_to_sequence("감정있는 한국어 목소리 생성", ['korean_cleaners'])) 112 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from text.numbers_ import normalize_numbers 18 | from text.korean import tokenize as ko_tokenize 19 | 20 | 21 | # Regular expression matching whitespace: 22 | _whitespace_re = re.compile(r'\s+') 23 | 24 | # List of (regular expression, replacement) pairs for abbreviations: 25 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 26 | ('mrs', 'misess'), 27 | ('mr', 'mister'), 28 | ('dr', 'doctor'), 29 | ('st', 'saint'), 30 | ('co', 'company'), 31 | ('jr', 'junior'), 32 | ('maj', 'major'), 33 | ('gen', 'general'), 34 | ('drs', 'doctors'), 35 | ('rev', 'reverend'), 36 | ('lt', 'lieutenant'), 37 | ('hon', 'honorable'), 38 | ('sgt', 'sergeant'), 39 | ('capt', 'captain'), 40 | ('esq', 'esquire'), 41 | ('ltd', 'limited'), 42 | ('col', 'colonel'), 43 | ('ft', 'fort'), 44 | ]] 45 | 46 | 47 | def expand_abbreviations(text): 48 | for regex, replacement in _abbreviations: 49 | text = re.sub(regex, replacement, text) 50 | return text 51 | 52 | 53 | def expand_numbers(text): 54 | return normalize_numbers(text) 55 | 56 | 57 | def lowercase(text): 58 | return text.lower() 59 | 60 | 61 | def collapse_whitespace(text): 62 | return re.sub(_whitespace_re, ' ', text) 63 | 64 | 65 | def convert_to_ascii(text): 66 | return unidecode(text) 67 | 68 | 69 | def basic_cleaners(text): 70 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 71 | text = lowercase(text) 72 | text = collapse_whitespace(text) 73 | return text 74 | 75 | 76 | def transliteration_cleaners(text): 77 | '''Pipeline for non-English text that transliterates to ASCII.''' 78 | text = convert_to_ascii(text) 79 | text = lowercase(text) 80 | text = collapse_whitespace(text) 81 | return text 82 | 83 | 84 | def english_cleaners(text): 85 | '''Pipeline for English text, including number and abbreviation expansion.''' 86 | text = convert_to_ascii(text) 87 | text = lowercase(text) 88 | text = expand_numbers(text) 89 | text = expand_abbreviations(text) 90 | text = collapse_whitespace(text) 91 | return text 92 | 93 | def korean_cleaners(text): 94 | '''Pipeline for Korean text, including number and abbreviation expansion.''' 95 | text = ko_tokenize(text, as_id=False) 96 | 97 | return text 98 | 99 | if __name__=='__main__': 100 | print(korean_cleaners("감정있는 한국어 목소리 생성")) -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/ko_dictionary.py: -------------------------------------------------------------------------------- 1 | etc_dictionary = { 2 | '2 30대': '이삼십대', 3 | '20~30대': '이삼십대', 4 | '20, 30대': '이십대 삼십대', 5 | '1+1': '원플러스원', 6 | '3에서 6개월인': '3개월에서 육개월인', 7 | 'mp3': '엠피쓰리', 8 | } 9 | 10 | english_dictionary = { 11 | 'Devsisters': '데브시스터즈', 12 | 'track': '트랙', 13 | 14 | # krbook 15 | 'LA': '엘에이', 16 | 'LG': '엘지', 17 | 'KOREA': '코리아', 18 | 'JSA': '제이에스에이', 19 | 'PGA': '피지에이', 20 | 'GA': '지에이', 21 | 'idol': '아이돌', 22 | 'KTX': '케이티엑스', 23 | 'AC': '에이씨', 24 | 'DVD': '디비디', 25 | 'US': '유에스', 26 | 'CNN': '씨엔엔', 27 | 'LPGA': '엘피지에이', 28 | 'P': '피', 29 | 'p': '피', 30 | 'L': '엘', 31 | 'l': '엘', 32 | 'T': '티', 33 | 't': '티', 34 | 'B': '비', 35 | 'b': '비', 36 | 'C': '씨', 37 | 'c': '씨', 38 | 'BIFF': '비아이에프에프', 39 | 'GV': '지비', 40 | 41 | # JTBC 42 | 'IT': '아이티', 43 | 'IQ': '아이큐', 44 | 'JTBC': '제이티비씨', 45 | 'trickle down effect': '트리클 다운 이펙트', 46 | 'trickle up effect': '트리클 업 이펙트', 47 | 'down': '다운', 48 | 'up': '업', 49 | 'FCK': '에프씨케이', 50 | 'AP': '에이피', 51 | 'WHERETHEWILDTHINGSARE': '', 52 | 'Rashomon Effect': '', 53 | 'O': '오', 54 | 'o': '오', 55 | 'OO': '오오', 56 | 'B': '비', 57 | 'b': '비', 58 | 'GDP': '지디피', 59 | 'CIPA': '씨아이피에이', 60 | 'YS': '와이에스', 61 | 'Y': '와이', 62 | 'y': '와이', 63 | 'S': '에스', 64 | 's': '에스', 65 | 'JTBC': '제이티비씨', 66 | 'PC': '피씨', 67 | 'bill': '빌', 68 | 'Halmuny': '하모니', ##### 69 | 'X': '엑스', 70 | 'SNS': '에스엔에스', 71 | 'ability': '어빌리티', 72 | 'shy': '', 73 | 'CCTV': '씨씨티비', 74 | 'IT': '아이티', 75 | 'the tenth man': '더 텐쓰 맨', #### 76 | 'L': '엘', 77 | 'PC': '피씨', 78 | 'YSDJJPMB': '', ######## 79 | 'Content Attitude Timing': '컨텐트 애티튜드 타이밍', 80 | 'CAT': '캣', 81 | 'IS': '아이에스', 82 | 'SNS': '에스엔에스', 83 | 'K': '케이', 84 | 'k': '케이', 85 | 'Y': '와이', 86 | 'y': '와이', 87 | 'KDI': '케이디아이', 88 | 'DOC': '디오씨', 89 | 'CIA': '씨아이에이', 90 | 'PBS': '피비에스', 91 | 'D': '디', 92 | 'd': '디', 93 | 'PPropertyPositionPowerPrisonP' 94 | 'S': '에스', 95 | 'francisco': '프란시스코', 96 | 'I': '아이', 97 | 'i': '아이', 98 | 'III': '아이아이', ###### 99 | 'No joke': '노 조크', 100 | 'BBK': '비비케이', 101 | 'LA': '엘에이', 102 | 'Don': '', 103 | 't worry be happy': ' 워리 비 해피', 104 | 'NO': '엔오', ##### 105 | 'it was our sky': '잇 워즈 아워 스카이', 106 | 'it is our sky': '잇 이즈 아워 스카이', #### 107 | 'NEIS': '엔이아이에스', ##### 108 | 'IMF': '아이엠에프', 109 | 'apology': '어폴로지', 110 | 'humble': '험블', 111 | 'M': '엠', 112 | 'm': '엠', 113 | 'Nowhere Man': '노웨어 맨', 114 | 'The Tenth Man': '더 텐쓰 맨', 115 | 'PBS': '피비에스', 116 | 'BBC': '비비씨', 117 | 'MRJ': '엠알제이', 118 | 'CCTV': '씨씨티비', 119 | 'Pick me up': '픽 미 업', 120 | 'DNA': '디엔에이', 121 | 'UN': '유엔', 122 | 'STOP': '스탑', ##### 123 | 'PRESS': '프레스', ##### 124 | 'not to be': '낫 투비', 125 | 'Denial': '디나이얼', 126 | 'G': '지', 127 | 'g': '지', 128 | 'IMF': '아이엠에프', 129 | 'GDP': '지디피', 130 | 'JTBC': '제이티비씨', 131 | 'Time flies like an arrow': '타임 플라이즈 라이크 언 애로우', 132 | 'DDT': '디디티', 133 | 'AI': '에이아이', 134 | 'Z': '제트', 135 | 'z': '제트', 136 | 'OECD': '오이씨디', 137 | 'N': '앤', 138 | 'n': '앤', 139 | 'A': '에이', 140 | 'a': '에이', 141 | 'MB': '엠비', 142 | 'EH': '이에이치', 143 | 'IS': '아이에스', 144 | 'TV': '티비', 145 | 'MIT': '엠아이티', 146 | 'KBO': '케이비오', 147 | 'I love America': '아이 러브 아메리카', 148 | 'SF': '에스에프', 149 | 'Q': '큐', 150 | 'q': '큐', 151 | 'KFX': '케이에프엑스', 152 | 'PM': '피엠', 153 | 'Prime Minister': '프라임 미니스터', 154 | 'Swordline': '스워드라인', 155 | 'TBS': '티비에스', 156 | 'DDT': '디디티', 157 | 'CS': '씨에스', 158 | 'Reflecting Absence': '리플렉팅 앱센스', 159 | 'PBS': '피비에스', 160 | 'Drum being beaten by everyone': '드럼 빙 비튼 바이 에브리원', 161 | 'negative pressure': '네거티브 프레셔', 162 | 'F': '에프', 163 | 'f': '에프', 164 | 'KIA': '기아', 165 | 'FTA': '에프티에이', 166 | 'Que sais-je': '', 167 | 'UFC': '유에프씨', 168 | 'P': '피', 169 | 'p': '피', 170 | 'DJ': '디제이', 171 | 'Chaebol': '채벌', 172 | 'BBC': '비비씨', 173 | 'OECD': '오이씨디', 174 | 'BC': '삐씨', 175 | 'C': '씨', 176 | 'c': '씨', 177 | 'B': '씨', 178 | 'b': '씨', 179 | 'KY': '케이와이', 180 | 'K': '케이', 181 | 'k': '케이', 182 | 'CEO': '씨이오', 183 | 'YH': '와이에치', 184 | 'IS': '아이에스', 185 | 'who are you': '후 얼 유', 186 | 'Y': '와이', 187 | 'y': '와이', 188 | 'The Devils Advocate': '더 데빌즈 어드보카트', 189 | 'YS': '와이에스', 190 | 'so sorry': '쏘 쏘리', 191 | 'Santa': '산타', 192 | 'Big Endian': '빅 엔디안', 193 | 'Small Endian': '스몰 엔디안', 194 | 'Oh Captain My Captain': '오 캡틴 마이 캡틴', 195 | 'AIB': '에이아이비', 196 | 'K': '케이', 197 | 'k': '케이', 198 | 'PBS': '피비에스', 199 | } 200 | -------------------------------------------------------------------------------- /text/korean.py: -------------------------------------------------------------------------------- 1 | # Code based on 2 | 3 | import re 4 | import os 5 | import ast 6 | import json 7 | from jamo import hangul_to_jamo, h2j, j2h, hcj_to_jamo, is_hcj 8 | from jamo.jamo import _jamo_char_to_hcj 9 | 10 | from .ko_dictionary import english_dictionary, etc_dictionary 11 | #from ko_dictionary import english_dictionary, etc_dictionary 12 | 13 | PAD = '_' 14 | EOS = '~' 15 | PUNC = '!\'(),-.:;?' 16 | SPACE = ' ' 17 | 18 | JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) 19 | JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) 20 | JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) 21 | 22 | VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE 23 | ALL_SYMBOLS = PAD + EOS + VALID_CHARS 24 | ALL_SYMBOLS_1 = "_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑᄒㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣㄱㄲㄳㄴㄵㄶㅇㄹㄺㄻㄼㄽㄾㄿㅀㅁㅂㅄㅅㅆㅇㅈㅊㅋㅌㅍㅎ!'(),-.:;? " 25 | ALL_SYMBOLS_2 = "_~ㄱㄲㄳㄴㄵㄶㄷㄸㄹㄺㄻㄼㄾㅀㅁㅂㅃㅄㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ!'(),-.:;? " 26 | ALL_SYMBOLS_3 = "_~ᄀᄂᄃᄅᄆᄇᄉᄋᄌᄎᄏᄐᄑᄒㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎ!'(),-.;? " 27 | ALL_SYMBOLS_4 = "_~ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ!'(),-.;? " 28 | 29 | char_to_id_1 = {s: i for i, s in enumerate(ALL_SYMBOLS_1)} 30 | char_to_id_2 = {s: i for i, s in enumerate(ALL_SYMBOLS_2)} 31 | char_to_id_3 = {s: i for i, s in enumerate(ALL_SYMBOLS_3)} 32 | char_to_id_4 = {s: i for i, s in enumerate(ALL_SYMBOLS_4)} 33 | char_to_id = {c: i for i, c in enumerate(ALL_SYMBOLS)} 34 | id_to_char = {i: c for i, c in enumerate(ALL_SYMBOLS)} 35 | 36 | quote_checker = """([`"'"“‘])(.+?)([`"'"”’])""" 37 | 38 | def is_lead(char): 39 | return char in JAMO_LEADS 40 | 41 | def is_vowel(char): 42 | return char in JAMO_VOWELS 43 | 44 | def is_tail(char): 45 | return char in JAMO_TAILS 46 | 47 | def get_mode(char): 48 | if is_lead(char): 49 | return 0 50 | elif is_vowel(char): 51 | return 1 52 | elif is_tail(char): 53 | return 2 54 | else: 55 | return -1 56 | 57 | def _get_text_from_candidates(candidates): 58 | if len(candidates) == 0: 59 | return "" 60 | elif len(candidates) == 1: 61 | return _jamo_char_to_hcj(candidates[0]) 62 | else: 63 | return j2h(**dict(zip(["lead", "vowel", "tail"], candidates))) 64 | 65 | def jamo_to_korean(text): 66 | text = h2j(text) 67 | 68 | idx = 0 69 | new_text = "" 70 | candidates = [] 71 | 72 | while True: 73 | if idx >= len(text): 74 | new_text += _get_text_from_candidates(candidates) 75 | break 76 | 77 | char = text[idx] 78 | mode = get_mode(char) 79 | 80 | if mode == 0: 81 | new_text += _get_text_from_candidates(candidates) 82 | candidates = [char] 83 | elif mode == -1: 84 | new_text += _get_text_from_candidates(candidates) 85 | new_text += char 86 | candidates = [] 87 | else: 88 | candidates.append(char) 89 | 90 | idx += 1 91 | return new_text 92 | 93 | num_to_kor = { 94 | '0': '영', 95 | '1': '일', 96 | '2': '이', 97 | '3': '삼', 98 | '4': '사', 99 | '5': '오', 100 | '6': '육', 101 | '7': '칠', 102 | '8': '팔', 103 | '9': '구', 104 | } 105 | 106 | unit_to_kor1 = { 107 | '%': '퍼센트', 108 | 'cm': '센치미터', 109 | 'mm': '밀리미터', 110 | 'km': '킬로미터', 111 | 'kg': '킬로그람', 112 | } 113 | unit_to_kor2 = { 114 | 'm': '미터', 115 | } 116 | 117 | upper_to_kor = { 118 | 'A': '에이', 119 | 'B': '비', 120 | 'C': '씨', 121 | 'D': '디', 122 | 'E': '이', 123 | 'F': '에프', 124 | 'G': '지', 125 | 'H': '에이치', 126 | 'I': '아이', 127 | 'J': '제이', 128 | 'K': '케이', 129 | 'L': '엘', 130 | 'M': '엠', 131 | 'N': '엔', 132 | 'O': '오', 133 | 'P': '피', 134 | 'Q': '큐', 135 | 'R': '알', 136 | 'S': '에스', 137 | 'T': '티', 138 | 'U': '유', 139 | 'V': '브이', 140 | 'W': '더블유', 141 | 'X': '엑스', 142 | 'Y': '와이', 143 | 'Z': '지', 144 | } 145 | 146 | def compare_sentence_with_jamo(text1, text2): 147 | return h2j(text1) != h2j(text2) 148 | 149 | def load_symbols_1(): 150 | jamo = "_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ!'(),-.:;? " 151 | hj = "_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑᄒㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣㄱㄲㄳㄴㄵㄶㅇㄹㄺㄻㄼㄽㄾㄿㅀㅁㅂㅄㅅㅆㅇㅈㅊㅋㅌㅍㅎ!'(),-.:;? " 152 | assert len(jamo) == len(hj) 153 | j2hj = {j: h for j, h in zip(jamo, hj)} 154 | return j2hj 155 | 156 | def load_symbols_2(): 157 | jamo = "_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ!'(),-.:;? " 158 | hcj = "_~ㄱㄲㄴㄷㄸㄹㅁㅂㅃㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣㄱㄲㄳㄴㄵㄶㄷㄹㄺㄻㄼㄽㄾㄿㅀㅁㅂㅄㅅㅆㅇㅈㅊㅋㅌㅍㅎ!'(),-.:;? " 159 | assert len(jamo) == len(hcj) 160 | j2hcj = {j: h for j, h in zip(jamo, hcj)} 161 | return j2hcj 162 | 163 | def load_symbols_3(): 164 | jamo = "_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ!'(),-.:;? " 165 | sj = "_|~|ᄀ|ᄀᄀ|ᄂ|ᄃ|ᄃᄃ|ᄅ|ᄆ|ᄇ|ᄇᄇ|ᄉ|ᄉᄉ|ᄋ|ᄌ|ᄌᄌ|ᄎ|ᄏ|ᄐ|ᄑ|ᄒ|ㅏ|ㅐ|ㅑ|ㅒ|ㅓ|ㅔ|ㅕ|ㅖ|ㅗ|ㅘ|ㅙ|ㅚ|ㅛ|ㅜ|ㅝ|ㅞ|ㅟ|ㅠ|ㅡ|ㅢ|ㅣ|ㄱ|ㄱㄱ|ㄱㅅ|ㄴ|ㄴㅈ|ㄴㅎ|ㄷ|ㄹ|ㄹㄱ|ㄹㅁ|ㄹㅂ|ㄹㅅ|ㄹㅌ|ㄹㅍ|ㄹㅎ|ㅁ|ㅂ|ㅂㅅ|ㅅ|ㅅㅅ|ㅇ|ㅈ|ㅊ|ㅋ|ㅌ|ㅍ|ㅎ|!|'|(|)|,|-|.|:|;|?| " 166 | assert len(jamo) == len(sj.split("|")) 167 | j2sj = {j: s for j, s in zip(jamo, sj.split("|"))} 168 | return j2sj 169 | 170 | def load_symbols_4(): 171 | jamo = "_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ!'(),-.:;? " 172 | shcj = "_|~|ㄱ|ㄱㄱ|ㄴ|ㄷ|ㄷㄷ|ㄹ|ㅁ|ㅂ|ㅂㅂ|ㅅ|ㅅㅅ|ㅇ|ㅈ|ㅈㅈ|ㅊ|ㅋ|ㅌ|ㅍ|ㅎ|ㅏ|ㅐ|ㅑ|ㅒ|ㅓ|ㅔ|ㅕ|ㅖ|ㅗ|ㅘ|ㅙ|ㅚ|ㅛ|ㅜ|ㅝ|ㅞ|ㅟ|ㅠ|ㅡ|ㅢ|ㅣ|ㄱ|ㄱㄱ|ㄱㅅ|ㄴ|ㄴㅈ|ㄴㅎ|ㄷ|ㄹ|ㄹㄱ|ㄹㅁ|ㄹㅂ|ㄹㅅ|ㄹㅌ|ㄹㅍ|ㄹㅎ|ㅁ|ㅂ|ㅂㅅ|ㅅ|ㅅㅅ|ㅇ|ㅈ|ㅊ|ㅋ|ㅌ|ㅍ|ㅎ|!|'|(|)|,|-|.|:|;|?| " 173 | assert len(jamo) == len(shcj.split("|")) 174 | j2shcj = {j: s for j, s in zip(jamo, shcj.split("|"))} 175 | return j2shcj 176 | 177 | def tokenize(text, as_id=False, symbol_type=1, debug=False): 178 | 179 | j2hj, j2hcj, j2sj, j2shcj = load_symbols_1(), load_symbols_2(), load_symbols_3(), load_symbols_4() 180 | 181 | text = normalize(text) 182 | pre_tokens = list(hangul_to_jamo(text)) 183 | pre_tokens = [hcj_to_jamo(_, "lead") if is_hcj(_) else _ for _ in pre_tokens] 184 | tokens = [] 185 | 186 | if symbol_type == 1: 187 | if debug: 188 | print(char_to_id_1) 189 | for token in pre_tokens: 190 | tokens += list(j2hj[token]) 191 | 192 | if as_id: 193 | return [char_to_id_1[token] for token in tokens] + [char_to_id_1[EOS]] 194 | else: 195 | return [token for token in tokens] + [EOS] 196 | 197 | elif symbol_type == 2: 198 | if debug: 199 | print(char_to_id_2) 200 | for token in pre_tokens: 201 | tokens += list(j2hcj[token]) 202 | 203 | if as_id: 204 | return [char_to_id_2[token] for token in tokens] + [char_to_id_2[EOS]] 205 | else: 206 | return [token for token in tokens] + [EOS] 207 | 208 | elif symbol_type == 3: 209 | if debug: 210 | print(char_to_id_3) 211 | for token in pre_tokens: 212 | tokens += list(j2sj[token]) 213 | 214 | if as_id: 215 | return [char_to_id_3[token] for token in tokens] + [char_to_id_3[EOS]] 216 | else: 217 | return [token for token in tokens] + [EOS] 218 | 219 | elif symbol_type == 4: 220 | if debug: 221 | print(char_to_id_4) 222 | for token in pre_tokens: 223 | tokens += list(j2shcj[token]) 224 | 225 | if as_id: 226 | return [char_to_id_4[token] for token in tokens] + [char_to_id_4[EOS]] 227 | else: 228 | return [token for token in tokens] + [EOS] 229 | 230 | def tokenizer_fn(iterator, symbol_type): 231 | return (token for x in iterator for token in tokenize(x, as_id=False, symbol_type=symbol_type)) 232 | 233 | def normalize(text): 234 | text = text.strip() 235 | 236 | text = text.replace("'", "") 237 | text = text.replace('"', "") 238 | 239 | text = re.sub('\(\d+일\)', '', text) 240 | text = re.sub('\([⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]+\)', '', text) 241 | 242 | text = normalize_with_dictionary(text, etc_dictionary) 243 | text = normalize_english(text) 244 | text = re.sub('[a-zA-Z]+', normalize_upper, text) 245 | 246 | text = normalize_quote(text) 247 | text = normalize_number(text) 248 | 249 | return text 250 | 251 | def normalize_with_dictionary(text, dic): 252 | if any(key in text for key in dic.keys()): 253 | pattern = re.compile('|'.join(re.escape(key) for key in dic.keys())) 254 | return pattern.sub(lambda x: dic[x.group()], text) 255 | else: 256 | return text 257 | 258 | def normalize_english(text): 259 | def fn(m): 260 | word = m.group() 261 | if word in english_dictionary: 262 | return english_dictionary.get(word) 263 | else: 264 | return word 265 | 266 | text = re.sub("([A-Za-z]+)", fn, text) 267 | return text 268 | 269 | def normalize_upper(text): 270 | text = text.group(0) 271 | 272 | if all([char.isupper() for char in text]): 273 | return "".join(upper_to_kor[char] for char in text) 274 | else: 275 | return text 276 | 277 | def normalize_quote(text): 278 | def fn(found_text): 279 | from nltk import sent_tokenize # NLTK doesn't along with multiprocessing 280 | 281 | found_text = found_text.group() 282 | unquoted_text = found_text[1:-1] 283 | 284 | sentences = sent_tokenize(unquoted_text) 285 | return " ".join(["'{}'".format(sent) for sent in sentences]) 286 | 287 | return re.sub(quote_checker, fn, text) 288 | 289 | number_checker = "([+-]?\d[\d,]*)[\.]?\d*" 290 | count_checker = "(시|명|가지|살|마리|포기|송이|수|톨|통|점|개|벌|척|채|다발|그루|자루|줄|켤레|그릇|잔|마디|상자|사람|곡|병|판)" 291 | 292 | def normalize_number(text): 293 | text = normalize_with_dictionary(text, unit_to_kor1) 294 | text = normalize_with_dictionary(text, unit_to_kor2) 295 | text = re.sub(number_checker + count_checker, 296 | lambda x: number_to_korean(x, True), text) 297 | text = re.sub(number_checker, 298 | lambda x: number_to_korean(x, False), text) 299 | return text 300 | 301 | num_to_kor1 = [""] + list("일이삼사오육칠팔구") 302 | num_to_kor2 = [""] + list("만억조경해") 303 | num_to_kor3 = [""] + list("십백천") 304 | 305 | #count_to_kor1 = [""] + ["하나","둘","셋","넷","다섯","여섯","일곱","여덟","아홉"] 306 | count_to_kor1 = [""] + ["한","두","세","네","다섯","여섯","일곱","여덟","아홉"] 307 | 308 | count_tenth_dict = { 309 | "십": "열", 310 | "두십": "스물", 311 | "세십": "서른", 312 | "네십": "마흔", 313 | "다섯십": "쉰", 314 | "여섯십": "예순", 315 | "일곱십": "일흔", 316 | "여덟십": "여든", 317 | "아홉십": "아흔", 318 | } 319 | 320 | 321 | 322 | def number_to_korean(num_str, is_count=False): 323 | if is_count: 324 | num_str, unit_str = num_str.group(1), num_str.group(2) 325 | else: 326 | num_str, unit_str = num_str.group(), "" 327 | 328 | num_str = num_str.replace(',', '') 329 | try: 330 | num = ast.literal_eval(num_str) 331 | except: 332 | num = int(num_str) 333 | 334 | if num == 0: 335 | return "영" 336 | 337 | check_float = num_str.split('.') 338 | if len(check_float) == 2: 339 | digit_str, float_str = check_float 340 | elif len(check_float) >= 3: 341 | raise Exception(" [!] Wrong number format") 342 | else: 343 | digit_str, float_str = check_float[0], None 344 | 345 | if is_count and float_str is not None: 346 | raise Exception(" [!] `is_count` and float number does not fit each other") 347 | 348 | digit = int(digit_str) 349 | 350 | if digit_str.startswith("-"): 351 | digit, digit_str = abs(digit), str(abs(digit)) 352 | 353 | kor = "" 354 | size = len(str(digit)) 355 | tmp = [] 356 | 357 | for i, v in enumerate(digit_str, start=1): 358 | v = int(v) 359 | 360 | if v != 0: 361 | if is_count: 362 | tmp += count_to_kor1[v] 363 | else: 364 | tmp += num_to_kor1[v] 365 | 366 | tmp += num_to_kor3[(size - i) % 4] 367 | 368 | if (size - i) % 4 == 0 and len(tmp) != 0: 369 | kor += "".join(tmp) 370 | tmp = [] 371 | kor += num_to_kor2[int((size - i) / 4)] 372 | 373 | if is_count: 374 | if kor.startswith("한") and len(kor) > 1: 375 | kor = kor[1:] 376 | 377 | if any(word in kor for word in count_tenth_dict): 378 | kor = re.sub( 379 | '|'.join(count_tenth_dict.keys()), 380 | lambda x: count_tenth_dict[x.group()], kor) 381 | 382 | if not is_count and kor.startswith("일") and len(kor) > 1: 383 | kor = kor[1:] 384 | 385 | if float_str is not None: 386 | kor += "쩜 " 387 | kor += re.sub('\d', lambda x: num_to_kor[x.group()], float_str) 388 | 389 | if num_str.startswith("+"): 390 | kor = "플러스 " + kor 391 | elif num_str.startswith("-"): 392 | kor = "마이너스 " + kor 393 | 394 | return kor + unit_str 395 | 396 | if __name__ == "__main__": 397 | def test_normalize(text): 398 | print(text) 399 | print(normalize(text)) 400 | print("="*30) 401 | 402 | test_normalize("JTBC는 JTBCs를 DY는 A가 Absolute") 403 | test_normalize("오늘(13일) 101마리 강아지가") 404 | #test_normalize('"저돌"(猪突) 입니다.') 405 | #test_normalize('비대위원장이 지난 1월 이런 말을 했습니다. “난 그냥 산돼지처럼 돌파하는 스타일이다”') 406 | test_normalize("지금은 -12.35%였고 종류는 5가지와 19가지, 그리고 55가지였다") 407 | test_normalize("JTBC는 TH와 K 양이 2017년 9월 12일 오후 12시에 24살이 된다") 408 | -------------------------------------------------------------------------------- /text/numbers_.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | from text.korean import ALL_SYMBOLS_1 9 | 10 | _pad = '_' 11 | _punctuation = '!\'(),.:;? ' 12 | _special = '-' 13 | _end = '~' 14 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 18 | 19 | # Export all symbols: 20 | eng_symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + list(_end) #+ _arpabet 21 | kor_symbols = ALL_SYMBOLS_1 22 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import math 5 | from numpy import finfo 6 | 7 | import torch 8 | from distributed import apply_gradient_allreduce 9 | import torch.distributed as dist 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data import DataLoader 12 | 13 | from fp16_optimizer import FP16_Optimizer 14 | 15 | from model import Tacotron2 16 | from data_utils import TextMelLoader, TextMelCollate 17 | from loss_function import Tacotron2Loss_VAE 18 | from logger import Tacotron2Logger 19 | from hparams import create_hparams 20 | 21 | 22 | def batchnorm_to_float(module): 23 | """Converts batch norm modules to FP32""" 24 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 25 | module.float() 26 | for child in module.children(): 27 | batchnorm_to_float(child) 28 | return module 29 | 30 | 31 | def reduce_tensor(tensor, n_gpus): 32 | rt = tensor.clone() 33 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 34 | rt /= n_gpus 35 | return rt 36 | 37 | 38 | def init_distributed(hparams, n_gpus, rank, group_name): 39 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 40 | print("Initializing Distributed") 41 | 42 | # Set cuda device so everything is done on the right GPU. 43 | torch.cuda.set_device(rank % torch.cuda.device_count()) 44 | 45 | # Initialize distributed communication 46 | dist.init_process_group( 47 | backend=hparams.dist_backend, init_method=hparams.dist_url, 48 | world_size=n_gpus, rank=rank, group_name=group_name) 49 | 50 | print("Done initializing distributed") 51 | 52 | 53 | def prepare_dataloaders(hparams): 54 | # Get data, data loaders and collate function ready 55 | trainset = TextMelLoader(hparams.training_files, hparams) 56 | valset = TextMelLoader(hparams.validation_files, hparams) 57 | collate_fn = TextMelCollate(hparams.n_frames_per_step) 58 | 59 | train_sampler = DistributedSampler(trainset) \ 60 | if hparams.distributed_run else None 61 | 62 | train_loader = DataLoader(trainset, num_workers=1, shuffle=False, 63 | sampler=train_sampler, 64 | batch_size=hparams.batch_size, pin_memory=False, 65 | drop_last=True, collate_fn=collate_fn) 66 | return train_loader, valset, collate_fn 67 | 68 | 69 | def prepare_directories_and_logger(output_directory, log_directory, rank): 70 | if rank == 0: 71 | if not os.path.isdir(output_directory): 72 | os.makedirs(output_directory) 73 | os.chmod(output_directory, 0o775) 74 | logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) 75 | else: 76 | logger = None 77 | return logger 78 | 79 | 80 | def load_model(hparams): 81 | model = Tacotron2(hparams).cuda() 82 | if hparams.fp16_run: 83 | model = batchnorm_to_float(model.half()) 84 | model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) 85 | 86 | if hparams.distributed_run: 87 | model = apply_gradient_allreduce(model) 88 | 89 | return model 90 | 91 | 92 | def warm_start_model(checkpoint_path, model): 93 | assert os.path.isfile(checkpoint_path) 94 | print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) 95 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 96 | model.load_state_dict(checkpoint_dict['state_dict']) 97 | return model 98 | 99 | 100 | def load_checkpoint(checkpoint_path, model, optimizer): 101 | assert os.path.isfile(checkpoint_path) 102 | print("Loading checkpoint '{}'".format(checkpoint_path)) 103 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 104 | model.load_state_dict(checkpoint_dict['state_dict']) 105 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 106 | learning_rate = checkpoint_dict['learning_rate'] 107 | iteration = checkpoint_dict['iteration'] 108 | print("Loaded checkpoint '{}' from iteration {}" .format( 109 | checkpoint_path, iteration)) 110 | return model, optimizer, learning_rate, iteration 111 | 112 | 113 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 114 | print("Saving model and optimizer state at iteration {} to {}".format( 115 | iteration, filepath)) 116 | torch.save({'iteration': iteration, 117 | 'state_dict': model.state_dict(), 118 | 'optimizer': optimizer.state_dict(), 119 | 'learning_rate': learning_rate}, filepath) 120 | 121 | 122 | def validate(model, criterion, valset, iteration, batch_size, n_gpus, 123 | collate_fn, logger, distributed_run, rank): 124 | """Handles all the validation scoring and printing""" 125 | model.eval() 126 | with torch.no_grad(): 127 | val_sampler = DistributedSampler(valset) if distributed_run else None 128 | val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, 129 | shuffle=False, batch_size=batch_size, 130 | pin_memory=False, collate_fn=collate_fn) 131 | 132 | val_loss = 0.0 133 | for i, batch in enumerate(val_loader): 134 | x, y = model.parse_batch(batch) 135 | y_pred = model(x) 136 | loss, _, _, _ = criterion(y_pred, y, iteration) 137 | if distributed_run: 138 | reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() 139 | else: 140 | reduced_val_loss = loss.item() 141 | val_loss += reduced_val_loss 142 | val_loss = val_loss / (i + 1) 143 | 144 | model.train() 145 | if rank == 0: 146 | print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) 147 | logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) 148 | 149 | 150 | def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, 151 | rank, group_name, hparams): 152 | """Training and validation logging results to tensorboard and stdout 153 | 154 | Params 155 | ------ 156 | output_directory (string): directory to save checkpoints 157 | log_directory (string) directory to save tensorboard logs 158 | checkpoint_path(string): checkpoint path 159 | n_gpus (int): number of gpus 160 | rank (int): rank of current gpu 161 | hparams (object): comma separated list of "name=value" pairs. 162 | """ 163 | if hparams.distributed_run: 164 | init_distributed(hparams, n_gpus, rank, group_name) 165 | 166 | torch.manual_seed(hparams.seed) 167 | torch.cuda.manual_seed(hparams.seed) 168 | 169 | model = load_model(hparams) 170 | learning_rate = hparams.learning_rate 171 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 172 | weight_decay=hparams.weight_decay) 173 | if hparams.fp16_run: 174 | optimizer = FP16_Optimizer( 175 | optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) 176 | 177 | if hparams.distributed_run: 178 | model = apply_gradient_allreduce(model) 179 | 180 | criterion = Tacotron2Loss_VAE(hparams) 181 | 182 | logger = prepare_directories_and_logger( 183 | output_directory, log_directory, rank) 184 | 185 | train_loader, valset, collate_fn = prepare_dataloaders(hparams) 186 | 187 | # Load checkpoint if one exists 188 | iteration = 0 189 | epoch_offset = 0 190 | if checkpoint_path is not None: 191 | if warm_start: 192 | model = warm_start_model(checkpoint_path, model) 193 | else: 194 | model, optimizer, _learning_rate, iteration = load_checkpoint( 195 | checkpoint_path, model, optimizer) 196 | if hparams.use_saved_learning_rate: 197 | learning_rate = _learning_rate 198 | iteration += 1 # next iteration is iteration + 1 199 | epoch_offset = max(0, int(iteration / len(train_loader))) 200 | 201 | model.train() 202 | # ================ MAIN TRAINNIG LOOP! =================== 203 | step = 0 204 | for epoch in range(epoch_offset, hparams.epochs): 205 | print("Epoch: {}".format(epoch)) 206 | for i, batch in enumerate(train_loader): 207 | start = time.perf_counter() 208 | for param_group in optimizer.param_groups: 209 | param_group['lr'] = learning_rate 210 | 211 | model.zero_grad() 212 | x, y = model.parse_batch(batch) 213 | y_pred = model(x) 214 | 215 | loss, recon_loss, kl, kl_weight = criterion(y_pred, y, iteration) 216 | if hparams.distributed_run: 217 | reduced_loss = reduce_tensor(loss.data, n_gpus).item() 218 | else: 219 | reduced_loss = loss.item() 220 | 221 | if hparams.fp16_run: 222 | optimizer.backward(loss) 223 | grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) 224 | else: 225 | loss.backward() 226 | grad_norm = torch.nn.utils.clip_grad_norm_( 227 | model.parameters(), hparams.grad_clip_thresh) 228 | 229 | optimizer.step() 230 | 231 | overflow = optimizer.overflow if hparams.fp16_run else False 232 | 233 | if not overflow and not math.isnan(reduced_loss) and rank == 0: 234 | duration = time.perf_counter() - start 235 | print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( 236 | iteration, reduced_loss, grad_norm, duration)) 237 | logger.log_training( 238 | reduced_loss, grad_norm, learning_rate, duration, recon_loss, kl, kl_weight, iteration) 239 | 240 | if not overflow and (iteration % hparams.iters_per_checkpoint == 0): 241 | validate(model, criterion, valset, iteration, 242 | hparams.batch_size, n_gpus, collate_fn, logger, 243 | hparams.distributed_run, rank) 244 | if rank == 0: 245 | checkpoint_path = os.path.join( 246 | output_directory, "checkpoint_{}".format(iteration)) 247 | save_checkpoint(model, optimizer, learning_rate, iteration, 248 | checkpoint_path) 249 | 250 | iteration += 1 251 | 252 | 253 | if __name__ == '__main__': 254 | parser = argparse.ArgumentParser() 255 | parser.add_argument('-o', '--output_directory', type=str, 256 | help='directory to save checkpoints') 257 | parser.add_argument('-l', '--log_directory', type=str, 258 | help='directory to save tensorboard logs') 259 | parser.add_argument('-c', '--checkpoint_path', type=str, default=None, 260 | required=False, help='checkpoint path') 261 | parser.add_argument('--warm_start', action='store_true', 262 | help='load the model only (warm start)') 263 | parser.add_argument('--n_gpus', type=int, default=1, 264 | required=False, help='number of gpus') 265 | parser.add_argument('--rank', type=int, default=0, 266 | required=False, help='rank of current gpu') 267 | parser.add_argument('--group_name', type=str, default='group_name', 268 | required=False, help='Distributed group name') 269 | parser.add_argument('--hparams', type=str, 270 | required=False, help='comma separated name=value pairs') 271 | 272 | args = parser.parse_args() 273 | hparams = create_hparams(args.hparams) 274 | 275 | torch.backends.cudnn.enabled = hparams.cudnn_enabled 276 | torch.backends.cudnn.benchmark = hparams.cudnn_benchmark 277 | 278 | print("FP16 Run:", hparams.fp16_run) 279 | print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling) 280 | print("Distributed Run:", hparams.distributed_run) 281 | print("cuDNN Enabled:", hparams.cudnn_enabled) 282 | print("cuDNN Benchmark:", hparams.cudnn_benchmark) 283 | 284 | train(args.output_directory, args.log_directory, args.checkpoint_path, 285 | args.warm_start, args.n_gpus, args.rank, args.group_name, hparams) 286 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io.wavfile import read 3 | import librosa 4 | import torch 5 | import os 6 | 7 | max_wav_value=32768.0 8 | 9 | def get_mask_from_lengths(lengths): 10 | max_len = torch.max(lengths).item() 11 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 12 | mask = (ids < lengths.unsqueeze(1)).byte() 13 | return mask 14 | 15 | 16 | def load_wav_to_torch(full_path): 17 | sampling_rate, data = read(full_path) 18 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 19 | 20 | 21 | def load_filepaths_and_text(filename, split="|"): 22 | with open(filename, encoding='utf-8') as f: 23 | filepaths_and_text = [line.strip().split(split) for line in f] 24 | return filepaths_and_text 25 | 26 | 27 | def to_gpu(x): 28 | x = x.contiguous() 29 | 30 | if torch.cuda.is_available(): 31 | x = x.cuda(non_blocking=True) 32 | return torch.autograd.Variable(x) 33 | 34 | def str2bool(v): 35 | return v.lower() in ('true', '1') 36 | 37 | def makedirs(path): 38 | if not os.path.exists(path): 39 | print(" [*] Make directories : {}".format(path)) 40 | os.makedirs(path) 41 | 42 | def add_postfix(path, postfix): 43 | path_without_ext, ext = path.rsplit('.', 1) 44 | return "{}.{}.{}".format(path_without_ext, postfix, ext) 45 | -------------------------------------------------------------------------------- /web/static/css/main.css: -------------------------------------------------------------------------------- 1 | 2 | 3 | /* Feature Sliders */ 4 | 5 | .SliderFrame { 6 | /* width: 160px; */ 7 | margin: auto; 8 | display: table; 9 | padding-top: 1.5em; 10 | } 11 | 12 | 13 | .SliderLabel { 14 | display: table-cell; 15 | vertical-align: left; 16 | } 17 | 18 | .SliderLabel p { 19 | margin: 0.1em 0.5em; 20 | min-width: 8em; 21 | font-size: 1.0em; 22 | line-height: 1.1em; 23 | } 24 | 25 | .Slider { 26 | /* display: inline-block; */ 27 | /* margin: 10px 0 0px 0 !important; */ 28 | width: 190px !important; 29 | vertical-align: middle; 30 | display: table-cell; 31 | } 32 | 33 | .RefAudioFrame { 34 | /* width: 320px; */ 35 | margin: auto; 36 | display: table; 37 | padding-top: 1.5em; 38 | } -------------------------------------------------------------------------------- /web/static/js/main.js: -------------------------------------------------------------------------------- 1 | var sw; 2 | var wavesurfer; 3 | 4 | var defaultSpeed = 0.03; 5 | var defaultAmplitude = 0.3; 6 | 7 | var activeColors = [[32,133,252], [94,252,169], [253,71,103]]; 8 | var inactiveColors = [[241,243,245], [206,212,218], [222,226,230], [173,181,189]]; 9 | 10 | function generate(ip, port, text, n, s, h, a, condition_on_ref, ref_audio) {//}, speaker_id) { 11 | $("#synthesize").addClass("is-loading"); 12 | 13 | var uri = 'http://' + ip + ':' + port 14 | var url = uri + '/generate?text=' + encodeURIComponent(text) + "&n=" + n + "&s=" + s + "&h=" + h + "&a=" + a + "&con=" + condition_on_ref + "&ref=" + ref_audio; 15 | console.log(url); 16 | fetch(url, {cache: 'no-cache', mode: 'cors'}) 17 | .then(function(res) { 18 | if (!res.ok) throw Error(res.statusText) 19 | return res.blob() 20 | }).then(function(blob) { 21 | var reader = new FileReader(); 22 | reader.readAsDataURL(blob); 23 | reader.onloadend = function() { 24 | base64data = reader.result; 25 | console.log("base64", base64data); 26 | } 27 | var url = URL.createObjectURL(blob); 28 | inProgress = true; 29 | // console.log(url); 30 | wavesurfer.load(url); 31 | $("#synthesize").removeClass("is-loading"); 32 | }).catch(function(err) { 33 | console.log(err); 34 | // console.log("error error"); 35 | inProgress = false; 36 | $("#synthesize").removeClass("is-loading"); 37 | }); 38 | } 39 | 40 | (function(window, document, undefined){ 41 | window.onload = init; 42 | 43 | 44 | function init(){ 45 | 46 | wavesurfer = WaveSurfer.create({ 47 | container: '#waveform', 48 | waveColor: '#017AFD', 49 | barWidth: 5, 50 | progressColor: 'navy', 51 | cursorColr: '#fff', 52 | normalize:true, 53 | }); 54 | 55 | wavesurfer.on('ready', function () { 56 | 57 | wavesurfer.play(); 58 | }); 59 | 60 | wavesurfer.on('finish', function () { 61 | 62 | }); 63 | 64 | wavesurfer.on('audioprocess', function () { 65 | if(wavesurfer.isPlaying()) { 66 | var totalTime = wavesurfer.getDuration(), 67 | currentTime = wavesurfer.getCurrentTime(); 68 | 69 | var timer_total = document.getElementById('time-total'); 70 | var mins = Math.floor(totalTime / 60); 71 | var secs = Math.floor(totalTime % 60); 72 | if (secs < 10) { 73 | secs = '0' + String(secs); 74 | } 75 | timer_total.innerText = mins + ':' + secs; 76 | 77 | var timer_current = document.getElementById('time-current'); 78 | var mins = Math.floor(currentTime / 60); 79 | var secs = Math.floor(currentTime % 60); 80 | if (secs < 10) { 81 | secs = '0' + String(secs); 82 | } 83 | timer_current.innerText = mins + ':' + secs; 84 | } 85 | }); 86 | 87 | var loadFile = { 88 | contents:"Null", 89 | init: function() { 90 | $.ajax({ 91 | url:"/uploads/koemo_spk_emo_all_test.txt", 92 | dataType: "text", 93 | async:false, 94 | success: function(data) { 95 | var allText = data; 96 | var split = allText.split('\n') 97 | var randomNum = Math.floor(Math.random() * split.length); 98 | var randomLine = split[randomNum]; 99 | loadFile.contents = randomLine.split('|')[0].replace('/home/jhoh/dataset', '/uploads'); 100 | } 101 | }); 102 | } 103 | } 104 | 105 | var condition_on_ref = false; 106 | 107 | $(document).ready(function() { 108 | 109 | 110 | $("#mix").click(function() { 111 | $(".SliderFrame").css("display", ""); 112 | $(".RefAudioFrame").css("display", "none"); 113 | condition_on_ref = false; 114 | }); 115 | $("#refaudio").click(function() { 116 | $(".SliderFrame").css("display", "none"); 117 | $(".RefAudioFrame").css("display", ""); 118 | condition_on_ref = true; 119 | }); 120 | 121 | $("#neu").change(function(){ 122 | var neu = this.value; 123 | }); 124 | $("#sad").change(function(){ 125 | var sad = this.value; 126 | }); 127 | $("#hap").change(function(){ 128 | var hap = this.value; 129 | }); 130 | $("#ang").change(function(){ 131 | var ang = this.value; 132 | }); 133 | 134 | $('#random-audio').click(function() { 135 | loadFile.init(); 136 | console.log(loadFile.contents); 137 | 138 | var audio = document.getElementById('audio'); 139 | audio.src = loadFile.contents; 140 | console.log(audio.src); 141 | audio.load(); 142 | }); 143 | }); 144 | 145 | 146 | $(document).on('click', "#synthesize", function() { 147 | synthesize(); 148 | }); 149 | 150 | function synthesize() { 151 | var text = $("#text").val().trim(); 152 | var text_length = text.length; 153 | var ref_audio = $("#audio").attr('src'); 154 | 155 | generate('mind.snu.ac.kr', 5907, text, neu.value, sad.value, hap.value, ang.value, condition_on_ref, ref_audio); 156 | 157 | var lowpass = wavesurfer.backend.ac.createGain(); 158 | wavesurfer.backend.setFilter(lowpass); 159 | } 160 | 161 | } 162 | })(window, document, undefined); 163 | -------------------------------------------------------------------------------- /web/static/uploads/KoreanEmotionSpeech: -------------------------------------------------------------------------------- 1 | /home/jhoh/dataset/KoreanEmotionSpeech/ -------------------------------------------------------------------------------- /web/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Expressive Voice Generation 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 |
20 |

21 |
22 |
23 | 24 | 25 |
26 |
27 | 28 |
29 |
30 |
31 |
32 |
33 |
34 |

Neutral

35 |
36 | 37 |
38 |
39 |
40 |

Sad

41 |
42 | 43 |
44 |
45 |
46 |

Happy

47 |
48 | 49 |
50 |
51 |
52 |

Angry

53 |
54 | 55 |
56 |
57 |
58 |
59 |
60 |
61 | 62 | 77 |
78 |
79 |

80 |
81 |
82 |
83 | 84 |
85 | 88 |
89 |
90 |
91 |
92 | 93 |

94 |
95 |
96 | 99 |
100 | 0:00 / 0:00 101 |
102 |
103 |
104 |

105 |
106 |
107 |
108 |
109 |
110 | 111 | 112 | 113 | --------------------------------------------------------------------------------